In [None]:
# pip install torch torchvision transformers


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
from PIL import Image
import os

# Define the dataset
class MultimodalDataset(Dataset):
    def __init__(self, image_dir, captions, labels, transform=None, max_length=50):
        self.image_dir = image_dir
        self.captions = captions
        self.labels = labels
        self.transform = transform
        self.max_length = max_length
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, f'{idx}.jpg')
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        # Tokenize text
        caption = self.captions[idx]
        encoding = self.tokenizer.encode_plus(
            caption,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        input_ids = encoding['input_ids'].flatten()
        attention_mask = encoding['attention_mask'].flatten()

        label = self.labels[idx]

        return image, input_ids, attention_mask, label

# Define the model
class MultimodalModel(nn.Module):
    def __init__(self, num_classes):
        super(MultimodalModel, self).__init__()
        # Image model (ResNet50)
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()  # Remove the classification layer

        # Text model (BERT)
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Combined classifier
        self.fc = nn.Linear(self.resnet.fc.in_features + self.bert.config.hidden_size, num_classes)

    def forward(self, image, input_ids, attention_mask):
        # Image branch
        img_features = self.resnet(image)

        # Text branch
        text_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.pooler_output

        # Concatenate features
        combined_features = torch.cat((img_features, text_features), dim=1)

        # Classifier
        logits = self.fc(combined_features)

        return logits

# Hyperparameters
batch_size = 16
num_classes = 10
learning_rate = 1e-4
num_epochs = 5

# Data preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Dummy data (replace with real data loading)
image_dir = 'path/to/images'
captions = ["A caption describing the image"] * 1000  # Dummy captions
labels = [0] * 1000  # Dummy labels

dataset = MultimodalDataset(image_dir, captions, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model, loss, and optimizer
model = MultimodalModel(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    for images, input_ids, attention_masks, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(images, input_ids, attention_masks)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

print("Training complete.")
