In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
from typing import List, Dict, Tuple

class DualBertCrossAttention(nn.Module):
    def __init__(self, model_name='bert-base-uncased', hidden_size=768, freeze_input_bert=True):
        super().__init__()
        
        # Dual BERT setup
        self.input_bert = BertModel.from_pretrained(model_name)
        self.class_bert = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        
        # Freezing strategy
        if freeze_input_bert:
            for param in self.input_bert.parameters():
                param.requires_grad = False
        
        self.hidden_size = hidden_size
        
        # Cross-attention layers
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            batch_first=True,
            dropout=0.1
        )
        
        # Classification head
        self.classifier = nn.Linear(hidden_size, 1)
        self.dropout = nn.Dropout(0.1)
        
    def tokenize_inputs(self, texts: List[str], max_length=128) -> Dict[str, torch.Tensor]:
        """Tokenize input texts with padding and attention masks."""
        encoding = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )
        return encoding
    
    def tokenize_classes(self, all_classes: List[List[str]], max_class_length=32) -> Tuple[torch.Tensor, torch.Tensor, List[List[Tuple[int, int]]]]:
        """
        Tokenize class labels with padding and track boundaries.
        
        Args:
            all_classes: List of class lists per sample [["very good", "bad"], ["positive", "negative", "neutral"]]
            max_class_length: Max tokens per individual class
            
        Returns:
            class_input_ids: [batch_size, max_total_class_tokens]
            class_attention_mask: [batch_size, max_total_class_tokens] 
            class_boundaries: [batch_size][class_idx] = (start, end) positions
        """
        batch_size = len(all_classes)
        batch_class_tokens = []
        batch_boundaries = []
        
        for sample_classes in all_classes:
            sample_tokens = []
            sample_boundaries = []
            current_pos = 0
            
            # Tokenize each class in this sample
            for class_text in sample_classes:
                # Tokenize single class (remove [CLS], [SEP])
                tokens = self.tokenizer(
                    class_text,
                    add_special_tokens=False,
                    max_length=max_class_length,
                    truncation=True
                )['input_ids']
                
                # Record boundaries
                start_pos = current_pos
                end_pos = current_pos + len(tokens)
                sample_boundaries.append((start_pos, end_pos))
                
                sample_tokens.extend(tokens)
                current_pos = end_pos
            
            batch_class_tokens.append(sample_tokens)
            batch_boundaries.append(sample_boundaries)
        
        # Pad all samples to same length
        max_total_tokens = max(len(tokens) for tokens in batch_class_tokens)
        
        padded_input_ids = []
        attention_masks = []
        
        for tokens in batch_class_tokens:
            # Pad tokens
            pad_length = max_total_tokens - len(tokens)
            padded_tokens = tokens + [self.tokenizer.pad_token_id] * pad_length
            
            # Create attention mask (1 for real tokens, 0 for padding)
            attention_mask = [1] * len(tokens) + [0] * pad_length
            
            padded_input_ids.append(padded_tokens)
            attention_masks.append(attention_mask)
        
        return (
            torch.tensor(padded_input_ids),
            torch.tensor(attention_masks),
            batch_boundaries
        )
    
    def forward(self, input_texts: List[str], class_lists: List[List[str]]) -> torch.Tensor:
        """
        Forward pass with cross-attention between input and class tokens.
        
        Args:
            input_texts: ["how are you", "i am good"] 
            class_lists: [["positive", "negative"], ["good", "bad", "neutral"]]
            
        Returns:
            logits: [batch_size, max_classes] - padded with -inf for missing classes
        """
        batch_size = len(input_texts)
        device = next(self.parameters()).device
        
        # 1. Tokenize inputs
        input_encoding = self.tokenize_inputs(input_texts)
        input_ids = input_encoding['input_ids'].to(device)
        input_attention_mask = input_encoding['attention_mask'].to(device)
        
        # 2. Tokenize classes  
        class_ids, class_attention_mask, class_boundaries = self.tokenize_classes(class_lists)
        class_ids = class_ids.to(device)
        class_attention_mask = class_attention_mask.to(device)
        
        # 3. Get BERT embeddings
        with torch.no_grad() if self.input_bert.training == False else torch.enable_grad():
            input_outputs = self.input_bert(input_ids, attention_mask=input_attention_mask)
            input_embeddings = input_outputs.last_hidden_state  # [batch_size, input_seq_len, hidden_size]
        
        class_outputs = self.class_bert(class_ids, attention_mask=class_attention_mask)
        class_embeddings = class_outputs.last_hidden_state  # [batch_size, class_seq_len, hidden_size]
        
        # 4. Global Cross-Attention: Each token type attends to the other
        # Input tokens attend to class tokens
        input_enhanced, _ = self.cross_attention(
            query=input_embeddings,
            key=class_embeddings, 
            value=class_embeddings,
            key_padding_mask=~class_attention_mask.bool()  # True for padding positions
        )
        
        # Class tokens attend to input tokens  
        class_enhanced, _ = self.cross_attention(
            query=class_embeddings,
            key=input_embeddings,
            value=input_embeddings, 
            key_padding_mask=~input_attention_mask.bool()
        )
        
        # 5. Pool class tokens back to class representations
        batch_logits = []
        max_classes = max(len(classes) for classes in class_lists)
        
        for batch_idx in range(batch_size):
            sample_class_reps = []
            boundaries = class_boundaries[batch_idx]
            
            # Extract representation for each class
            for start, end in boundaries:
                # Mean pool tokens belonging to this class
                class_tokens = class_enhanced[batch_idx, start:end, :]  # [class_len, hidden_size]
                class_rep = class_tokens.mean(dim=0)  # [hidden_size]
                sample_class_reps.append(class_rep)
            
            # Convert to logits
            sample_class_reps = torch.stack(sample_class_reps)  # [num_classes, hidden_size]
            sample_logits = self.classifier(self.dropout(sample_class_reps)).squeeze(-1)  # [num_classes]
            
            # Pad to max_classes for batch consistency
            if len(sample_logits) < max_classes:
                padding = torch.full((max_classes - len(sample_logits),), float('-inf'), device=device)
                sample_logits = torch.cat([sample_logits, padding])
            
            batch_logits.append(sample_logits)
        
        return torch.stack(batch_logits)  # [batch_size, max_classes]


    def inference(self, input_texts: List[str], class_lists: List[List[str]]) -> List[Dict]:
        """
        Inference function that returns predictions with probabilities.
        
        Args:
            input_texts: List of input texts to classify
            class_lists: List of class options for each input
            
        Returns:
            List of prediction dictionaries with probabilities and predicted class
        """
        self.eval()  # Set model to evaluation mode
        
        with torch.no_grad():
            # Get logits from forward pass
            logits = self.forward(input_texts, class_lists)
            
            predictions = []
            
            for i, (sample_logits, classes) in enumerate(zip(logits, class_lists)):
                num_classes = len(classes)
                
                # Extract valid logits (remove padding)
                valid_logits = sample_logits[:num_classes]
                
                # Convert to probabilities
                probabilities = F.softmax(valid_logits, dim=0)
                
                # Get prediction
                pred_idx = torch.argmax(valid_logits).item()
                pred_class = classes[pred_idx]
                pred_prob = probabilities[pred_idx].item()
                
                # Create result dictionary
                result = {
                    'input_text': input_texts[i],
                    'predicted_class': pred_class,
                    'predicted_index': pred_idx,
                    'confidence': pred_prob,
                    'all_probabilities': {
                        class_name: prob.item() 
                        for class_name, prob in zip(classes, probabilities)
                    },
                    'class_options': classes
                }
                
                predictions.append(result)
        
        return predictions




In [37]:
 model = DualBertCrossAttention(freeze_input_bert=True,hidden_size=128)
 model

DualBertCrossAttention(
  (input_bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-

In [38]:
# Count parameters
total_params = sum(p.numel() for p in model.parameters())

# Convert to M and B
params_in_m = total_params / 1e6
params_in_b = total_params / 1e9

print(f"Total parameters: {total_params:,}")
print(f"In Millions (M): {params_in_m:.2f}M")
print(f"In Billions (B): {params_in_b:.3f}B")

Total parameters: 219,030,657
In Millions (M): 219.03M
In Billions (B): 0.219B


In [39]:
# Usage Example
def train_example():
    # Sample data
    input_texts = [
        "how are you doing today",
        "i am feeling great", 
        "this movie was okay"
    ]
    
    class_lists = [
        ["very positive", "quite negative"],  # 2 classes
        ["good", "bad", "neutral"],           # 3 classes  
        ["positive", "negative"]              # 2 classes
    ]
    
    # Ground truth labels (indices within each sample's classes)
    targets = [1, 0, 0]  # "quite negative", "good", "positive"
    
    # Initialize model
    model = DualBertCrossAttention(freeze_input_bert=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    # Forward pass
    logits = model(input_texts, class_lists)
    print(f"Logits shape: {logits.shape}")  # [3, 3] (padded to max_classes=3)
    
    # Compute loss (handle variable class numbers)
    total_loss = 0
    for i, (sample_logits, target_idx) in enumerate(zip(logits, targets)):
        num_classes = len(class_lists[i])
        valid_logits = sample_logits[:num_classes]  # Remove padded positions
        loss = F.cross_entropy(valid_logits.unsqueeze(0), torch.tensor([target_idx]))
        total_loss += loss
    
    avg_loss = total_loss / len(targets)
    print(f"Average loss: {avg_loss.item():.4f}")
    
    # Backward pass
    avg_loss.backward()
    optimizer.step()
    
    # Predictions
    predictions = []
    for i, sample_logits in enumerate(logits):
        num_classes = len(class_lists[i])
        valid_logits = sample_logits[:num_classes]
        pred_idx = torch.argmax(valid_logits).item()
        pred_class = class_lists[i][pred_idx]
        predictions.append((pred_idx, pred_class))
        print(f"Sample {i}: Predicted '{pred_class}' (index {pred_idx})")

if __name__ == "__main__":
    train_example()

Logits shape: torch.Size([3, 3])
Average loss: 0.8419
Sample 0: Predicted 'quite negative' (index 1)
Sample 1: Predicted 'neutral' (index 2)
Sample 2: Predicted 'negative' (index 1)


In [41]:

# Usage example:
def run_inference():
    # Load trained model
    model = DualBertCrossAttention(freeze_input_bert=True)
    # model.load_state_dict(torch.load('trained_model.pth'))  # Load trained weights
    
    # Sample inputs for inference
    test_inputs = [
        "I absolutely love this product!",
        "The weather is okay today", 
        "This service was terrible"
    ]
    
    test_classes = [
        ["very positive", "positive", "neutral", "negative", "very negative"],
        ["excellent", "good", "average", "poor"],
        ["satisfied", "dissatisfied", "neutral"]
    ]
    
    # Run inference
    results = model.inference(test_inputs, test_classes)
    
    # Print results
    for i, result in enumerate(results):
        print(f"\n--- Sample {i+1} ---")
        print(f"Input: '{result['input_text']}'")
        print(f"Predicted: '{result['predicted_class']}' (confidence: {result['confidence']:.3f})")
        print(f"All probabilities:")
        for class_name, prob in result['all_probabilities'].items():
            print(f"  {class_name}: {prob:.3f}")

run_inference()


--- Sample 1 ---
Input: 'I absolutely love this product!'
Predicted: 'negative' (confidence: 0.200)
All probabilities:
  very positive: 0.200
  positive: 0.200
  neutral: 0.200
  negative: 0.200
  very negative: 0.200

--- Sample 2 ---
Input: 'The weather is okay today'
Predicted: 'good' (confidence: 0.250)
All probabilities:
  excellent: 0.249
  good: 0.250
  average: 0.250
  poor: 0.250

--- Sample 3 ---
Input: 'This service was terrible'
Predicted: 'dissatisfied' (confidence: 0.334)
All probabilities:
  satisfied: 0.333
  dissatisfied: 0.334
  neutral: 0.333
