In [3]:
import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import pprint
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
from sklearn.metrics import confusion_matrix

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from collections import OrderedDict

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from transformers import BertModel
from transformers import BertForSequenceClassification, AdamW, BertConfig


import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

from models import  MultimodalFramework
from model_utils import set_seed, build_optimizer, MemesDataset, evaluate_f1


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

torch.cuda.empty_cache()

#train_model(model_name, dataloaders_dict, criterion, len_train, len_val, config, path)
def train_model(model_name, dataloaders, criterion, config, path):
#def train_model(model, dataloaders, criterion, optimizer, len_train, len_val, num_epochs, path):
    
    set_seed(42)
    """
    if model_name == "bert":
        model = Bert() #Bert()
        #for param in model.bert.parameters():
        #    param.requires_grad = False
    elif model_name == "mlp":
        model = Net()
    else:
        model = ResNet() 
        #for param in model.resnet18.parameters():
        #    param.requires_grad = False
    """
    model = MultimodalFramework()
    
    torch.cuda.empty_cache()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model.to(device)
    
    num_epochs = 1
    optimizer = build_optimizer(model, "adamW",0.01, 0.9)
    #"""
    since = time.time()

    val_acc_history = []
    val_loss_history = []
    train_acc_history = []
    train_loss_history = []

    #best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    patience = 5 
    trigger = 0
    
    acc_dict = {}
    for epoch in range(num_epochs):
        #scheduler.step()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            f1 = 0

            for data in dataloaders[phase]:
                if model_name == "bert":
                    inputs, masks, labels = data
                    inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)
                elif model_name == "mlp":
                    inputs, labels = data
                    inputs, labels =inputs.float(), labels.long()
                    inputs, labels = inputs.to(device), labels.to(device)
                else:
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)


                #print(torch.equal(text_labels,img_labels))

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if model_name == "bert":
                        outputs = model([inputs, masks], model_name)

                    else:
                        outputs = model(inputs, model_name)


                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)


                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                #print("text_inp.size(0)")
                #print(text_inp.size(0))

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                f = evaluate_f1(preds, labels.data).to(device)
                print(f)
                f1 += f
                print("f1: " + str(f1))
            epoch_loss = running_loss / len(labels)
            epoch_acc = running_corrects.double() / len(labels)
            epoch_f1 = f1.double() / len(labels)
            print("epoch f1: " + str(epoch_f1))

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val':
                #wandb.log({"val_loss": epoch_loss, "val_f1": epoch_f1})
                acc_dict[epoch] = float(epoch_acc.detach().cpu())
                val_acc_history.append(epoch_acc.detach().cpu())
                val_loss_history.append(epoch_loss)
                torch.save(model.state_dict(), path+"_current.pth")
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    #best_model_wts = copy.deepcopy(model.state_dict())
                    #torch.save(model.state_dict(), path+"_best.pth")
                #"""
                if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):
                    trigger +=1
                    if trigger >= patience:
                        return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}
                else:
                    trigger = 0
                #"""    
            if phase == 'train':
                #wandb.log({"train_loss": epoch_loss, "train_acc": epoch_f1, "epoch": epoch})
                train_acc_history.append(epoch_acc.detach().cpu())
                train_loss_history.append(epoch_loss)


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    #model.load_state_dict(best_model_wts)
    return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}

         

    

PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
cpu


In [11]:
dataloaders = dataloaders_dict
set_seed(42)
"""
if model_name == "bert":
    model = Bert() #Bert()
    #for param in model.bert.parameters():
    #    param.requires_grad = False
elif model_name == "mlp":
    model = Net()
else:
    model = ResNet() 
    #for param in model.resnet18.parameters():
    #    param.requires_grad = False
"""
model = MultimodalFramework()

torch.cuda.empty_cache()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

num_epochs = 1
optimizer = build_optimizer(model, "adamW",0.01, 0.9)
#"""
since = time.time()

val_acc_history = []
val_loss_history = []
train_acc_history = []
train_loss_history = []

#best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
patience = 5 
trigger = 0

acc_dict = {}
for epoch in range(num_epochs):
    predicted_labels, ground_truth_labels = [], []
    #scheduler.step()
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0
        f1 = 0

        for data in dataloaders[phase]:
            if model_name == "bert":
                inputs, masks, labels = data
                inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)
            elif model_name == "mlp":
                inputs, labels = data
                inputs, labels =inputs.float(), labels.long()
                inputs, labels = inputs.to(device), labels.to(device)
            else:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)


            #print(torch.equal(text_labels,img_labels))

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                # Get model outputs and calculate loss
                # Special case for inception because in training it has an auxiliary output. In train
                #   mode we calculate the loss by summing the final output and the auxiliary output
                #   but in testing we only consider the final output.
                if model_name == "bert":
                    outputs = model([inputs, masks], model_name)

                else:
                    outputs = model(inputs, model_name)


                loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)


                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # statistics
            #print("text_inp.size(0)")
            #print(text_inp.size(0))

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            predicted_labels.extend(preds.cpu().detach().numpy())
            ground_truth_labels.extend(labels.cpu().detach().numpy())
            
            #abc
            
            f = evaluate_f1(preds, labels.data).to(device)
            
            from sklearn.metrics import f1_score   
            f1_score = f1_score(labels.cpu().data, preds.cpu())
            print(f)
            print(f1_score)
            f1 += f
            print("f1: " + str(f1))
        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
        epoch_f1 = f1.double() / len(dataloaders[phase].dataset)
        print("epoch f1: " + str(epoch_f1))

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        if phase == 'val':
            #wandb.log({"val_loss": epoch_loss, "val_f1": epoch_f1})
            acc_dict[epoch] = float(epoch_acc.detach().cpu())
            val_acc_history.append(epoch_acc.detach().cpu())
            val_loss_history.append(epoch_loss)
            torch.save(model.state_dict(), path+"_current.pth")
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                #best_model_wts = copy.deepcopy(model.state_dict())
                #torch.save(model.state_dict(), path+"_best.pth")
            #"""
            if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):
                trigger +=1
                if trigger >= patience:
                    print("here")#return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}
            else:
                trigger = 0
            #"""    
        if phase == 'train':
            #wandb.log({"train_loss": epoch_loss, "train_acc": epoch_f1, "epoch": epoch})
            train_acc_history.append(epoch_acc.detach().cpu())
            train_loss_history.append(epoch_loss)


time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))

# load best model weights
    #model.load_state_dict(best_model_wts)

Random seed set as 42


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch 0/0
----------
tensor(0.5000)
0.5
f1: tensor(0.5000)
tensor(0.7500)
0.6666666666666666
f1: tensor(1.2500)
tensor(0.)
0.0
f1: tensor(1.2500)


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


tensor(1.)
0.0
f1: tensor(2.2500)
tensor(0.2500)
0.0
f1: tensor(2.5000)
tensor(0.)
0.0
f1: tensor(2.5000)
tensor(0.5000)
0.6666666666666666
f1: tensor(3.)
tensor(0.2500)
0.4
f1: tensor(3.2500)
tensor(0.7500)
0.0
f1: tensor(4.)
tensor(0.7500)
0.0
f1: tensor(4.7500)


KeyboardInterrupt: 

In [14]:
from sklearn.metrics import f1_score
f1_score(ground_truth_labels, predicted_labels)

0.3225806451612903

In [3]:
model_name = "bert"

if model_name == "bert":
    train_inputs = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_txt.pt')
    val_inputs = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')

else:
    train_inputs = torch.load("/users/mgolovan/data/mgolovan/facebook_memes/data/train_img.pt")
    val_inputs = torch.load("/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt")


criterion = nn.CrossEntropyLoss()



train_dataloader = DataLoader(train_inputs, batch_size=4,shuffle=False)

val_dataloader= DataLoader(val_inputs, batch_size=4, shuffle=False)

dataloaders_dict = {'train':train_dataloader, 'val':val_dataloader}

path = '//users/mgolovan/data/mgolovan/facebook_memes/unimodal_models/model_' 
#train_model(model_name, dataloaders_dict, criterion, "config", path)  



In [13]:
len(dataloaders_dict["train"])

3

In [8]:
train_dataloader

[tensor([[  101,  2043, 19817,  1008,  1050, 15580,  2215,  2000,  2022,  3970,
           2021,  2027, 14163, 26065,  2618,  2037,  4230,   102,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
         [  101,  3342, 12455,  3531,  2393,  2491,  1996,  6355,  1010,  2317,
           6494,  4095,  1010, 10041,  2313,  1025,  2031,  2115,  8398,  6793,
          12403, 20821,  2030, 11265, 1

In [2]:
import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset
 
import matplotlib.pyplot as plt
from PIL import Image
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel


from models import MultimodalFramework
from model_utils import set_seed, build_optimizer, ReviewsDataset

 #'/gpfs/home/mgolovan/data/mgolovan/Reviews/amazon/home/test_rvw_inputs.pt'
model_name = "bert_resnet_l"
lr = 0.00005
epochs = 34
batch_size = int(32)
#best_model_1e-06_22_adamW_20_resnet.pth
random_seeds = [15, 0, 1, 67, 128, 87, 261, 510, 340, 22] # 
df = pd.DataFrame(columns = ['AUROC','accuracy', "precision", "recall", "f1-score", "CM", "CR"])

for seed in random_seeds:
    model_path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_'+str(lr)+'_' + str(seed)+'_adamW_' + str(epochs)+'_' + str(model_name)+ '.pth_current.pth'

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    torch.cuda.empty_cache()

    model = MultimodalFramework()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) #eager-sweep-1
    model.to(device)

    test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_txt.pt')
    test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_img.pt')

    if model_name.split("_")[:2] == ["bert", "resnet"]:
        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size)
        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size) 

    elif model_name.split("_")[:2] == ["resnet", "mlp"]:
        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size) 
        modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size)

    else:
        modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size) 
        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size)


    correct = 0
    total = 0
    pred = []
    test_labels = []

    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for modality1, modality2 in zip(modality_1, modality_2):

            if model_name.split("_")[:2] == ["bert", "resnet"]:
                text_inp, masks, text_labels = modality2
                img_inp, labels = modality1

                text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)
                img_inp, labels = img_inp.to(device), labels.to(device)

                outputs = model([img_inp, text_inp, masks], model_name)

            elif model_name.split("_")[:2] == ["resnet", "mlp"]:
                img_inp, labels = modality1
                tab_inp, tab_labels = modality2
                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                img_inp, labels = img_inp.to(device), labels.to(device)

                outputs = model([tab_inp, img_inp], model_name)
            else:
                tab_inp, tab_labels = modality1
                text_inp, masks, labels = modality2
                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)

                outputs = model([tab_inp, text_inp, masks], model_name)

            test_labels.extend(np.array(labels.cpu()))
            _, predicted = torch.max(outputs, 1)
            pred.extend(predicted.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc= 100 * correct // total
    print(f'Accuracy of the bert: {100 * correct // total} %')

    test_labels = np.array(test_labels)

    #print(confusion_matrix(test_labels, pred))
    cm = confusion_matrix(test_labels, pred)
    #print(classification_report(test_labels, pred))
    cr = classification_report(test_labels, pred, output_dict=True)
    auc = roc_auc_score(test_labels, pred)
    df = df.append({'AUROC': auc,'accuracy': acc, "precision":cr["macro avg"]["precision"]*100 ,
                    "recall":cr["macro avg"]["recall"]*100, "f1-score":cr["macro avg"]["f1-score"]*100,
                    "CM":cm, "CR":cr}, ignore_index=True)

df.to_csv(model_name + "_current_results.csv")
print(df.mean())
print(df.std())


PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 67 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 70 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 70 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 68 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 69 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 69 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 68 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 68 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 69 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 70 %
AUROC         0.643529
accuracy     68.800000
precision    66.177446
recall       64.352910
f1-score     64.797999
dtype: float64
AUROC        0.009263
accuracy     1.032796
precision    1.116388
recall       0.926337
f1-score     0.975161
dtype: float64




In [8]:
pd.read_csv("bert_resnet_luong_current_results.csv").std()

  """Entry point for launching an IPython kernel.


Unnamed: 0     3.027650
AUROC          0.048424
accuracy       2.796824
precision     12.079991
recall         4.842412
f1-score       9.065280
dtype: float64

In [1]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from transformers import BertModel


class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(
            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    


class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
    
class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 2, batch_first = True)
        self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        self.OvO_multihead_attention = MultiHeadAttention(256,2, typ = "OvO")
        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = "luong")
        
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
            
        elif model == "bert_resnet_l":
            img, text, masks = x
            res_emb = self.resnet18(img)
            res = self.res_wrap(res_emb)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            bert = self.bert_wrap(bert_emb)
            
            combined = torch.cat((res,
                                  bert), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)

            attn_output_LV = self.luong_attention(bert, res)
            attn_output_VL = self.luong_attention(res, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            res = res[:, None, :]
            bert = self.bert_wrap(bert_emb)
            bert = bert[:, None, :]

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            print(res.shape)
            res = self.res_wrap(res)
            
            attn_txt = self.OvO_concat_attention([feat, res], bert)
            attn_img = self.OvO_concat_attention([feat, bert],res)
            attn_tab = self.OvO_concat_attention([bert, res], feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset
 
import matplotlib.pyplot as plt
from PIL import Image
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel


#from models import MultimodalFramework
from model_utils import set_seed, build_optimizer, ReviewsDataset

 #'/gpfs/home/mgolovan/data/mgolovan/Reviews/amazon/home/test_rvw_inputs.pt'
model_name = "resnet"
lr = 0.001
epochs = 25
batch_size = int(32)
#best_model_1e-06_22_adamW_20_resnet.pth
random_seeds = [15, 0, 1, 67, 128, 87, 261, 510, 340, 22] # 
df = pd.DataFrame(columns = ['AUROC','accuracy', "precision", "recall", "f1-score", "CM", "CR"])

for seed in random_seeds:
    model_path = '/users/mgolovan/data/mgolovan/facebook_memes/unimodal_models/model_'+str(lr)+'_' + str(seed)+'_adamW_' + str(epochs)+'_' + str(model_name)+ '_current.pth'

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    torch.cuda.empty_cache()

    model = MultimodalFramework()

    if model_name == "bert":
        test_inputs = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_txt.pt')
        test_dataloader = DataLoader(test_inputs, batch_size=batch_size)

    elif model_name == "resnet":
        test_inputs = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_img.pt')
        test_dataloader = DataLoader(test_inputs, batch_size=batch_size)
    else:
        test_inputs = torch.load('/gpfs/home/mgolovan/data/mgolovan/Reviews/amazon/Electronics/test_rvw_binary_tab.pt')
        test_dataloader = DataLoader(test_inputs, batch_size=batch_size)

    model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu'))) #eager-sweep-1
    model.to(device)

    correct = 0
    total = 0
    pred = []
    test_labels = []

    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in test_dataloader:
            if model_name == "bert":
                inputs, masks, labels = data
                inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)
                outputs = model([inputs, masks], model_name)

            elif model_name == "resnet":
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs, model_name)
            else:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                inputs, labels =inputs.float(), labels.long()
                outputs = model(inputs, model_name)

            test_labels.extend(np.array(labels.cpu()))
            _, predicted = torch.max(outputs, 1)
            pred.extend(predicted.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc= 100 * correct // total
    print(f'Accuracy of the bert: {100 * correct // total} %')

    test_labels = np.array(test_labels)

    #print(confusion_matrix(test_labels, pred))
    cm = confusion_matrix(test_labels, pred)
    #print(classification_report(test_labels, pred))
    cr = classification_report(test_labels, pred, output_dict=True)
    auc = roc_auc_score(test_labels, pred)
    df = df.append({'AUROC': auc,'accuracy': acc, "precision":cr["macro avg"]["precision"]*100 ,
                    "recall":cr["macro avg"]["recall"]*100, "f1-score":cr["macro avg"]["f1-score"]*100,
                    "CM":cm, "CR":cr}, ignore_index=True)

df.to_csv(model_name + "_current_results.csv")


df.to_csv(model_name + "_current_results.csv")
print(df.mean())
print(df.std())
    



PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 62 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 58 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 61 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 62 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 59 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 59 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 60 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 58 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 63 %
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 61 %
AUROC         0.562000
accuracy     60.300000
precision    56.513300
recall       56.200030
f1-score     56.265014
dtype: float64
AUROC        0.016231
accuracy     1.766981
precision    1.751285
recall       1.623102
f1-score     1.674605
dtype: float64




In [55]:
i = torch.rand((64,256))

In [57]:
F.softmax(i, dim=0)

tensor([[0.0136, 0.0166, 0.0119,  ..., 0.0106, 0.0193, 0.0188],
        [0.0217, 0.0102, 0.0218,  ..., 0.0102, 0.0110, 0.0099],
        [0.0234, 0.0094, 0.0222,  ..., 0.0246, 0.0182, 0.0183],
        ...,
        [0.0103, 0.0122, 0.0147,  ..., 0.0236, 0.0208, 0.0157],
        [0.0248, 0.0157, 0.0184,  ..., 0.0111, 0.0204, 0.0153],
        [0.0109, 0.0236, 0.0130,  ..., 0.0199, 0.0097, 0.0097]])

In [58]:
F.softmax(i, dim=1)

tensor([[0.0035, 0.0043, 0.0030,  ..., 0.0027, 0.0050, 0.0048],
        [0.0055, 0.0026, 0.0055,  ..., 0.0026, 0.0028, 0.0025],
        [0.0058, 0.0023, 0.0055,  ..., 0.0061, 0.0045, 0.0046],
        ...,
        [0.0026, 0.0031, 0.0037,  ..., 0.0061, 0.0053, 0.0040],
        [0.0059, 0.0038, 0.0044,  ..., 0.0027, 0.0049, 0.0037],
        [0.0027, 0.0059, 0.0032,  ..., 0.0050, 0.0024, 0.0024]])

In [59]:
F.softmax(i, dim=-1)

tensor([[0.0035, 0.0043, 0.0030,  ..., 0.0027, 0.0050, 0.0048],
        [0.0055, 0.0026, 0.0055,  ..., 0.0026, 0.0028, 0.0025],
        [0.0058, 0.0023, 0.0055,  ..., 0.0061, 0.0045, 0.0046],
        ...,
        [0.0026, 0.0031, 0.0037,  ..., 0.0061, 0.0053, 0.0040],
        [0.0059, 0.0038, 0.0044,  ..., 0.0027, 0.0049, 0.0037],
        [0.0027, 0.0059, 0.0032,  ..., 0.0050, 0.0024, 0.0024]])

In [1]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from transformers import BertModel


class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(
            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    


class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
    
class OneVSOthers_concat(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        m2 = others[0]
        m3 = others[1]
        con = torch.cat((m2, m3), dim=1)
        weights = con @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
    
class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 2, batch_first = True)
        self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
            
        elif model == "bert_resnet_l":
            img, text, masks = x
            res_emb = self.resnet18(img)
            res = self.res_wrap(res_emb)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            bert = self.bert_wrap(bert_emb)
            
            combined = torch.cat((res,
                                  bert), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)

            attn_output_LV = self.luong_attention(bert, res)
            attn_output_VL = self.luong_attention(res, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            res = res[:, None, :]
            bert = self.bert_wrap(bert_emb)
            bert = bert[:, None, :]

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            print(res.shape)
            res = self.res_wrap(res)
            
            attn_txt = self.OvO_concat_attention([feat, res], bert)
            attn_img = self.OvO_concat_attention([feat, bert],res)
            attn_tab = self.OvO_concat_attention([bert, res], feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import pprint
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
from sklearn.metrics import confusion_matrix,f1_score

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from collections import OrderedDict

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from transformers import BertModel
from transformers import BertForSequenceClassification, AdamW, BertConfig


import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

from model_utils import set_seed, build_optimizer


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

torch.cuda.empty_cache()

#train_model(model_name, dataloaders_dict, criterion, len_train, len_val, config, path)
def train_model(model_name, dataloaders, criterion, len_train, len_val, config, path):
    
    set_seed(42)
    """
    if model_name == "bert_resnet":
        model = BertResNet()
    elif model_name == "bert_resnet_luong":
        model = BertResNetLuong() 
    else:
        model = BertResNetVaswani()  
    
    #for param in model.resnet18.parameters():
    #        param.requires_grad = False
    
    #for param in model.bert.parameters():
    #        param.requires_grad = False
    """
    
    model = MultimodalFramework()
    
    torch.cuda.empty_cache()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model.to(device)
    
    num_epochs = 11
    optimizer = build_optimizer(model, "adamW", 0.1, 0.9)

    since = time.time()

    val_acc_history = []
    val_loss_history = []
    train_acc_history = []
    train_loss_history = []

    best_acc = 0.0
    patience = 5 
    trigger = 0
    acc_dict = {}

    for epoch in range(num_epochs):
        #scheduler.step()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                length = len_train
                model.train()  # Set model to training mode
            else:
                length = len_val
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            predicted_labels, ground_truth_labels = [], []

            for modality1, modality2 in zip(dataloaders[phase][0], dataloaders[phase][1]):
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    
                    if model_name.split("_")[:2] == ["bert", "resnet"]:
                        text_inp, masks, text_labels = modality2
                        img_inp, labels = modality1

                        text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)
                        img_inp, labels = img_inp.to(device), labels.to(device)
                        
                        inp_len = text_inp.size(0)
                        outputs = model([img_inp, text_inp, masks], model_name)

                    elif model_name.split("_")[:2] == ["resnet", "mlp"]:
                        img_inp, labels = modality1
                        tab_inp, tab_labels = modality2
                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                        img_inp, labels = img_inp.to(device), labels.to(device)
                        
                        inp_len = tab_inp.size(0)
                        outputs = model([tab_inp, img_inp], model_name)
                    else:
                        tab_inp, tab_labels = modality1
                        text_inp, masks, labels = modality2
                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                        text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)
                        
                        inp_len = tab_inp.size(0)
                        outputs = model([tab_inp, text_inp, masks], model_name)
                    
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                #print("text_inp.size(0)")
                #print(text_inp.size(0))

                running_loss += loss.item() * labels.size(0)
                running_corrects += torch.sum(preds == labels.data)
                predicted_labels.extend(preds.cpu().detach().numpy())
                ground_truth_labels.extend(labels.cpu().detach().numpy())
                
            epoch_loss = running_loss / length
            epoch_acc = running_corrects.double() / length
            #epoch_f1 = f1.double() / len(dataloaders[phase].dataset)
            epoch_f1 = f1_score(ground_truth_labels, predicted_labels)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val':
                #wandb.log({"val_loss": epoch_loss, "val_acc": epoch_acc, "val_f1": epoch_f1})
                acc_dict[epoch] = float(epoch_acc.detach().cpu())
                val_acc_history.append(epoch_acc.detach().cpu())
                val_loss_history.append(epoch_loss)
                torch.save(model.state_dict(), path+"_current.pth")
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    #best_model_wts = copy.deepcopy(model.state_dict())
                    #torch.save(model.state_dict(), path+"_best.pth")
                #"""
                if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):
                    trigger +=1
                    if trigger >= patience:
                        return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}
                else:
                    trigger = 0
                #"""    
            if phase == 'train':
                wandb.log({"train_loss": epoch_loss, "train_acc": epoch_acc,"train_f1": epoch_f1, "epoch": epoch})
                train_acc_history.append(epoch_acc.detach().cpu())
                train_loss_history.append(epoch_loss)


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    #model.load_state_dict(best_model_wts)
    return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}

 


PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
cpu


In [3]:

model_name = "bert_resnet_ours_concat"

train_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_txt.pt')
val_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')

train_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_img.pt')
val_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')


criterion = nn.CrossEntropyLoss()


train_dataloader_text = DataLoader(train_inputs_txt, batch_size=4,shuffle=False)
val_dataloader_text = DataLoader(val_inputs_txt, batch_size=4, shuffle=False)

train_dataloader_img = DataLoader(train_inputs_img, batch_size=4,shuffle=False)
val_dataloader_img = DataLoader(val_inputs_img, batch_size=4, shuffle=False)

len_val = len(val_inputs_txt)
len_train = len(train_inputs_txt)

dataloaders_dict = {'train':[train_dataloader_img, train_dataloader_text], 'val':[val_dataloader_img, val_dataloader_text]}


path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_' 

train_model(model_name, dataloaders_dict, criterion, len_train, len_val, "config", path)  


    

Random seed set as 42


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch 0/10
----------


ValueError: not enough values to unpack (expected 4, got 3)

In [5]:
import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import pprint
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
from sklearn.metrics import confusion_matrix,f1_score

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from collections import OrderedDict

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from transformers import BertModel
from transformers import BertForSequenceClassification, AdamW, BertConfig


import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

#from models import MultimodalFramework
from model_utils import set_seed, build_optimizer


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

torch.cuda.empty_cache()

#train_model(model_name, dataloaders_dict, criterion, len_train, len_val, config, path)
def train_model(model_name, dataloaders, criterion, len_train, len_val, config, path):
    
    set_seed(42)

    model = MultimodalFramework()
    torch.cuda.empty_cache()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model.to(device)
    
    num_epochs = 1
    optimizer = build_optimizer(model, "adamW", 0.1, 0.9)

    since = time.time()

    val_acc_history = []
    val_loss_history = []
    train_acc_history = []
    train_loss_history = []

    best_acc = 0.0
    patience = 10 
    trigger = 0
    acc_dict = {}

    for epoch in range(num_epochs):
        #scheduler.step()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                length = len_train
                model.train()  # Set model to training mode
            else:
                length = len_val
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            predicted_labels, ground_truth_labels = [], []

            for modality1, modality2 in zip(dataloaders[phase][0], dataloaders[phase][1]):
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    
                    if model_name.split("_")[:2] == ["bert", "resnet"]:
                        text_inp, masks, text_labels = modality2
                        img_inp, labels = modality1

                        text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)
                        img_inp, labels = img_inp.to(device), labels.to(device)
                        
                        inp_len = text_inp.size(0)
                        outputs = model([img_inp, text_inp, masks], model_name)

                    elif model_name.split("_")[:2] == ["resnet", "mlp"]:
                        img_inp, labels = modality1
                        tab_inp, tab_labels = modality2
                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                        img_inp, labels = img_inp.to(device), labels.to(device)
                        
                        inp_len = tab_inp.size(0)
                        outputs = model([tab_inp, img_inp], model_name)
                    else:
                        tab_inp, tab_labels = modality1
                        text_inp, masks, labels = modality2
                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                        text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)
                        
                        inp_len = tab_inp.size(0)
                        outputs = model([tab_inp, text_inp, masks], model_name)
                    
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                #print("text_inp.size(0)")
                #print(text_inp.size(0))

                running_loss += loss.item() * labels.size(0)
                running_corrects += torch.sum(preds == labels.data)
                predicted_labels.extend(preds.cpu().detach().numpy())
                ground_truth_labels.extend(labels.cpu().detach().numpy())
                
            epoch_loss = running_loss / length
            epoch_acc = running_corrects.double() / length
            #epoch_f1 = f1.double() / len(dataloaders[phase].dataset)
            epoch_f1 = f1_score(ground_truth_labels, predicted_labels)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val':
                wandb.log({"val_loss": epoch_loss, "val_acc": epoch_acc, "val_f1": epoch_f1})
                acc_dict[epoch] = float(epoch_acc.detach().cpu())
                val_acc_history.append(epoch_acc.detach().cpu())
                val_loss_history.append(epoch_loss)
                torch.save(model.state_dict(), path+"_current.pth")
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    #best_model_wts = copy.deepcopy(model.state_dict())
                    #torch.save(model.state_dict(), path+"_best.pth")
                #"""
                if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):
                    trigger +=1
                    if trigger >= patience:
                        return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}
                else:
                    trigger = 0
                #"""    
            if phase == 'train':
                wandb.log({"train_loss": epoch_loss, "train_acc": epoch_acc,"train_f1": epoch_f1, "epoch": epoch})
                train_acc_history.append(epoch_acc.detach().cpu())
                train_loss_history.append(epoch_loss)


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    #model.load_state_dict(best_model_wts)
    return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}

    


PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
cpu


In [8]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from MHA_modified import MultiheadAttention

class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(
            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        print(query.shape)
        print(self.W.shape)
        print(values.shape)
        print(values.transpose(2,3).shape)
        weights = query.transpose(2, 3) @ self.W @ values.transpose(2, 3) #.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    
class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
    
class OneVSOthers_concat(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        m2 = others[0]
        m3 = others[1]
        con = torch.cat((m2, m3), dim=1)
        weights = con @ self.W @ main.T  # [seq_length]
        return weights/np.sqrt(self.decoder_dim)

class MultiHeadAttention(nn.Module):

    def __init__(self,
                 in_features,
                 head_num, typ,
                 bias=True,
                 activation=F.relu):
        
        super(MultiHeadAttention, self).__init__()
        if in_features % head_num != 0:
            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))
        self.in_features = in_features
        self.type = typ
        self.head_num = head_num
        self.activation = activation
        self.bias = bias
        self.linear_q = nn.Linear(in_features, in_features, bias)
        self.linear_k = nn.Linear(in_features, in_features, bias)
        self.linear_v = nn.Linear(in_features, in_features, bias)
        self.linear_o = nn.Linear(in_features, in_features, bias)

    def forward(self, q, k, v, mask=None): 
        #q = self.linear_q(q)
        #k = self.linear_k(k)
        #v = self.linear_v(v)
        
        dim = int(self.in_features / self.head_num)
        #y = ScaledDotProductAttention()(q, k, v, mask)
        q = self._reshape_to_batches(q)
        k = self._reshape_to_batches(k)
        v = self._reshape_to_batches(v)
        if self.type == "OvO":
            att = OneVSOthers(dim, dim)
            y = att([q, k], v) #.cuda()
        else:
            att = MultiplicativeAttention(dim, dim)
            y = att(q, v) #.cuda()
        y = self._reshape_from_batches(y)
        y = self.linear_o(y)
        #if self.activation is not None:
        #    y = self.activation(y)
        return y

    """
    def _reshape_to_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, self.head_num, sub_dim)\
                .permute(0, 1, 2)\
                .reshape(batch_size * self.head_num, sub_dim)

    def _reshape_from_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num,  in_feature)\
                .permute(0, 1, 2)\
                .reshape(batch_size,  out_dim)
    """

    def _reshape_to_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size * self.head_num, seq_len, sub_dim)

    def _reshape_from_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size, seq_len, out_dim)
    
    def extra_repr(self):
        return 'in_features={}, head_num={}, bias={}, activation={}'.format(
            self.in_features, self.head_num, self.bias, self.activation)
#"""

class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 8, batch_first = True)
        self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)
        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = "OvO")
        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = "luong")
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)

            attn_output_LV = self.luong_multihead_attention(bert, res, res)
            attn_output_VL = self.luong_multihead_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert
            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)
            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out


In [9]:
model_name = "bert_resnet_luong"
train_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_txt.pt')
val_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')

train_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_img.pt')
val_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')


criterion = nn.CrossEntropyLoss()



train_dataloader_text = DataLoader(train_inputs_txt, batch_size=4,shuffle=False)
val_dataloader_text = DataLoader(val_inputs_txt, batch_size=4, shuffle=False)

train_dataloader_img = DataLoader(train_inputs_img, batch_size=4,shuffle=False)
val_dataloader_img = DataLoader(val_inputs_img, batch_size=4, shuffle=False)

len_val = len(val_inputs_txt)
len_train = len(train_inputs_txt)

dataloaders_dict = {'train':[train_dataloader_img, train_dataloader_text], 'val':[val_dataloader_img, val_dataloader_text]}


path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_' 
train_model(model_name, dataloaders_dict, criterion, len_train, len_val, "config", path)  



Random seed set as 42


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch 0/0
----------
torch.Size([8, 1, 128])
torch.Size([128, 128])


IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

In [123]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from MHA_modified import MultiheadAttention

class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(
            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    



class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
    
class OneVSOthers_concat(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        m2 = others[0]
        m3 = others[1]
        con = torch.cat((m2, m3), dim=1)
        weights = con @ self.W @ main.T  # [seq_length]
        return weights/np.sqrt(self.decoder_dim)
#"""
class MultiHeadAttention(nn.Module):

    def __init__(self,
                 in_features,
                 head_num, typ,
                 bias=True,
                 activation=F.relu):
        
        super(MultiHeadAttention, self).__init__()
        if in_features % head_num != 0:
            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))
        self.in_features = in_features
        self.type = typ
        self.head_num = head_num
        self.activation = activation
        self.bias = bias
        self.linear_q = nn.Linear(in_features, in_features, bias)
        self.linear_k = nn.Linear(in_features, in_features, bias)
        self.linear_v = nn.Linear(in_features, in_features, bias)
        self.linear_o = nn.Linear(in_features, in_features, bias)

    def forward(self, q, k, v, mask=None): 
        #q = self.linear_q(q)
        #k = self.linear_k(k)
        #v = self.linear_v(v)
        
        dim = int(self.in_features / self.head_num)
        #y = ScaledDotProductAttention()(q, k, v, mask)
        q = self._reshape_to_batches(q)
        k = self._reshape_to_batches(k)
        v = self._reshape_to_batches(v)
        if self.type == "OvO":
            att = OneVSOthers(dim, dim)
            y = att([q, k], v) #.cuda()
        else:
            att = MultiplicativeAttention(dim, dim)
            y = att(q, v) #.cuda()
        y = self._reshape_from_batches(y)
        y = self.linear_o(y)
        #if self.activation is not None:
        #    y = self.activation(y)
        return y

    
    def _reshape_to_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, self.head_num, sub_dim)\
                .permute(0, 1, 2)\
                .reshape(batch_size * self.head_num, sub_dim)

    def _reshape_from_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num,  in_feature)\
                .permute(0, 1, 2)\
                .reshape(batch_size,  out_dim)
    """

    def _reshape_to_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size * self.head_num, seq_len, sub_dim)

    def _reshape_from_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size, seq_len, out_dim)
    
    def extra_repr(self):
        return 'in_features={}, head_num={}, bias={}, activation={}'.format(
            self.in_features, self.head_num, self.bias, self.activation)
    """

class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)
        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)
        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = "OvO")
        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = "luong")
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)

            attn_output_LV = self.luong_multihead_attention(bert, res, res)
            attn_output_VL = self.luong_multihead_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            res = res[:, None, :]
            bert = self.bert_wrap(bert_emb)
            bert = bert[:, None, :]

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert
            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)
            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out


In [118]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from MHA_modified import MultiheadAttention

class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(
            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    



class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
    
class OneVSOthers_concat(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        m2 = others[0]
        m3 = others[1]
        con = torch.cat((m2, m3), dim=1)
        weights = con @ self.W @ main.T  # [seq_length]
        return weights/np.sqrt(self.decoder_dim)
#"""
class MultiHeadAttention(nn.Module):

    def __init__(self,
                 in_features,
                 head_num, typ,
                 bias=True,
                 activation=F.relu):
        
        super(MultiHeadAttention, self).__init__()
        if in_features % head_num != 0:
            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))
        self.in_features = in_features
        self.type = typ
        self.head_num = head_num
        self.activation = activation
        self.bias = bias
        self.linear_q = nn.Linear(in_features, in_features, bias)
        self.linear_k = nn.Linear(in_features, in_features, bias)
        self.linear_v = nn.Linear(in_features, in_features, bias)
        self.linear_o = nn.Linear(in_features, in_features, bias)

    def forward(self, q, k, v, mask=None): 
        #q = self.linear_q(q)
        #k = self.linear_k(k)
        #v = self.linear_v(v)
        
        dim = int(self.in_features / self.head_num)
        #y = ScaledDotProductAttention()(q, k, v, mask)
        q = self._reshape_to_batches(q)
        k = self._reshape_to_batches(k)
        v = self._reshape_to_batches(v)
        if self.type == "OvO":
            att = OneVSOthers(dim, dim)
            y = att([q, k], v) #.cuda()
        else:
            att = MultiplicativeAttention(dim, dim)
            y = att(q, v) #.cuda()
        y = self._reshape_from_batches(y)
        y = self.linear_o(y)
        #if self.activation is not None:
        #    y = self.activation(y)
        return y

    
    def _reshape_to_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, self.head_num, sub_dim)\
                .permute(0, 1, 2)\
                .reshape(batch_size * self.head_num, sub_dim)

    def _reshape_from_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num,  in_feature)\
                .permute(0, 1, 2)\
                .reshape(batch_size,  out_dim)
    """

    def _reshape_to_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size * self.head_num, seq_len, sub_dim)

    def _reshape_from_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size, seq_len, out_dim)
    
    def extra_repr(self):
        return 'in_features={}, head_num={}, bias={}, activation={}'.format(
            self.in_features, self.head_num, self.bias, self.activation)
    """

class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)
        self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        #self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)
        #self.OvO_multihead_attention = MultiheadAttention(256,2, typ = "OvO")
        #self.luong_multihead_attention = MultiHeadAttention(256,2, typ = "luong")
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)

            attn_output_LV = self.luong_attention(bert, res)
            attn_output_VL = self.luong_attention(res, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            res = res[:, None, :]
            bert = self.bert_wrap(bert_emb)
            bert = bert[:, None, :]

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            attn_txt, weights_txt = self.OvO_attention(feat, res, bert) #[feat, res], bert
            attn_img, weights_img = self.OvO_attention(feat, bert,res)
            attn_tab, weights_tab = self.OvO_attention(bert, res, feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out


In [121]:
import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset
 
import matplotlib.pyplot as plt
from PIL import Image
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel


#from models import MultimodalFramework
from model_utils import set_seed, build_optimizer, ReviewsDataset

 #'/gpfs/home/mgolovan/data/mgolovan/Reviews/amazon/home/test_rvw_inputs.pt'
model_name = "bert_resnet_luong"
lr = 5e-05 
epochs = 6
batch_size = 64
#best_model_1e-06_22_adamW_20_resnet.pth
random_seeds = [15] #15, 0, 1,67,  128, 87, 261, 510, 340, 22
df = pd.DataFrame(columns = ['AUROC','accuracy', "precision", "recall", "f1-score", "CM", "CR"]) #str(batch_size)+ '_'

for seed in random_seeds:
    model_path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_' + str(lr)+'_' + str(seed)+'_adamW_'  +  str(epochs)+'_' + str(model_name)+ '.pth_current.pth'

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #str(batch_size)+ '_' +
    print(device)

    torch.cuda.empty_cache()

    model = MultimodalFramework()
    model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu'))) #eager-sweep-1
    model.to(device)
    model.eval()

    test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')
    test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')

    if model_name.split("_")[:2] == ["bert", "resnet"]:
        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False)
        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False ) 

    elif model_name.split("_")[:2] == ["resnet", "mlp"]:
        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False ) 
        modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False )

    else:
        modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False ) 
        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False )


    correct = 0
    total = 0
    pred = []
    test_labels = []

    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for modality1, modality2 in zip(modality_1, modality_2):

            if model_name.split("_")[:2] == ["bert", "resnet"]:
                text_inp, masks, text_labels = modality2
                img_inp, labels = modality1

                text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)
                img_inp, labels = img_inp.to(device), labels.to(device)

                outputs = model([img_inp, text_inp, masks], model_name)

            elif model_name.split("_")[:2] == ["resnet", "mlp"]:
                img_inp, labels = modality1
                tab_inp, tab_labels = modality2
                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                img_inp, labels = img_inp.to(device), labels.to(device)

                outputs = model([tab_inp, img_inp], model_name)
            else:
                tab_inp, tab_labels = modality1
                text_inp, masks, labels = modality2
                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)

                outputs = model([tab_inp, text_inp, masks], model_name)

            test_labels.extend(np.array(labels.cpu()))
            _, predicted = torch.max(outputs, 1)
            pred.extend(predicted.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc= 100 * correct / total
    print(f'Accuracy of the bert: {100 * correct // total} %')

    test_labels = np.array(test_labels)

    #print(confusion_matrix(test_labels, pred))
    cm = confusion_matrix(test_labels, pred)
    #print(classification_report(test_labels, pred))
    cr = classification_report(test_labels, pred, output_dict=True)
    auc = roc_auc_score(test_labels, pred)
    df = df.append({'AUROC': auc,'accuracy': acc, "precision":cr["macro avg"]["precision"]*100 ,
                    "recall":cr["macro avg"]["recall"]*100, "f1-score":cr["macro avg"]["f1-score"]*100,
                    "CM":cm, "CR":cr}, ignore_index=True)

#df.to_csv(model_name + "_2_no_act_one_results.csv")
print(df.mean())
print(df.std())


PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Accuracy of the bert: 56 %
AUROC         0.568000
accuracy     56.800000
precision    57.343286
recall       56.800000
f1-score     55.985915
dtype: float64
AUROC       NaN
accuracy    NaN
precision   NaN
recall      NaN
f1-score    NaN
dtype: float64




In [None]:
Accuracy of the bert: 56 %
AUROC         0.568000
accuracy     56.800000
precision    57.343286
recall       56.800000
f1-score     55.985915
dtype: float64
AUROC       NaN
accuracy    NaN
precision   NaN
recall      NaN
f1-score    NaN
dtype: float64

In [None]:
Accuracy of the bert: 60 %
AUROC         0.606000
accuracy     60.600000
precision    61.775894
recall       60.600000
f1-score     59.591236
dtype: float64
AUROC       NaN
accuracy    NaN
precision   NaN
recall      NaN
f1-score    NaN
dtype: float64

In [None]:
Accuracy of the bert: 66 %
AUROC         0.653181
accuracy     66.000000
precision    64.550000
recall       65.318093
f1-score     64.732094
dtype: float64
AUROC       NaN
accuracy    NaN
precision   NaN
recall      NaN
f1-score    NaN
dtype: float64

In [None]:
Accuracy of the bert: 66 %
AUROC         0.659618
accuracy     66.000000
precision    64.841198
recall       65.961799
f1-score     64.824099
dtype: float64
AUROC       NaN
accuracy    NaN
precision   NaN
recall      NaN
f1-score    NaN
dtype: float64

In [15]:
model.parameters()

<generator object Module.parameters at 0x7fcc38b1a8d0>

In [17]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from MHA_modified import MultiheadAttention

class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(
            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    



class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
    
class OneVSOthers_concat(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        m2 = others[0]
        m3 = others[1]
        con = torch.cat((m2, m3), dim=1)
        weights = con @ self.W @ main.T  # [seq_length]
        return weights/np.sqrt(self.decoder_dim)
#"""
class MultiHeadAttention(nn.Module):

    def __init__(self,
                 in_features,
                 head_num, typ,
                 bias=True,
                 activation=F.relu):
        
        super(MultiHeadAttention, self).__init__()
        if in_features % head_num != 0:
            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))
        self.in_features = in_features
        self.type = typ
        self.head_num = head_num
        self.activation = activation
        self.bias = bias
        self.linear_q = nn.Linear(in_features, in_features, bias)
        self.linear_k = nn.Linear(in_features, in_features, bias)
        self.linear_v = nn.Linear(in_features, in_features, bias)
        self.linear_o = nn.Linear(in_features, in_features, bias)

    def forward(self, q, k, v, mask=None): 
        #q = self.linear_q(q)
        #k = self.linear_k(k)
        #v = self.linear_v(v)
        
        dim = int(self.in_features / self.head_num)
        #y = ScaledDotProductAttention()(q, k, v, mask)
        q = self._reshape_to_batches(q)
        k = self._reshape_to_batches(k)
        v = self._reshape_to_batches(v)
        if self.type == "OvO":
            att = OneVSOthers(dim, dim)
            y = att([q, k], v) #.cuda()
        else:
            att = MultiplicativeAttention(dim, dim)
            y = att(q, v) #.cuda()
        y = self._reshape_from_batches(y)
        y = self.linear_o(y)
        #if self.activation is not None:
        #    y = self.activation(y)
        return y

    
    def _reshape_to_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, self.head_num, sub_dim)\
                .permute(0, 1, 2)\
                .reshape(batch_size * self.head_num, sub_dim)

    def _reshape_from_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num,  in_feature)\
                .permute(0, 1, 2)\
                .reshape(batch_size,  out_dim)
    """

    def _reshape_to_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size * self.head_num, seq_len, sub_dim)

    def _reshape_from_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size, seq_len, out_dim)
    
    def extra_repr(self):
        return 'in_features={}, head_num={}, bias={}, activation={}'.format(
            self.in_features, self.head_num, self.bias, self.activation)
    """

class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)
        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)
        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = "OvO")
        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = "luong")
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)

            attn_output_LV = self.luong_multihead_attention(bert, res, res)
            attn_output_VL = self.luong_multihead_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            res = res[:, None, :]
            bert = self.bert_wrap(bert_emb)
            bert = bert[:, None, :]

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert
            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)
            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out

In [1]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from MHA_modified import MultiheadAttention

class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    



class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
    
class OneVSOthers_concat(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))

    def _get_weights(self,others, main):
        m2 = others[0]
        m3 = others[1]
        con = torch.cat((m2, m3), dim=1)
        weights = con @ self.W @ main.T  # [seq_length]
        return weights/np.sqrt(self.decoder_dim)
#"""

class ScaledDotProductAttention(nn.Module):
    """
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))
    """
    def forward(self, W, query, key, value):

        scores = query @ W @ value.T
        attention = torch.nn.functional.softmax(scores, dim=0)
        #attention = F.softmax(scores, dim=-1)
        return attention.matmul(value)
    
class MultiHeadAttention(nn.Module):

    def __init__(self,
                 in_features,
                 head_num, typ,
                 bias=True,
                 activation=F.relu):
        
        super(MultiHeadAttention, self).__init__()
        if in_features % head_num != 0:
            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))
        self.in_features = in_features
        self.type = typ
        self.head_num = head_num
        self.activation = activation
        self.bias = bias
        self.dim = int(self.in_features / self.head_num)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.dim, self.dim).uniform_(-0.1, 0.1))
        self.linear_q = nn.Linear(in_features, in_features, bias)
        self.linear_k = nn.Linear(in_features, in_features, bias)
        self.linear_v = nn.Linear(in_features, in_features, bias)
        self.linear_o = nn.Linear(in_features, in_features, bias)

    def forward(self, q, k, v, mask=None): 
        #q = self.linear_q(q)
        #k = self.linear_k(k)
        #v = self.linear_v(v)
        
        dim = int(self.in_features / self.head_num)
        #y = ScaledDotProductAttention()(q, k, v, mask)
        #k, q, v = k.to("cpu"), q.to("cpu"), v.to("cpu")
        q = self._reshape_to_batches(q)
        k = self._reshape_to_batches(k)
        v = self._reshape_to_batches(v)
        if self.type == "OvO":
            att = OneVSOthers(dim, dim)
            y = att([q, k], v) #.cuda()
        else:
            #att = MultiplicativeAttention(dim, dim)
            att = ScaledDotProductAttention()
            y = att(self.W, q,k, v) #.cuda()
        #y = y.to("cuda")
        y = self._reshape_from_batches(y)
        y = self.linear_o(y)
        #if self.activation is not None:
        #    y = self.activation(y)
        return y

    
    def _reshape_to_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, self.head_num, sub_dim)\
                .permute(0, 1, 2)\
                .reshape(batch_size * self.head_num, sub_dim)

    def _reshape_from_batches(self, x):
        seq_len = 1
        batch_size, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num,  in_feature)\
                .permute(0, 1, 2)\
                .reshape(batch_size,  out_dim)
    """

    def _reshape_to_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        sub_dim = in_feature // self.head_num
        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size * self.head_num, seq_len, sub_dim)

    def _reshape_from_batches(self, x):
        x = x.unsqueeze(1)
        batch_size, seq_len, in_feature = x.size()
        batch_size //= self.head_num
        out_dim = in_feature * self.head_num
        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\
                .permute(0, 2, 1, 3)\
                .reshape(batch_size, seq_len, out_dim)
    
    def extra_repr(self):
        return 'in_features={}, head_num={}, bias={}, activation={}'.format(
            self.in_features, self.head_num, self.bias, self.activation)
    """

class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)
        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)
        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = "OvO")
        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = "luong")
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)

            attn_output_LV = self.luong_multihead_attention(bert, res, res)
            attn_output_VL = self.luong_multihead_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            res = res[:, None, :]
            bert = self.bert_wrap(bert_emb)
            bert = bert[:, None, :]

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert
            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)
            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out

  from .autonotebook import tqdm as notebook_tqdm


In [40]:
#from models import MultimodalFramework
model = MultimodalFramework()
model.load_state_dict(torch.load("/gpfs/home/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_lin1e-06_340_adamW_128_34_bert_resnet_luong.pth_best.pth",map_location=torch.device('cpu')))


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [41]:
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset
 
import matplotlib.pyplot as plt
from PIL import Image
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #str(batch_size)+ '_' +
from model_utils import set_seed

model.eval()
set_seed(42)
batch_size =128
model_name = "bert_resnet_luong"
test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')
test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')

if model_name.split("_")[:2] == ["bert", "resnet"]:
    modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False)
    modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False ) 

elif model_name.split("_")[:2] == ["resnet", "mlp"]:
    modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False ) 
    modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False )

else:
    modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False ) 
    modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False )


correct = 0
total = 0
running_loss = 0
pred = []
test_labels = []
criterion = nn.CrossEntropyLoss()

# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for modality1, modality2 in zip(modality_1, modality_2):

        if model_name.split("_")[:2] == ["bert", "resnet"]:
            text_inp, masks, text_labels = modality2
            img_inp, labels = modality1

            text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)
            img_inp, labels = img_inp.to(device), labels.to(device)

            outputs = model([img_inp, text_inp, masks], model_name)

        elif model_name.split("_")[:2] == ["resnet", "mlp"]:
            img_inp, labels = modality1
            tab_inp, tab_labels = modality2
            tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

            tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
            img_inp, labels = img_inp.to(device), labels.to(device)

            outputs = model([tab_inp, img_inp], model_name)
        else:
            tab_inp, tab_labels = modality1
            text_inp, masks, labels = modality2
            tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

            tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
            text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)

            outputs = model([tab_inp, text_inp, masks], model_name)
        
        loss = criterion(outputs, labels)
        test_labels.extend(np.array(labels.cpu()))
        _, predicted = torch.max(outputs, 1)
        pred.extend(predicted.cpu().numpy())
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item() * labels.size(0)

acc= 100 * correct / total
loss_f = running_loss/ total
print(f'Accuracy of the bert: {100 * correct / total} %')

test_labels = np.array(test_labels)

#print(confusion_matrix(test_labels, pred))
cm = confusion_matrix(test_labels, pred)
#print(classification_report(test_labels, pred))
cr = classification_report(test_labels, pred, output_dict=True)
auc = roc_auc_score(test_labels, pred)


PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
Random seed set as 42
Accuracy of the bert: 57.2 %


In [8]:
lr = 5e-07
epochs = 68
batch_size = 128 #32
#best_model_1e-06_22_adamW_20_resnet.pth
random_seeds = [15, 0, 1,67,  128, 87, 261, 510, 340, 22] #15, 0, 1,67,  128, 87, 261, 510, 340, 22
#df = pd.DataFrame(columns = ['AUROC','accuracy', "precision", "recall", "f1-score", "CM", "CR"])
model_name = "bert_resnet_luong"
for seed in random_seeds:
    model_path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_' + str(lr)+'_' + str(seed)+'_adamW_' +str(batch_size)+ '_' +  str(epochs)+'_' + str(model_name)+ '.pth_current.pth'
    model = MultimodalFramework()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
    import pandas as pd
    import numpy as np
    import json
    import sys
    import logging
    from pathlib import Path
    import random
    import tarfile
    import tempfile
    import warnings
    import matplotlib.pyplot as plt
    # import pandas_path  # Path style access for pandas
    from tqdm import tqdm
    import torch                    
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    from torchvision import datasets, models, transforms
    import matplotlib.pyplot as plt
    import time
    import os
    import copy
    print("PyTorch Version: ",torch.__version__)
    print("Torchvision Version: ",torchvision.__version__)

    from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score

    from torch.utils.data import TensorDataset
    from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset

    import matplotlib.pyplot as plt
    from PIL import Image
    from transformers import BertTokenizer
    from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #str(batch_size)+ '_' +
    from model_utils import set_seed

    model.eval()
    set_seed(42)
    batch_size =128
    model_name = "bert_resnet_luong"
    test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_txt.pt')
    test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_img.pt')

    if model_name.split("_")[:2] == ["bert", "resnet"]:
        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False)
        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False ) 

    elif model_name.split("_")[:2] == ["resnet", "mlp"]:
        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False ) 
        modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False )

    else:
        modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False ) 
        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False )


    correct = 0
    total = 0
    running_loss = 0
    pred = []
    test_labels = []
    criterion = nn.CrossEntropyLoss()

    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for modality1, modality2 in zip(modality_1, modality_2):

            if model_name.split("_")[:2] == ["bert", "resnet"]:
                text_inp, masks, text_labels = modality2
                img_inp, labels = modality1

                text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)
                img_inp, labels = img_inp.to(device), labels.to(device)

                outputs = model([img_inp, text_inp, masks], model_name)

            elif model_name.split("_")[:2] == ["resnet", "mlp"]:
                img_inp, labels = modality1
                tab_inp, tab_labels = modality2
                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                img_inp, labels = img_inp.to(device), labels.to(device)

                outputs = model([tab_inp, img_inp], model_name)
            else:
                tab_inp, tab_labels = modality1
                text_inp, masks, labels = modality2
                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)

                outputs = model([tab_inp, text_inp, masks], model_name)

            loss = criterion(outputs, labels)
            test_labels.extend(np.array(labels.cpu()))
            _, predicted = torch.max(outputs, 1)
            pred.extend(predicted.cpu().numpy())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item() * labels.size(0)

    acc= 100 * correct / total
    loss_f = running_loss/ total
    print(f'Accuracy of the bert: {100 * correct / total} %')

    test_labels = np.array(test_labels)

    #print(confusion_matrix(test_labels, pred))
    cm = confusion_matrix(test_labels, pred)
    #print(classification_report(test_labels, pred))
    cr = classification_report(test_labels, pred, output_dict=True)
    auc = roc_auc_score(test_labels, pred)
    print(cr['macro avg']['f1-score'])
    df = df.append({'AUROC': auc,'accuracy': acc, "precision":cr["macro avg"]["precision"]*100 ,
                    "recall":cr["macro avg"]["recall"]*100, "f1-score":cr["macro avg"]["f1-score"]*100,
                    "CM":cm, "CR":cr}, ignore_index=True)
    


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
Random seed set as 42
Accuracy of the bert: 66.52941176470588 %
0.6334595049107635


In [9]:
df

Unnamed: 0,AUROC,accuracy,precision,recall,f1-score,CM,CR
0,0.649902,69.176471,66.140263,64.990224,65.364182,"[[870, 220], [304, 306]]","{'0': {'precision': 0.7410562180579217, 'recal..."
1,0.638555,68.647059,65.444905,63.855467,64.270153,"[[881, 209], [324, 286]]","{'0': {'precision': 0.7311203319502074, 'recal..."
2,0.614183,65.058824,61.726473,61.418258,61.543207,"[[810, 280], [314, 296]]","{'0': {'precision': 0.7206405693950177, 'recal..."
3,0.652286,70.176471,67.31272,65.228606,65.752486,"[[902, 188], [319, 291]]","{'0': {'precision': 0.7387387387387387, 'recal..."
4,0.645044,69.294118,66.226512,64.504437,64.958087,"[[888, 202], [320, 290]]","{'0': {'precision': 0.7350993377483444, 'recal..."
5,0.647142,69.470588,66.441448,64.714243,65.17591,"[[889, 201], [318, 292]]","{'0': {'precision': 0.7365368682684341, 'recal..."
6,0.628215,66.117647,63.017255,62.821477,62.909091,"[[812, 278], [298, 312]]","{'0': {'precision': 0.7315315315315315, 'recal..."
7,0.618439,65.882353,62.450593,61.843886,62.053571,"[[830, 260], [320, 290]]","{'0': {'precision': 0.7217391304347827, 'recal..."
8,0.650526,69.117647,66.090551,65.052639,65.403638,"[[866, 224], [301, 309]]","{'0': {'precision': 0.7420736932305055, 'recal..."
9,0.660911,70.588235,67.790275,66.091142,66.591412,"[[894, 196], [304, 306]]","{'0': {'precision': 0.7462437395659433, 'recal..."


In [4]:
df["f1-score"].mean()

64.40217369520445

In [7]:
df.mean()

  """Entry point for launching an IPython kernel.


AUROC         0.640520
accuracy     68.352941
precision    65.264099
recall       64.052038
f1-score     64.402174
dtype: float64

In [20]:
correct / total

0.586

In [43]:
loss_f

1.3884286813735962

In [39]:
cr['macro avg']['f1-score']

0.648864681332673

In [66]:
torch.backends.cudnn.version()

8200

In [None]:
val Loss: 0.8501 Acc: 0.5000

In [None]:
train Loss: 0.6654 Acc: 0.6334
val Loss: 0.7465 Acc: 0.5020
Training complete in 1m 57s
Best val Acc: 0.502000

In [None]:
train Loss: 0.6606 Acc: 0.6379
val Loss: 0.7492 Acc: 0.5000
Training complete in 4m 37s
Best val Acc: 0.500000

In [None]:
train Loss: 0.6575 Acc: 0.6391
val Loss: 0.7548 Acc: 0.5000

In [None]:
train Loss: 0.6589 Acc: 0.6363
val Loss: 0.7782 Acc: 0.5000
Training complete in 1m 55s
Best val Acc: 0.500000
    
train Loss: 0.6603 Acc: 0.6394
val Loss: 0.7611 Acc: 0.5000
Training complete in 2m 27s
Best val Acc: 0.500000

train Loss: 0.6603 Acc: 0.6394
val Loss: 0.7611 Acc: 0.5000
Training complete in 2m 16s
Best val Acc: 0.500000

In [128]:
train Loss: 0.6418 Acc: 0.6437
val Loss: 0.8099 Acc: 0.4940
Training complete in 2m 6s
Best val Acc: 0.494000

train Loss: 0.6418 Acc: 0.6437
val Loss: 0.8099 Acc: 0.4940
Training complete in 2m 6s
Best val Acc: 0.494000

True

In [101]:
pred1 = pred

In [63]:
acc

50.2

In [61]:
acc

48.8

In [None]:
train Loss: 55.9858 Acc: 0.5862
val Loss: 543.6128 Acc: 0.5040
Training complete in 4m 12s
Best val Acc: 0.504000

In [3]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from MHA_modified import MultiheadAttention

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from typing import Optional, Tuple

class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    



class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
"""
class ScaledDotProductAttention(nn.Module):
    
    def __init__(self, dim: int):
        super(ScaledDotProductAttention, self).__init__()
        self.sqrt_dim = np.sqrt(dim)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim

        if mask is not None:
            score.masked_fill_(mask.view(score.size()), -float('Inf'))

        attn = F.softmax(score, -1)
        context = torch.bmm(attn, value)

        return context, attn
"""

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dim: int):
        super(ScaledDotProductAttention, self).__init__()
        self.sqrt_dim = np.sqrt(dim)
    
    def forward(self, query, key, value, W):

        score = query @ W @ value.transpose(1, 2) #.T
        attn = F.softmax(score, -1)
        context = torch.bmm(attn, value)
        return context, attn
    
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention proposed in "Attention Is All You Need"
    Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
    project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
    These are concatenated and once again projected, resulting in the final values.
    Multi-head attention allows the model to jointly attend to information from different representation
    subspaces at different positions.
    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
        where head_i = Attention(Q · W_q, K · W_k, V · W_v)
    Args:
        d_model (int): The dimension of keys / values / quries (default: 512)
        num_heads (int): The number of attention heads. (default: 8)
    Inputs: query, key, value, mask
        - **query** (batch, q_len, d_model): In transformer, three different ways:
            Case 1: come from previoys decoder layer
            Case 2: come from the input embedding
            Case 3: come from the output embedding (masked)
        - **key** (batch, k_len, d_model): In transformer, three different ways:
            Case 1: come from the output of the encoder
            Case 2: come from the input embeddings
            Case 3: come from the output embedding (masked)
        - **value** (batch, v_len, d_model): In transformer, three different ways:
            Case 1: come from the output of the encoder
            Case 2: come from the input embeddings
            Case 3: come from the output embedding (masked)
        - **mask** (-): tensor containing indices to be masked
    Returns: output, attn
        - **output** (batch, output_len, dimensions): tensor containing the attended output features.
        - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
    """
    def __init__(self, d_model: int = 512, num_heads: int = 8):
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model % num_heads should be zero."

        self.d_head = int(d_model / num_heads)
        self.num_heads = num_heads
        self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
        self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.d_head, self.d_head).uniform_(-0.1, 0.1))

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            mask: Optional[Tensor] = None
    ) -> Tuple[Tensor, Tensor]:
        batch_size = value.size(0)

        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)  # BxQ_LENxNxD
        key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head)      # BxK_LENxNxD
        value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head)  # BxV_LENxNxD

        query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxQ_LENxD
        key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)      # BNxK_LENxD
        value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxV_LENxD

        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)  # BxNxQ_LENxK_LEN

        #context, attn = self.scaled_dot_attn(query, key, value, mask)
        context, attn = self.scaled_dot_attn(query, key, value, self.W)

        context = context.view(self.num_heads, batch_size, -1, self.d_head)
        context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head)  # BxTxND

        return context

class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)
        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        #self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)
        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = "OvO")
        #self.luong_multihead_attention = MultiHeadAttention(256,2, typ = "luong")
        self.luong_multihead_attention = MultiHeadAttention(256,2)
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res).unsqueeze(1)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert).unsqueeze(1)

            attn_output_LV = self.luong_multihead_attention(bert, res, res)
            attn_output_VL = self.luong_multihead_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            res = res[:, None, :]
            bert = self.bert_wrap(bert_emb)
            bert = bert[:, None, :]

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert
            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)
            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out

In [10]:
model = MultimodalFramework()
model.load_state_dict(torch.load("/gpfs/home/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_new_0.00005_42_adamW_16_150_bert_resnet_luong.pth_current.pth",map_location=torch.device('cpu')))


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


FileNotFoundError: [Errno 2] No such file or directory: '/gpfs/home/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_new_0.00005_42_adamW_16_150_bert_resnet_luong.pth_current.pth'

In [11]:
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset
 
import matplotlib.pyplot as plt
from PIL import Image
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #str(batch_size)+ '_' +
from model_utils import set_seed

model.eval()
set_seed(42)
batch_size =16
model_name = "bert_resnet_luong"
test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_txt.pt')
test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_img.pt')

if model_name.split("_")[:2] == ["bert", "resnet"]:
    modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False)
    modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False ) 

elif model_name.split("_")[:2] == ["resnet", "mlp"]:
    modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False ) 
    modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False )

else:
    modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False ) 
    modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False )


correct = 0
total = 0
running_loss = 0
pred = []
test_labels = []
criterion = nn.CrossEntropyLoss()

# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for modality1, modality2 in zip(modality_1, modality_2):

        if model_name.split("_")[:2] == ["bert", "resnet"]:
            text_inp, masks, text_labels = modality2
            img_inp, labels = modality1

            text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)
            img_inp, labels = img_inp.to(device), labels.to(device)

            outputs = model([img_inp, text_inp, masks], model_name)

        elif model_name.split("_")[:2] == ["resnet", "mlp"]:
            img_inp, labels = modality1
            tab_inp, tab_labels = modality2
            tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

            tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
            img_inp, labels = img_inp.to(device), labels.to(device)

            outputs = model([tab_inp, img_inp], model_name)
        else:
            tab_inp, tab_labels = modality1
            text_inp, masks, labels = modality2
            tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

            tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
            text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)

            outputs = model([tab_inp, text_inp, masks], model_name)
        
        loss = criterion(outputs, labels)
        test_labels.extend(np.array(labels.cpu()))
        _, predicted = torch.max(outputs, 1)
        pred.extend(predicted.cpu().numpy())
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item() * labels.size(0)

acc= 100 * correct / total
loss_f = running_loss/ total
print(f'Accuracy of the bert: {100 * correct / total} %')

test_labels = np.array(test_labels)

#print(confusion_matrix(test_labels, pred))
cm = confusion_matrix(test_labels, pred)
#print(classification_report(test_labels, pred))
cr = classification_report(test_labels, pred, output_dict=True)
auc = roc_auc_score(test_labels, pred)


PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
Random seed set as 42
Accuracy of the bert: 64.11764705882354 %


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [27]:
loss_f

1.472200929697822

In [12]:
cr['macro avg']['f1-score']

0.3906810035842294

In [16]:
import pandas as pd
import numpy as np
import json
import logging
import random
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch import flatten

from collections import OrderedDict
from transformers import BertModel, DistilBertModel

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from MHA_modified import MultiheadAttention

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from typing import Optional, Tuple

class Attention(torch.nn.Module):
    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim

    def forward(self, 
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
        ):
        weights = self._get_weights(query, values) # [seq_length]
        weights = torch.nn.functional.softmax(weights, dim=0)
        return weights @ values  # [encoder_dim]

class MultiplicativeAttention(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')

    def _get_weights(self,
        query: torch.Tensor,  # [decoder_dim]
        values: torch.Tensor, # [seq_length, encoder_dim]
    ):
        weights = query @ self.W @ values.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim)    



class OneVSOthers(Attention):

    def __init__(self, encoder_dim: int, decoder_dim: int):
        super().__init__(encoder_dim, decoder_dim)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')

    def _get_weights(self,others, main):
        mean = sum(others) / len(others)
        weights = mean @ self.W @ main.T  # [seq_length]
        return weights #/np.sqrt(self.decoder_dim) 
"""
class ScaledDotProductAttention(nn.Module):
    
    def __init__(self, dim: int):
        super(ScaledDotProductAttention, self).__init__()
        self.sqrt_dim = np.sqrt(dim)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim

        if mask is not None:
            score.masked_fill_(mask.view(score.size()), -float('Inf'))

        attn = F.softmax(score, -1)
        context = torch.bmm(attn, value)

        return context, attn
"""

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dim: int):
        super(ScaledDotProductAttention, self).__init__()
        self.sqrt_dim = np.sqrt(dim)
    
    def forward(self, query, key, value, W):

        score = query @ W @ value.transpose(1, 2) #.T
        attn = F.softmax(score, -1)
        context = torch.bmm(attn, value)
        return context, attn
    
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention proposed in "Attention Is All You Need"
    Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
    project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
    These are concatenated and once again projected, resulting in the final values.
    Multi-head attention allows the model to jointly attend to information from different representation
    subspaces at different positions.
    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
        where head_i = Attention(Q · W_q, K · W_k, V · W_v)
    Args:
        d_model (int): The dimension of keys / values / quries (default: 512)
        num_heads (int): The number of attention heads. (default: 8)
    Inputs: query, key, value, mask
        - **query** (batch, q_len, d_model): In transformer, three different ways:
            Case 1: come from previoys decoder layer
            Case 2: come from the input embedding
            Case 3: come from the output embedding (masked)
        - **key** (batch, k_len, d_model): In transformer, three different ways:
            Case 1: come from the output of the encoder
            Case 2: come from the input embeddings
            Case 3: come from the output embedding (masked)
        - **value** (batch, v_len, d_model): In transformer, three different ways:
            Case 1: come from the output of the encoder
            Case 2: come from the input embeddings
            Case 3: come from the output embedding (masked)
        - **mask** (-): tensor containing indices to be masked
    Returns: output, attn
        - **output** (batch, output_len, dimensions): tensor containing the attended output features.
        - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
    """
    def __init__(self, d_model: int = 512, num_heads: int = 8):
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model % num_heads should be zero."

        self.d_head = int(d_model / num_heads)
        self.num_heads = num_heads
        self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
        self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
        self.W = torch.nn.Parameter(torch.FloatTensor(self.d_head, self.d_head).uniform_(-0.1, 0.1))

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            mask: Optional[Tensor] = None
    ) -> Tuple[Tensor, Tensor]:
        batch_size = value.size(0)
        
        print(query.shape)
        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)  # BxQ_LENxNxD
        key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head)      # BxK_LENxNxD
        value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head)  # BxV_LENxNxD
        print(query.shape)
        query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxQ_LENxD
        print(query.shape)
        key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)      # BNxK_LENxD
        value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxV_LENxD

        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)  # BxNxQ_LENxK_LEN

        #context, attn = self.scaled_dot_attn(query, key, value, mask)
        context, attn = self.scaled_dot_attn(query, key, value, self.W)

        context = context.view(self.num_heads, batch_size, -1, self.d_head)
        context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head)  # BxTxND

        return context

class MultimodalFramework(nn.Module):

    def __init__(self):
        super(MultimodalFramework, self).__init__()
        ##MLP
        self.fc1 = nn.Linear(53, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 2)
        self.relu = nn.ReLU()
        
        ##RESNET
        self.resnet18 = models.resnet18(pretrained=True)
        n_inputs = self.resnet18.fc.in_features

        self.resnet18.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(n_inputs, 512))
        ])) 

        self.resnet_classification = nn.Linear(512, 2) #4
        
        ##BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased') 
        self.bert_classification = nn.Linear(768, 2)
        
        #Two Modality models
        self.bert_resnet_classification = nn.Linear(512 + 768, 2)
        self.bert_mlp_classification = nn.Linear(256 + 768, 2)
        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)
        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)
        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)
        self.att_classification = nn.Linear(256*2, 2)
        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)
        self.OvO_classification = nn.Linear(3*256, 2)

        self.res_wrap = nn.Linear(512, 256)
        self.bert_wrap = nn.Linear(768, 256)
        
        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)
        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)
        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)
        #self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)
        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = "OvO")
        #self.luong_multihead_attention = MultiHeadAttention(256,2, typ = "luong")
        self.luong_multihead_attention = MultiHeadAttention(256,2)
        
    def bi_directional_att(self, pair):
        x = pair[0]
        y = pair[1]
        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)
        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)
        combined = torch.cat((attn_output_LV,
                              attn_output_VL), dim=1)
        return combined

    def forward(self, x, model):
        if model == "mlp":
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            x = self.relu(x)
            out = self.fc3(x)
            
        elif model == "resnet":
            res = self.resnet18(x)       
            out = self.resnet_classification(res)
        
        elif model == "bert":
            text, masks = x
            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]
            out = self.bert_classification(bert)
            
        elif model == "bert_resnet":
            img, text, masks = x
            res_emb = self.resnet18(img)
            
            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]
            combined = torch.cat((res_emb,
                                  bert_emb), dim=1)
            out = self.bert_resnet_classification(combined)
        
        elif model == "bert_resnet_luong":
            img, text, masks = x
            
            res = self.resnet18(img)
            res = self.res_wrap(res) #.unsqueeze(1)

            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert) #.unsqueeze(1)

            attn_output_LV = self.luong_multihead_attention(bert, res, res)
            attn_output_VL = self.luong_multihead_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV, #.squeeze(1)
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
        
        elif model == "bert_resnet_vaswani":
            img, text, masks = x
            res_emb = self.resnet18(img)
        
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.res_wrap(res_emb)
            res = res[:, None, :]
            bert = self.bert_wrap(bert_emb)
            bert = bert[:, None, :]

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)

            combined = torch.cat((attn_output_LV.squeeze(1),
                                  attn_output_VL.squeeze(1)), dim=1)
            out = self.att_classification(combined)
            
        elif model == "bert_mlp":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            
            combined = torch.cat((bert,feat), dim=1)
            out = self.bert_mlp_classification(combined)
            
        elif model == "bert_mlp_luong":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV = self.luong_attention(bert, feat)
            attn_output_VL = self.luong_attention(feat, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "bert_mlp_vaswani":
            features, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert_emb)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)
            
        elif model == "resnet_mlp":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            
            combined = torch.cat((feat,res), dim=1)
            out = self.resnet_mlp_classification(combined)
            
        elif model == "resnet_mlp_luong":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV = self.luong_attention(res, feat)
            attn_output_VL = self.luong_attention(feat, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            out = self.att_classification(combined)   
            
        elif model == "resnet_mlp_vaswani":
            features, img = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)

            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)
            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)

            combined = torch.cat((attn_output_LV,
                                  attn_output_VL), dim=1)
            
            out = self.att_classification(combined)
            
        elif model == "bert_resnet_mlp":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            res = self.resnet18(img)

            
            combined = torch.cat((bert,feat, res), dim=1)
            out = self.bert_resnet_mlp_classification(combined)
            
        elif model == "bert_resnet_mlp_l":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            res = self.resnet18(img)
            res = self.res_wrap(res)

            combined = torch.cat((bert, feat, res), dim=1)
            out = self.bert_resnet_mlp_l_classification(combined)
        
        elif model == "bert_resnet_mlp_vaswani":
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            pairs = [[feat, bert],[feat,res],[bert,res]]
        
            results = []
            for pair in pairs:
                combined = self.bi_directional_att(pair)
                results.append(combined)

            comb = torch.cat(results, dim=1)
            out = self.vaswani_3_classification(comb)
            
        else:
            features, img, text, masks = x
            
            feat = self.fc1(features)
            feat = self.relu(feat)
            feat = self.fc2(feat)
            
            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits
            bert = self.bert_wrap(bert)
            
            res = self.resnet18(img)
            res = self.res_wrap(res)
            
            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert
            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)
            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)

            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)
            out = self.OvO_classification(comb)

        return out

import pandas as pd
import numpy as np
import json
import sys
import logging
from pathlib import Path
import random
import tarfile
import tempfile
import warnings
import matplotlib.pyplot as plt
# import pandas_path  # Path style access for pandas
from tqdm import tqdm
import torch                    
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import pprint
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
from sklearn.metrics import confusion_matrix,f1_score

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from collections import OrderedDict

from PIL import ImageFile,Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

from transformers import BertModel
from transformers import BertForSequenceClassification, AdamW, BertConfig


import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

#from models import MultimodalFramework
from model_utils import set_seed, build_optimizer


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

torch.cuda.empty_cache()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

torch.cuda.empty_cache()

def save_model(net, optim, ckpt_fname, epoch):                                                                                                                                                             
    state_dict = net.state_dict()                                                                          
    for key in state_dict.keys():                                                                                 
        state_dict[key] = state_dict[key].cpu()  

    torch.save({                                                                                                                                                                                                 
        'epoch': epoch,                                                                                                                                                                                     
        'state_dict': state_dict,                                                                                                                                                                                
        'optimizer': optim},                                                                                                                                                                                     
        ckpt_fname)

#train_model(model_name, dataloaders_dict, criterion, len_train, len_val, config, path)
def train_model(model_name, dataloaders, criterion, len_train, len_val, config, path):
    
    set_seed(42)
    model = MultimodalFramework()
    
    #torch.cuda.empty_cache()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

   
    #model = model.to('cuda')
    
    num_epochs = 1
    optimizer = build_optimizer(model, "adamW", 0.0001, 0.9)

    since = time.time()

    val_acc_history = []
    val_loss_history = []
    train_acc_history = []
    train_loss_history = []

    best_acc = 0.0
    patience = 5 
    trigger = 0
    acc_dict = {}

    for epoch in range(num_epochs):
        #scheduler.step()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                length = len_train
                model.train()  # Set model to training mode
            else:
                length = len_val
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            predicted_labels, ground_truth_labels = [], []

            for modality1, modality2 in zip(dataloaders[phase][0], dataloaders[phase][1]):
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    
                    if model_name.split("_")[:2] == ["bert", "resnet"]:
                        text_inp, masks, text_labels = modality2
                        img_inp, labels = modality1

                        text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)
                        img_inp, labels = img_inp.to(device), labels.to(device)
                        #text_inp, masks, text_labels = text_inp.cuda(), masks.cuda(), text_labels.cuda()
                        #img_inp, labels = img_inp.cuda(), labels.cuda()
                        
                        inp_len = text_inp.size(0)
                        outputs = model([img_inp, text_inp, masks], model_name)

                    elif model_name.split("_")[:2] == ["resnet", "mlp"]:
                        img_inp, labels = modality1
                        tab_inp, tab_labels = modality2
                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                        img_inp, labels = img_inp.to(device), labels.to(device)
                        
                        inp_len = tab_inp.size(0)
                        outputs = model([tab_inp, img_inp], model_name)
                    else:
                        tab_inp, tab_labels = modality1
                        text_inp, masks, labels = modality2
                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()

                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)
                        text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)
                        
                        inp_len = tab_inp.size(0)
                        outputs = model([tab_inp, text_inp, masks], model_name)
                    
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                #print("text_inp.size(0)")
                #print(text_inp.size(0))

                running_loss += loss.item() * labels.size(0)
                running_corrects += torch.sum(preds == labels.data)
                predicted_labels.extend(preds.cpu().detach().numpy())
                ground_truth_labels.extend(labels.cpu().detach().numpy())
                
            epoch_loss = running_loss / length
            epoch_acc = running_corrects.double() / length
            #epoch_f1 = f1.double() / len(dataloaders[phase].dataset)
            epoch_f1 = f1_score(ground_truth_labels, predicted_labels)

            print('{} Loss: {} Acc: {}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val':
                #wandb.log({"val_loss": epoch_loss, "val_acc": epoch_acc, "val_f1": epoch_f1})
                acc_dict[epoch] = float(epoch_acc.detach().cpu())
                val_acc_history.append(epoch_acc)
                val_loss_history.append(epoch_loss)
                save_model(model, optimizer, path+"_save.pth", epoch)
                #print(model.state_dict())
                #torch.save(model.cpu().state_dict(), path+"_current.pth")
                #model = model.cuda()
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    #best_model_wts = copy.deepcopy(model.state_dict())
                    #torch.save(model.state_dict(), path+"_best.pth")
                """
                if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):
                    trigger +=1
                    if trigger >= patience:
                        return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}
                else:
                    trigger = 0
                """    
            if phase == 'train':
                #wandb.log({"train_loss": epoch_loss, "train_acc": epoch_acc,"train_f1": epoch_f1, "epoch": epoch})
                train_acc_history.append(epoch_acc.detach().cpu())
                train_loss_history.append(epoch_loss)


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    #model.load_state_dict(best_model_wts)
    #torch.save(model.cpu().state_dict(), path+"_last.pth")
    return model, {"train_acc":train_acc_history, "val_acc":val_acc_history,"train_loss":train_loss_history, "val_loss":val_loss_history}


model_name = "bert_resnet_luong"

train_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_txt.pt')
val_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')

train_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_img.pt')
val_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')


criterion = nn.CrossEntropyLoss()


train_dataloader_text = DataLoader(train_inputs_txt, batch_size=16,shuffle=False)
val_dataloader_text = DataLoader(val_inputs_txt, batch_size=16, shuffle=False)

train_dataloader_img = DataLoader(train_inputs_img, batch_size=16,shuffle=False)
val_dataloader_img = DataLoader(val_inputs_img, batch_size=16, shuffle=False)

len_val = len(val_inputs_txt)
len_train = len(train_inputs_txt)

dataloaders_dict = {'train':[train_dataloader_img, train_dataloader_text], 'val':[val_dataloader_img, val_dataloader_text]}


path = '' + model_name + "_original_"

model, dic = train_model(model_name, dataloaders_dict, criterion, len_train, len_val, "config", path)  




PyTorch Version:  1.11.0+cu113
Torchvision Version:  0.12.0+cu113
cpu
cpu
Random seed set as 42


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch 0/0
----------
torch.Size([16, 1, 256])
torch.Size([16, 1, 2, 128])
torch.Size([32, 1, 128])
torch.Size([16, 1, 256])
torch.Size([16, 1, 2, 128])
torch.Size([32, 1, 128])


KeyboardInterrupt: 

In [None]:
torch.Size([16, 1, 256])
torch.Size([16, 1, 2, 128])
torch.Size([32, 1, 128])
torch.Size([16, 1, 256])
torch.Size([16, 1, 2, 128])
torch.Size([32, 1, 128])

In [None]:
torch.Size([16, 1, 256])
torch.Size([16, 1, 2, 128])
torch.Size([32, 1, 128])
torch.Size([16, 1, 256])
torch.Size([16, 1, 2, 128])
torch.Size([32, 1, 128])