In [1]:
import os, math, torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Iterator, Tuple
from transformers import AutoTokenizer, AutoModel
from transformers.tokenization_utils_base import BatchEncoding
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from pprint import pprint

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
class Config:
    input_dim = 30
    features_hidden_1_dim =1024
    news_hidden_1_dim = 1024
    common_hidden_1_dim = 1024
    common_hidden_2_dim = 1024
    model_dim = 512
    num_heads = 8
    num_layers = 8
    output_dim = 3
    max_seq_length = 512
    news_embed_size = 768
    dropout = 0.2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    bert_model = "ProsusAI/finbert"

print(Config.device)

cpu


  return torch._C._cuda_getDeviceCount() > 0


In [3]:
class NewsWrapper:
    tokenizer: AutoTokenizer
    news_batch: List[List[List[str]]]
    flattened: List[str]
    tokenized: BatchEncoding

    def __init__(
        self, 
        tokenizer: AutoTokenizer,
        news_batch: List[List[List[str]]]
    ):
        self.tokenizer = tokenizer
        self.news_batch = news_batch
        self._flatten()
        self._tokenize()

    def get_encoding(
        self, bert_output: BaseModelOutputWithPoolingAndCrossAttentions
    ) -> torch.Tensor:
        cls_tokens = bert_output.last_hidden_state[:, 0, :]

        index = 0
        unflattened = []

        for sequence in self.news_batch:
            unflattened_sequence = []

            for step in sequence:
                step_length = len(step)
                unflattened_sequence.append(
                    self._combine_step_news(cls_tokens[index:index+step_length])
                )
                index += step_length

            unflattened.append(torch.stack(unflattened_sequence))

        return torch.stack(unflattened)

    def _flatten(self):
        self.flattened = []
        for sequence in self.news_batch:
            for step in sequence:
                for sentence in step:
                    self.flattened.append(sentence)

    def _tokenize(self):
        self.tokenized = self.tokenizer(
            self.flattened, padding=True, truncation=True, return_tensors='pt'
        )

    @staticmethod
    def _combine_step_news(vectors: torch.Tensor) -> torch.Tensor:
        # Compute the row-wise average
        return vectors.mean(dim=0)


In [4]:
class TimeSeriesTransformer(nn.Module):
    config: Config

    def __init__(self, config: Config):
        super(TimeSeriesTransformer, self).__init__()
        self.config = config

        self.features_hidden_1 = nn.Linear(config.input_dim, config.features_hidden_1_dim)
        self.features_hidden_2 = nn.Linear(config.features_hidden_1_dim, config.common_hidden_1_dim)
        
        self.news_bert = AutoModel.from_pretrained(config.bert_model)
        self.news_hidden_1 = nn.Linear(config.news_embed_size, config.news_hidden_1_dim)
        self.news_hidden_2 = nn.Linear(config.news_hidden_1_dim, config.common_hidden_1_dim)
        
        self.common_hidden_1 = nn.Linear(config.common_hidden_1_dim, config.common_hidden_2_dim)
        self.common_hidden_2 = nn.Linear(config.common_hidden_2_dim, config.model_dim)

        self.positional_encoding = self._create_positional_encoding(config.max_seq_length, config.model_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=config.model_dim, nhead=config.num_heads, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=config.num_layers)
        self.final_projection = nn.Linear(config.model_dim, config.output_dim)

        self.dropout = nn.Dropout(config.dropout)
        self.activation_gelu = nn.GELU()

        self.to(config.device)
 
    def forward(
        self,
        features: torch.Tensor, # (batch_size, sequence_length, input_dim)
        news: NewsWrapper,  # (batch_size, sequence_length, bert_embed_size)
    ) -> Tuple[
        torch.Tensor, # (batch_size, sequence_length, model_dim)
        torch.Tensor, # (batch_size, sequence_length, model_dim)
        torch.Tensor, # (batch_size, sequence_length, output_dim)
    ]:
        features = features.to(self.config.device)

        news_bert = self.news_bert(**{
            key: value.to(self.config.device) 
            for key, value in news.tokenized.items()
        })

        features = self.features_hidden_1(features)
        features = self.activation_gelu(features)
        features = self.dropout(features)
        features = self.features_hidden_2(features)

        news_hidden = self.news_hidden_1(news.get_encoding(news_bert))
        news_hidden = self.activation_gelu(news_hidden)
        news_hidden = self.dropout(news_hidden)
        news_hidden = self.news_hidden_2(news_hidden)

        common = features + news_hidden
        common = self.activation_gelu(common)
        common = self.dropout(common)
        common = self.common_hidden_1(common)
        common = self.activation_gelu(common)
        common = self.dropout(common)
        model = self.common_hidden_2(common)

        return self.apply_transformer(model)
    
    def apply_transformer(
        self, 
        common: torch.Tensor # (batch_size, sequence_length, model_dim)
    ) -> Tuple[
        torch.Tensor, # (batch_size, sequence_length, model_dim)
        torch.Tensor, # (batch_size, sequence_length, model_dim)
        torch.Tensor, # (batch_size, sequence_length, output_dim)
    ]:
        common = common.to(self.config.device)
        transformer_input = common + self.positional_encoding[:, :common.size(1), :]
        transformer_output = self.transformer_encoder(transformer_input)
        
        return transformer_input, transformer_output, self.final_projection(transformer_output)
    
    def predict(
        self, 
        features: torch.Tensor, # (batch_size, sequence_length, input_dim)
        news_wrapper: NewsWrapper, # (batch_size, sequence_length, bert_embed_size)
        future_steps: int
    ) -> torch.Tensor: # (batch_size, future_steps, output_dim)
        self.eval()
        predictions = []
        
        # use the full model for the initial prediction
        tr_inputs, tr_outputs, step_predictions = self.forward(
            features, news_wrapper
        )

        # get the prediction for the last token
        predictions.append(step_predictions[:, -1, :])
        
        with torch.no_grad():
            for _ in range(future_steps - 1):
                tr_inputs = torch.cat((
                    tr_inputs[:, 1:, :], 
                    tr_outputs[:, -1:, :]
                ), dim=1)

                # use the apply_transformer method for subsequent predictions
                tr_inputs, tr_outputs, step_predictions = self.apply_transformer(
                    tr_inputs
                )

                # get the prediction for the last token
                predictions.append(step_predictions[:, -1, :])

        # return predictions
        return torch.stack(predictions, dim=1)
    
    def freeze_bert(self, freeze: bool = True):
        for param in self.news_bert.parameters():
            param.requires_grad = not freeze

    def get_params_count(self, trainable: bool = False) -> int:
        return sum(param.numel() for param in self.parameters() if ((not trainable) or param.requires_grad))
    
    def _create_positional_encoding(
        self, max_seq_length: int, model_dim: int
    ) -> torch.Tensor:
        pe = torch.zeros(max_seq_length, model_dim)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-math.log(10000.0) / model_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Add batch dimension

        return nn.Parameter(pe, requires_grad=False).to(self.config.device)

In [None]:
config = Config()

tokenizer = AutoTokenizer.from_pretrained(config.bert_model)

news_wrapper = NewsWrapper(tokenizer, [
    [
        [
            "I am a sentence", 
            "I am another sentence", 
            "Hello world", 
            "Just a quite longer sentence to test the padding",
        ],
        [
            "So long and thanks for all the fish", 
        ],
        [
            "I am another sentence", 
            "Hello world", 
        ]
    ],
    [
        [
            "Hello world", 
            "Just a quite longer sentence to test the padding",
        ],
        [
            "So long and thanks for all the fish",
            "Hello world", 
        ],
        [
            "I am another sentence", 
            "Hello world", 
            "Just a quite longer sentence to test the padding",
        ]
    ]
])

input_features = torch.rand(2, 3, config.input_dim).to(config.device)

self = TimeSeriesTransformer(config)
# raw_input, raw_output, output = self(input_features, news_wrapper)


self.freeze_bert(True)
print(self.get_params_count(True) / 1e6, "Million parameters")

predictions = self.predict(input_features, news_wrapper, 10)

print(predictions.shape)
pprint(predictions)

# print(raw_input.shape)
# print(raw_output.shape)
# print(output.shape)

# print(input_features)

# outputs = bertModel(**news_wrapper.tokenized)
# news_encoding = news_wrapper.get_encoding(outputs)

# print(news_encoding)

32.865795 Million parameters
torch.Size([2, 10, 3])
tensor([[[ 0.0583, -0.5177,  0.1569],
         [-0.1305, -0.1571, -0.1989],
         [-0.2078,  0.0415, -0.3686],
         [-0.1813,  0.1612, -0.4192],
         [-0.1127,  0.1667, -0.4001],
         [-0.0730,  0.1627, -0.3799],
         [-0.0423,  0.1504, -0.3645],
         [-0.0171,  0.1410, -0.3506],
         [ 0.0064,  0.1316, -0.3367],
         [ 0.0269,  0.1226, -0.3201]],

        [[ 0.0522, -0.5234,  0.1530],
         [-0.1350, -0.1569, -0.2025],
         [-0.2113,  0.0438, -0.3722],
         [-0.1835,  0.1638, -0.4229],
         [-0.1143,  0.1689, -0.4036],
         [-0.0742,  0.1644, -0.3824],
         [-0.0431,  0.1516, -0.3663],
         [-0.0176,  0.1418, -0.3514],
         [ 0.0060,  0.1320, -0.3371],
         [ 0.0267,  0.1227, -0.3202]]], grad_fn=<StackBackward0>)
