In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
!pip install transformers --quiet

In [3]:
import torch
from torchvision import models
import torch.nn as nn
import os
import transformers
from transformers import AutoModel, BertTokenizerFast, get_linear_schedule_with_warmup

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [5]:
image_model = models.vgg16(progress=True, weights='IMAGENET1K_V1')
image_model.classifier[6] = nn.Linear(in_features=4096, out_features=4)
image_model = image_model.to(device)
image_model.load_state_dict(torch.load('/content/drive/MyDrive/FYDP/Combined/ImageModel/com_image_model.bin'))

<All keys matched successfully>

In [6]:
pre_trained_model_name = 'bert-base-uncased'
bert = AutoModel.from_pretrained(pre_trained_model_name, return_dict=False)

class GenreClassifier(nn.Module): #1 D CNN MOdel
    def __init__(self, bert):
        super(GenreClassifier, self).__init__()
        self.bert = bert
        self.conv = nn.Conv1d(in_channels=768, out_channels=256, kernel_size=5, padding='valid', stride=1)
        self.relu = nn.ReLU()
        self.clf1 = nn.Linear(256, 256)
        self.clf2 = nn.Linear(256, 4)

    # Define the forward pass
    def forward(self, sent_id, mask):
        cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)
        x = cls_hs[0]
        x = x.permute(0, 2, 1)
        x = self.conv(x)
        x = self.relu(x)

        # Apply adaptive max pooling along the sequence_length dimension
        x = nn.functional.adaptive_max_pool1d(x, 1)  # Pooling with kernel_size=1

        x = x.view(x.size(0), -1)  # Flatten the tensor along the spatial dimensions
        #x = self.dropout(x)
        x = self.clf1(x)
        #x = self.relu(x)
        return self.clf2(x)

In [7]:
text_model = GenreClassifier(bert)
text_model = text_model.to(device)
text_model.load_state_dict(torch.load('/content/drive/MyDrive/FYDP/Combined/TextModel/com_text_model.bin'))

<All keys matched successfully>

In [8]:
import pandas as pd
import numpy as np
import random

In [9]:
augment_image_folder = '/content/drive/MyDrive/FYDP/ImageData/augmented_resized'
test_csv_path = '/content/drive/MyDrive/FYDP/TextData/fl_client_1/data.csv'
genre_mapping = {
    0: 'Horror', 1: 'Comedy', 2: 'Romance', 3: 'Action'
}

def find_random_genre_image(genre):
  genre_path = os.path.join(augment_image_folder, genre_mapping.get(genre))
  files = os.listdir(genre_path)
  return os.path.join(genre_path, random.choice(files))

df = pd.read_csv(test_csv_path)
df = df.sample(n=32)
df['file'] = df['label'].apply(lambda x: find_random_genre_image(x))
df = df.reset_index()
df.head()

Unnamed: 0,index,description,label,file
0,65,A young girl finds solace in her artist father...,0,/content/drive/MyDrive/FYDP/ImageData/augmente...
1,363,As the world searches for a cure to a disastro...,0,/content/drive/MyDrive/FYDP/ImageData/augmente...
2,389,Xu Niannian and Yang Yi met at the most beauti...,2,/content/drive/MyDrive/FYDP/ImageData/augmente...
3,133,The Knight Jean de Carrouges must settle the d...,3,/content/drive/MyDrive/FYDP/ImageData/augmente...
4,267,Caught in the cross hairs of police corruption...,3,/content/drive/MyDrive/FYDP/ImageData/augmente...


In [10]:
tokenizer = BertTokenizerFast.from_pretrained(pre_trained_model_name)

In [11]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

In [16]:
from tqdm import tqdm

def force_three_channel(image):
    if image.mode == 'L':
        # If the image is single-channel (grayscale), convert it to RGB.
        image = image.convert('RGB')
    return image

class MovieData(Dataset):
  def __init__(self, content, targets, tokenizer, image):
        self.content = content
        self.targets = targets
        self.tokenizer = tokenizer
        self.image_file = image
        self.transform = transforms.Compose([
          transforms.Resize(244),
          transforms.Lambda(force_three_channel),
          transforms.ToTensor()
        ])

  def __len__(self):
    return len(self.content)

  def __getitem__(self, item):
      content = str(self.content[item])
      image_file = self.image_file[item]
      target = self.targets[item]

      image = Image.open(image_file)
      transformed_image = self.transform(image)

      encoding = self.tokenizer.encode_plus(
        content,
        max_length=512,
        add_special_tokens=True,
        return_token_type_ids=False,
        padding="max_length",
        truncation = True,
        return_attention_mask=True,
          return_tensors='pt'
      )


      return {
          'content_text':content,
          'input_ids':encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'image': transformed_image,
          'targets':torch.tensor(target, dtype=torch.long)
      }

def create_data_loader(df, tokenizer, batch_size, shuffle=True):
  ds = MovieData(
      content = df.description.to_numpy(),
      image = df.file,
      targets = df['label'].to_numpy(),
      tokenizer = tokenizer
  )

  return DataLoader(
      ds,
      batch_size = batch_size,
      num_workers = 2,
      shuffle = shuffle
  )
BATCH_SIZE = 4
test_dataloader = create_data_loader(df, tokenizer, BATCH_SIZE, shuffle=False)

def test(image_model, text_model, dataloader, n_examples):
  image_model.eval()
  text_model.eval()
  correct_predictions = 0
  for d in tqdm(dataloader):
    input_ids = d['input_ids'].to(device)
    attention_mask = d['attention_mask'].to(device)
    targets = d['targets'].to(device)
    image = d['image'].to(device)

    text_outputs = text_model(input_ids, attention_mask)
    image_outputs = image_model(image)

    outputs = (text_outputs + image_outputs) / 2
    _, preds = torch.max(outputs, dim=1)
    correct_predictions += torch.sum(preds == targets)
  return correct_predictions.double()/n_examples

acc = test(image_model, text_model, test_dataloader, len(df))
print()
print(f"Combined Model Accuracy Is: {100 * acc.item() :.2f}%")

100%|██████████| 8/8 [00:01<00:00,  4.72it/s]


Combined Model Accuracy Is: 93.75%



