In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import BertModel


In [3]:
class MultiModalSentimentModel(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()

        # Text encoder (BERT)
        self.bert = BertModel.from_pretrained("bert-base-uncased")

        # Image encoder (ResNet-50)
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()  # remove classifier

        # Fusion + classifier
        self.fc1 = nn.Linear(768 + 2048, 512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, input_ids, attention_mask, images):
        text_features = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).pooler_output   # (B, 768)

        image_features = self.resnet(images)  # (B, 2048)

        fused = torch.cat((text_features, image_features), dim=1)
        x = self.fc1(fused)
        x = self.relu(x)
        x = self.dropout(x)
        return self.fc2(x)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiModalSentimentModel().to(device)

print(model)




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/viaansharma/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████████████████████████████████| 97.8M/97.8M [00:33<00:00, 3.08MB/s]


MultiModalSentimentModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, 