In [116]:
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image

from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import numpy as np
import os

In [117]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

class MultimodalModel(nn.Module):
    def __init__(self, clip_model):
        super(MultimodalModel, self).__init__()
        self.clip_model = clip_model
        self.fc = nn.Linear(clip_model.config.projection_dim * 2, 3)

    def forward(self, image, text):
        inputs = clip_processor(text=text, images=image, return_tensors="pt", padding=True, truncation=True)

        outputs = self.clip_model(**inputs)
        img_features = outputs.image_embeds
        text_features = outputs.text_embeds

        combined_features = torch.cat((img_features, text_features), dim=1)
        logits = self.fc(combined_features)
        return logits

model = MultimodalModel(clip_model)
model.eval()

MultimodalModel(
  (clip_model): CLIPModel(
    (text_model): CLIPTextTransformer(
      (embeddings): CLIPTextEmbeddings(
        (token_embedding): Embedding(49408, 512)
        (position_embedding): Embedding(77, 512)
      )
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-11): 12 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=512, out_features=512, bias=True)
              (v_proj): Linear(in_features=512, out_features=512, bias=True)
              (q_proj): Linear(in_features=512, out_features=512, bias=True)
              (out_proj): Linear(in_features=512, out_features=512, bias=True)
            )
            (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=512, out_features=2048, bias=True)
              (fc2): Linear(in_features=2048, out_features=512, bias=

In [118]:
video = 'WICC' # change here

In [119]:
data = pd.read_csv(f'dataset/{video}/{video}.csv', header=None, skiprows=1, names=['label', 'text']) 
print(data.head())

   label                                               text
0      0                            What is Climate Change?
0      1  goal 13 of the sustainable development goals c...
0      0  me Tom Tom is a college professor that teaches...
2      2  during one of his lectures one of his students...
2      2  during one of his lectures one of his students...


In [120]:
image_folder = f'dataset/{video}/{video}_frames' 

def preprocess_image(image_path):
    image = Image.open(image_path)
    return image

In [121]:
def evaluate_model(data, image_folder, video, model):
    true_labels = []
    predictions = []

    for index, row in data.iterrows():
        label = row['label']
        transcript = row['text']
        
        image_path = os.path.join(image_folder, f"{video}-{index + 1:03d}.jpg")
        
        if not os.path.exists(image_path):
            print(f"Image {image_path} does not exist.")
            continue
        
        image_inputs = preprocess_image(image_path)
        
        with torch.no_grad():
            logits = model(image_inputs, transcript)
        
        predicted_class = torch.argmax(logits, dim=1).item()
        true_labels.append(label)
        predictions.append(predicted_class)
    
    true_labels = np.array(true_labels)
    predictions = np.array(predictions)
    
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='weighted')
    
    return accuracy, f1

In [122]:
accuracy, f1 = evaluate_model(data, image_folder, video, model)

accuracy = round(accuracy, 2)
f1 = round(f1, 2)

results = pd.DataFrame({
    'video': [video],
    'accuracy': [accuracy],
    'f1': [f1]
})

filename='results/CLIP_test_results.csv'
if os.path.exists(filename):
        results.to_csv(filename, mode='a', header=False, index=False)
else:
    results.to_csv(filename, mode='w', header=True, index=False)
