In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
from transformers import BertTokenizer, VisualBertModel, VisualBertForPreTraining, logging
from PIL import Image
from tqdm import tqdm
import os
import warnings
warnings.filterwarnings("ignore")
logging.set_verbosity_error()
import numpy as np

In [None]:
train_df = pd.read_json("../data/facebook/train.json")
dev_df = pd.read_json("../data/facebook/dev.json")
train_df.head()

In [None]:
# Some global variables
BATCH_SIZE = 32
EPOCHS = 5
ROOT_PATH = '../data/facebook'
IMAGE_SIZE = 224*224
NUM_CLASSES = 2
TEXTUAL_DIMENSION = 512
VISUAL_DIMENSION = 512
CHECKPOINT = './model.pt'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else'cpu')

In [None]:
# Initialize the dataset and maintain the dataloader
class DynamicDataset(Dataset):
    def __init__(self, json_path, transform = None):
        self.df = pd.read_json(json_path)
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = self.df.loc[index, 'img']
        img_file = os.path.join(ROOT_PATH, img_path)
        image = Image.open(img_file).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        
        text = self.df.loc[index, 'text']
        if 'label' not in self.df.columns:
            return image, text
        label = self.df.loc[index, 'label']

        return image ,text, label

In [None]:
# Define a transform function for image preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Create objects of each set of data
train_data = DynamicDataset(os.path.join(ROOT_PATH, 'train.json'), transform = transform)
dev_data = DynamicDataset(os.path.join(ROOT_PATH, 'dev.json'), transform = transform)

# Create a dataloader
train_loader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle = True)
dev_loader = DataLoader(dev_data, batch_size = BATCH_SIZE, shuffle = True)

In [None]:
class Visual_Feature(nn.Module):
    def __init__(self):
        super().__init__()

        # Define resnet50 model
        resnet50 = models.resnet50(weights = models.ResNet50_Weights.DEFAULT)
        convolution_layers = nn.Sequential(
            nn.Conv2d(2048, 1024, kernel_size=(3, 3), stride = (1, 1), padding = (1, 1)),
            nn.ReLU(),
            nn.Conv2d(1024, 512, kernel_size=(3, 3), stride = (1, 1), padding = (1, 1)),
            nn.ReLU(),
        )
        
        # Freeze parameters
        for param in resnet50.parameters():
            param.requires_grad = False

        self.resnet50 = nn.Sequential(*list(resnet50.children())[:-1])
        self.convolution_layers = convolution_layers

    def get_visual_features(self, images, get_conv_features):
        # Extract visual features from resnet50 model
        visual_features = None
        if(get_conv_features):
            visual_features = self.convolution_layers(self.resnet50(images))
        else:
            visual_features = self.resnet50(images)

        # visual_features = visual_features.view(visual_features.size(0), -1)
        visual_features = visual_features.reshape(BATCH_SIZE, 1, -1)

        return visual_features

In [None]:
class Textual_Feature(nn.Module):
    def __init__(self):
        super().__init__()

        # Define virtual bert model
        visual_bert = VisualBertForPreTraining.from_pretrained('uclanlp/visualbert-vqa')
        dense_layers = nn.Sequential(
            nn.Linear(30522, 20000),
            nn.ReLU(),
            nn.Linear(20000, 10000),
            nn.ReLU(),
            nn.Linear(10000, 5000),
            nn.ReLU(),
            nn.Linear(5000, 2000),
            nn.ReLU(),
            nn.Linear(2000, 1000),
            nn.ReLU(),
            nn.Linear(1000, 512),
            nn.ReLU(),
        )
        
#         # Freeze parameters
#         for param in visual_bert.parameters():
#             param.requires_grad = False

        self.visual_bert = visual_bert
        self.dense_layers = dense_layers

        # Define tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def get_textual_features(self, images, texts):
        # Define indices and attention mask
        inputs = self.tokenizer(texts, padding = True, truncation = True, return_tensors = 'pt')
        input_ids = inputs['input_ids'].to(DEVICE)
        attention_mask = inputs['attention_mask'].to(DEVICE)
        token_ids = inputs['token_type_ids'].to(DEVICE)

        # Extract visual features
        resnet50 = Visual_Feature()
        resnet50.to(DEVICE)
        visual_features = resnet50.get_visual_features(images.to(DEVICE), get_conv_features = False)
        visual_token_ids = torch.ones(visual_features.shape[:-1], dtype=torch.long).to(DEVICE)
        visual_attention_mask = torch.ones(visual_features.shape[:-1], dtype=torch.float).to(DEVICE)

        # Extract textual features from virtual bert model
        textual_features = self.visual_bert(
            input_ids = input_ids,
            attention_mask = attention_mask,
            token_type_ids = token_ids,
            visual_embeds = visual_features,
            visual_token_type_ids = visual_token_ids,
            visual_attention_mask = visual_attention_mask,
        )
        
        textual_features = textual_features[0][:, 0, :] # Extract the first token of last hidden state
        textual_features = self.dense_layers(textual_features)

        return textual_features

In [None]:
# Test visual bert (WORKS BUT SKIPPED TO PRESERVE MEMORY)
# vbert = Textual_Feature().to(DEVICE)
# for images, texts, labels in tqdm(train_loader):
#     images = images.to(DEVICE)
#     textual_feature = vbert.get_textual_features(images, texts)
#     print(textual_feature.shape)
#     break

In [None]:
# Test resnet50 (WORKS)
resnet50 = Visual_Feature()
resnet50.to(DEVICE)
image = Image.open(os.path.join(ROOT_PATH, 'dev/hateful/01456.png'))
image = transform(image).reshape(1, 3, 224, 224)
visual_features = resnet50.get_visual_features(image.to(DEVICE), get_conv_features = True)
print(visual_features.shape)

In [None]:
class Fusion(nn.Module):
    def __init__(self):
        super().__init__()

        # Define fusion layers
        fusion_layers = nn.Sequential(
            nn.Linear((VISUAL_DIMENSION + TEXTUAL_DIMENSION), 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )

        self.fusion_layers = fusion_layers
    
    def forward(self, images, texts):
        # Initialize text and visual classes
        visual_class = Visual_Feature().to(DEVICE)
        textual_class = Textual_Feature().to(DEVICE)

        # Extract visual and textual features
        visual_features = visual_class.get_visual_features(images, get_conv_features = True).reshape(BATCH_SIZE, -1)

        textual_features = textual_class.get_textual_features(images, texts)

        # Concatenate visual and textual features
        features = torch.cat((visual_features, textual_features), dim = 1)

        # Pass through fusion layers
        output = self.fusion_layers(features)

        return output

In [None]:
# Define fusion model
fusion = Fusion()
fusion.to(DEVICE)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()

# Define optimizer
optimizer = optim.Adam(fusion.parameters(), lr = 0.01)

In [None]:
def train_model(model):
    # Initialize required variables
    train_loss = 0
    train_acc = 0
    total = 0
    correct = 0
    
    for images, texts, labels in tqdm(train_loader):
        images = images.to(DEVICE)
        labels = torch.reshape(labels, (-1, 1)).to(dtype = torch.float32, device = DEVICE)

        optimizer.zero_grad()
        outputs = fusion(images, texts)

        predicted = torch.round(torch.sigmoid(outputs))

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
    train_acc = 100 * correct / total
    train_loss /= len(train_data)
    return train_acc, train_loss
        

In [None]:
def eval_model(model):    
    # Initialize the required variables
    dev_loss = 0
    dev_acc = 0
    total = 0
    correct = 0
    
    for images, texts, labels in tqdm(dev_loader):
        images = images.to(DEVICE)
        labels = torch.reshape(labels, (-1, 1)).to(dtype = torch.float32, device = DEVICE)
        
        outputs = model(images, texts)
        predicted = torch.round(torch.sigmoid(outputs)) # threshold issues
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        loss = criterion(outputs, labels)
        dev_loss += loss.item() * images.size(0)
        
    dev_acc = 100 * correct / total
    dev_loss /= len(dev_data)
    
    return dev_acc, dev_loss

In [None]:
def save_model(prev_acc, curr_acc, epoch, model, optimizer):
    # Compare and save
    if curr_acc > prev_acc:
        # Save the model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, CHECKPOINT)
        
        # Return new highest accuracy
        return curr_acc
    return prev_acc

In [None]:
prev_dev_acc = 0
dev_acc = 0
try:
    for epoch in range(EPOCHS):
        # Train model
        fusion.train()
        train_acc, train_loss = train_model(fusion)
        print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss = {train_loss:.4f}, Train Accuracy = {train_acc:.4f}")

        # Evaluate model
        fusion.eval()
        dev_acc, dev_loss = eval_model(fusion)
        print(f"Epoch {epoch+1}/{EPOCHS}, Dev Loss = {dev_loss:.4f}, Dev Accuracy = {dev_acc:.4f}")

        # Save best model
        prev_dev_acc = save_model(prev_dev_acc, dev_acc, epoch + 1, fusion, optimizer)
        
except Exception as e:
    # Log the exception
    print(e)

    # Save best model
    prev_dev_acc = save_model(prev_dev_acc, dev_acc, epoch, fusion, optimizer)