In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel

class MultiLevelAttention(nn.Module):
    def __init__(self, d_model):
        super(MultiLevelAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_model]))
        self.sentiment_weights = nn.Parameter(torch.zeros(d_model))
        self.sentence_attention = nn.Linear(d_model, 1)

    def forward(self, query, key, value, mask=None):
        # Word-level attention
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale
        sentiment_bias = torch.tanh(torch.matmul(key, self.sentiment_weights))
        attention_scores += sentiment_bias.unsqueeze(0).unsqueeze(0)

        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e9)

        attention_probs = F.softmax(attention_scores, dim=-1)
        word_attention_output = torch.matmul(attention_probs, value)

        # Sentence-level attention
        sentence_attention_weights = F.softmax(self.sentence_attention(word_attention_output), dim=1)
        sentence_attention_output = torch.sum(sentence_attention_weights * word_attention_output, dim=1)
        return sentence_attention_output

class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = MultiLevelAttention(d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        return x

class SentimentAnalysisTransformer(nn.Module):
    def __init__(self, num_layers, d_model, nhead, dim_feedforward, num_classes, dropout=0.1):
        super(SentimentAnalysisTransformer, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.embedding_layer = nn.Linear(768, d_model)  # Convert BERT output to desired d_model

        # Transformer layers
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers)]
        )

        # Multi-task learning layers
        self.emotion_classifier = nn.Linear(d_model, 5)  # For auxiliary emotion detection task
        self.sentiment_classifier = nn.Linear(d_model, num_classes)

        # Residual connections
        self.residual_connection = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)])

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_text):
        # Tokenization
        encoded_input = self.tokenizer(input_text, padding=True, truncation=True, return_tensors='pt')
        input_ids = encoded_input['input_ids']
        attention_mask = encoded_input['attention_mask']

        # BERT embeddings
        with torch.no_grad():
            bert_output = self.bert(input_ids, attention_mask=attention_mask)
        x = self.embedding_layer(bert_output.last_hidden_state)

        # Transformer layers with residual connections
        mask = (attention_mask != 0).unsqueeze(1).unsqueeze(2)
        for i, transformer in enumerate(self.transformer_blocks):
            x = transformer(x, mask) + self.residual_connection[i](x)

        # Mean pooling
        x = x.mean(dim=1)

        # Auxiliary emotion task
        emotion_output = self.emotion_classifier(x)

        # Main sentiment classification
        x = self.dropout(x)
        sentiment_output = self.sentiment_classifier(x)

        return sentiment_output, emotion_output