In [134]:
import torch
import torch.nn as nn
from transformers import ViTFeatureExtractor, ViTModel, BertTokenizer, BertModel
from PIL import Image

from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import numpy as np
import os

In [135]:
vit_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

class MultimodalModel(nn.Module):
    def __init__(self, vit_model, bert_model):
        super(MultimodalModel, self).__init__()
        self.vit_model = vit_model
        self.bert_model = bert_model
        self.fc = nn.Linear(vit_model.config.hidden_size + bert_model.config.hidden_size, 3) 

    def forward(self, image, text):
        img_features = self.vit_model(image).pooler_output
        
        text_inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        text_outputs = self.bert_model(**text_inputs)
        text_features = text_outputs.pooler_output
        
        combined_features = torch.cat((img_features, text_features), dim=1)
        
        logits = self.fc(combined_features)
        return logits

model = MultimodalModel(vit_model, bert_model)
model.eval()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


MultimodalModel(
  (vit_model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, 

In [136]:
video = 'WICC' # change here

In [137]:
data = pd.read_csv(f'dataset/{video}/{video}.csv', header=None, skiprows=1, names=['label', 'text']) 
print(data.head())

   label                                               text
0      0                            What is Climate Change?
0      1  goal 13 of the sustainable development goals c...
0      0  me Tom Tom is a college professor that teaches...
2      2  during one of his lectures one of his students...
2      2  during one of his lectures one of his students...


In [138]:
image_folder = f'dataset/{video}/{video}_frames' 

def preprocess_image(image_path):
    image = Image.open(image_path)
    inputs = vit_extractor(images=image, return_tensors="pt")
    return inputs['pixel_values']

In [139]:
def evaluate_model(data, image_folder, video, model):
    true_labels = []
    predictions = []

    for index, row in data.iterrows():
        label = row['label']
        transcript = row['text']
        
        image_path = os.path.join(image_folder, f"{video}-{index + 1:03d}.jpg")
        
        if not os.path.exists(image_path):
            print(f"Image {image_path} does not exist.")
            continue
        
        image_inputs = preprocess_image(image_path)
        
        with torch.no_grad():
            logits = model(image_inputs, transcript)
        
        predicted_class = torch.argmax(logits, dim=1).item()
        true_labels.append(label)
        predictions.append(predicted_class)
    
    true_labels = np.array(true_labels)
    predictions = np.array(predictions)
    
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='weighted')
    
    return accuracy, f1

In [140]:
accuracy, f1 = evaluate_model(data, image_folder, video, model)

accuracy = round(accuracy, 2)
f1 = round(f1, 2)

results = pd.DataFrame({
    'video': [video],
    'accuracy': [accuracy],
    'f1': [f1]
})

filename='results/ViTBERT_test_results.csv'
if os.path.exists(filename):
        results.to_csv(filename, mode='a', header=False, index=False)
else:
    results.to_csv(filename, mode='w', header=True, index=False)
