In [20]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from torchvision import transforms, models
from PIL import Image
import torch.nn as nn

# Load Dataset
df = pd.read_csv("movieDataset.csv")

# Define columns
TEXT_COLUMN = "description"
IMAGE_COLUMN = "poster_path"
NUMERICAL_COLUMNS = ["ratings", "box_office_collection"]
LABEL_COLUMN = "genre"

# Map genres to numeric labels
genre_mapping = {genre: idx for idx, genre in enumerate(df["genre"].unique())}
df[LABEL_COLUMN] = df[LABEL_COLUMN].map(genre_mapping)
df

Unnamed: 0,movie_id,title,description,poster_path,ratings,release_year,box_office_collection,genre
0,1,Movie 1,"This is the description of Movie 1, which belo...",poster_1.jpg,1.17,2009,219.98,0
1,2,Movie 2,"This is the description of Movie 2, which belo...",poster_2.jpg,9.69,2012,335.18,1
2,3,Movie 3,"This is the description of Movie 3, which belo...",poster_3.jpg,9.94,2021,209.16,2
3,4,Movie 4,"This is the description of Movie 4, which belo...",poster_4.jpg,5.97,2010,422.6,3
4,5,Movie 5,"This is the description of Movie 5, which belo...",poster_5.jpg,9.83,2006,257.5,4
5,6,Movie 6,"This is the description of Movie 6, which belo...",poster_6.jpg,8.31,2022,489.87,4
6,7,Movie 7,"This is the description of Movie 7, which belo...",poster_7.jpg,6.88,2015,82.61,4
7,8,Movie 8,"This is the description of Movie 8, which belo...",poster_8.jpg,5.34,2002,376.46,3
8,9,Movie 9,"This is the description of Movie 9, which belo...",poster_9.jpg,9.28,2001,354.31,0
9,10,Movie 10,"This is the description of Movie 10, which bel...",poster_10.jpg,3.09,2006,299.9,3


In [21]:
df["poster_path"] = "C:/Users/pavan/Downloads/posters/" + df["poster_path"]


In [22]:
df["poster_path"][1]

'C:/Users/pavan/Downloads/posters/poster_2.jpg'

In [25]:
# Preprocessing functions
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize_text(text):
    return tokenizer(text, return_tensors="pt", padding=True, truncation=True)

image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def preprocess_image(image_path):
    try:
        image = Image.open(image_path).convert("RGB")
        return image_transform(image).unsqueeze(0)
    except FileNotFoundError:
        print(f"Image {image_path} not found.")
        return torch.zeros((1, 3, 224, 224))

# Example preprocessing
example_text = df[TEXT_COLUMN].iloc[0]
example_image = f"{df[IMAGE_COLUMN].iloc[0]}"
example_numerical = torch.tensor([df[NUMERICAL_COLUMNS].iloc[0].values], dtype=torch.float32)

text_tokens = tokenize_text(example_text)
image_tensor = preprocess_image(example_image)
print("Text Tokens:", text_tokens)
print("Image Tensor Shape:", image_tensor.shape)
print("Numerical Data:", example_numerical)


Text Tokens: {'input_ids': tensor([[ 101, 2023, 2003, 1996, 6412, 1997, 3185, 1015, 1010, 2029, 7460, 2000,
         1996, 2895, 6907, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
Image Tensor Shape: torch.Size([1, 3, 224, 224])
Numerical Data: tensor([[  1.1700, 219.9800]])


In [6]:
class MultiModalModel(nn.Module):
    def __init__(self, text_model_name="bert-base-uncased", image_model_name="resnet18", numerical_input_size=2, num_classes=5):
        super(MultiModalModel, self).__init__()
        # Text Encoder
        self.text_model = AutoModel.from_pretrained(text_model_name)
        # Image Encoder
        self.image_model = models.resnet18(pretrained=True)
        self.image_model.fc = nn.Identity()  # Remove the final classification layer
        # Numerical Feature Encoder
        self.num_fc = nn.Linear(numerical_input_size, 128)
        # Combined Classifier
        self.classifier = nn.Linear(768 + 512 + 128, num_classes)

    def forward(self, text_tokens, image_tensor, numerical_data):
        text_features = self.text_model(**text_tokens).last_hidden_state[:, 0, :]
        image_features = self.image_model(image_tensor)
        numerical_features = self.num_fc(numerical_data)
        combined_features = torch.cat((text_features, image_features, numerical_features), dim=1)
        return self.classifier(combined_features)

# Instantiate Model
model = MultiModalModel(num_classes=len(genre_mapping))
print(model)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\pavan/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 117MB/s]

MultiModalModel(
  (text_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, ele




In [27]:
# Training Loop (Example Only)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Dummy Inputs for Example
labels = torch.tensor([df[LABEL_COLUMN].iloc[0]])
predictions = model(text_tokens, image_tensor, example_numerical)

# Compute Loss
loss = criterion(predictions, labels)
print("Loss:", loss.item())

# Save Model
torch.save(model.state_dict(), "multi_modal_movie_model.pth")
print("Model saved successfully.")


Loss: 23.676387786865234
Model saved successfully.
