# Deep Learning (EE-559) Mini-Project
Members:
Luca Salvador,
Marco Giuliano,
Paolo Giaretta
#
Professor:
Andrea Cavallaro

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
from functools import partial
from collections import Counter
from IPython.display import display
from torch.utils.data import WeightedRandomSampler

import sys
sys.path.insert(0, './code')
from trainer import BimodalTrainer
from GANtrainer import GANTrainer
from models import Classifier, Generator
from dataset_utils import BimodalDataset
from utils import evaluate_cos_similarities

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# Set up seed and device
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setup dataset path

cd = os.getcwd()

data_path = os.path.join(cd, 'data')

# Print
print('Device:', device)
print('Data path:', data_path)

Loading data

In [None]:
# Load dataset
dataset_folder_name = 'hateful_memes'
dataset_path = os.path.join(data_path, dataset_folder_name)
train_path = os.path.join(dataset_path, 'train.pth')
val_path = os.path.join(dataset_path, 'dev_unseen.pth')
test_path = os.path.join(dataset_path, 'test_unseen.pth')

# Load dataset (train-only for now)
train_data = torch.load(train_path)
val_data = torch.load(val_path)
test_data = torch.load(test_path)

# Calculate labels count for each
train_labels = [x['label'] for x in train_data]
val_labels = [x['label'] for x in val_data]
test_labels = [x['label'] for x in test_data]

print('Train data:', train_data[0].keys(), len(train_data), Counter(train_labels))
print('Val data:', val_data[0].keys(), len(val_data), Counter(val_labels))
print('Test data:', test_data[0].keys(), len(test_data), Counter(test_labels))

# GAN

In [None]:
# GAN HYPERPARAMETERS #

# Label to train on
GAN_label = 1                       # 0: non-offensive, 1: offensive
toxicity_threshold = 0.5            # Threshold to add generated data to the training set based on frozen-classifier
similarilty_threshold = 0.5         # Threhold to add generated data to the training set based on similarity to the original data

# Training
batch_size = 64
epochs = 35
num_gen_steps = 1
num_disc_steps = 1

# Loss weights
lambda_gp = 0.5
lambda_L1_gen = 0
lambda_L2_gen = 1
lambda_L1_disc = 0
lambda_L2_disc = 1
lambda_consistency = 0.2
lambda_ms = 1

Training parameters

In [None]:
# from 1 to 2

# Optimization
lr_gen = 1e-4
lr_disc =  3e-4
gen_optimi = partial(optim.RMSprop, lr=lr_gen)
disc_optim = partial(optim.RMSprop, lr=lr_disc)
gen_scheduler = partial(StepLR, step_size=1, gamma=0.95)#from 0.99 to 1
disc_scheduler = partial(StepLR, step_size=1, gamma=0.95)#from 0.99 to 1
weight_cliping = None

# Metrics
metrics = {'acc_fake': BinaryAccuracy(), 'acc_real':BinaryAccuracy()}

# Model ########################################################
clip_feature_dim = 768

# Generator
noise_dim = 512
gen_hidden_dims = [2*clip_feature_dim] * 2
gen_dropout_prob = 0.2
gen_normalize_features = False
gen_bn = False

# Discriminator
classifier_hidden_dims = [256, 64]
classifiers_dropout_prob = 0.38#from 0.5 to 0.3 
comb_dropout_prob = 0.2
comb_fusion = 'concat'
disc_normalize_features = False
disc_bn = False

In [None]:
# Filter the data for GAN generation
train_data_filtered = list(filter(lambda x: x['label'] == GAN_label, train_data))
print(f'Train dataset size after filtering on {GAN_label}:', len(train_data_filtered))

In [None]:
GAN_dataset = BimodalDataset(train_data_filtered)
train_dataloader_GAN  = DataLoader(GAN_dataset, 
                              batch_size=batch_size, 
                              shuffle=True,
                              pin_memory=True,
                              drop_last=False,
                              # sampler=CustomSampler()
                              )

# Check cosine similarity in dataset (for first batch)
batch = next(iter(train_dataloader_GAN))
img_embedding, text_embedding, _ = batch
cos_sim_img, cos_sim_text = evaluate_cos_similarities(img_embedding, text_embedding)
print(f'Mean cosine similarity in first batch \n image: {cos_sim_img: .4f}, text: {cos_sim_text: .4f}')

In [None]:

# GAN models ########################################################
gen = Generator(
                 img_embedding_size=clip_feature_dim,
                 text_embedding_size=clip_feature_dim,
                 noise_dim=noise_dim,
                 hidden_dims=gen_hidden_dims,
                 dropout_prob=gen_dropout_prob,
                 bn=gen_bn,
                 act="relu",
                 normalize_features=gen_normalize_features
                 ).to(device)

disc = Classifier(
                 clip_feature_dim,
                 comb_convex_tensor=False,
                 comb_proj=False, 
                 comb_fusion=comb_fusion, 
                 comb_dropout_prob=comb_dropout_prob,
                 classifier_hidden_dims=classifier_hidden_dims,
                 act="relu",
                 bn=disc_bn,
                 classifiers_dropout_prob=classifiers_dropout_prob,
                 normalize_features=disc_normalize_features
                 ).to(device)

In [None]:
# TRAINER ########################################################
# Define optimizers
gen_optimizer = gen_optimi(gen.parameters())
disc_optimizer = disc_optim(disc.parameters())

# Define criterions
# WGAN Losses
gen_GAN_criterion = lambda prediction: -torch.mean(prediction) 
disc_GAN_criterion = lambda fake_predictions, real_predictions: -torch.mean(real_predictions) + torch.mean(fake_predictions) 

In [None]:
GAN_trainer = GANTrainer(gen=gen,
                    disc=disc,
                    gen_optimizer=gen_optimizer,
                    disc_optimizer=disc_optimizer,
                    gen_scheduler=gen_scheduler(gen_optimizer),
                    disc_scheduler=disc_scheduler(disc_optimizer),
                    noise_dim=noise_dim,
                    gen_GAN_criterion=gen_GAN_criterion,
                    disc_GAN_criterion=disc_GAN_criterion,
                    num_gen_steps=num_gen_steps,
                    num_disc_steps=num_disc_steps,
                    metrics=metrics,
                    weight_clip=weight_cliping,
                    lambda_gp=lambda_gp,
                    lambda_L1_gen=lambda_L1_gen,
                    lambda_L2_gen=lambda_L2_gen,
                    lambda_L1_disc=lambda_L1_disc,
                    lambda_L2_disc=lambda_L2_disc,
                    lambda_consistency=lambda_consistency,
                    lambda_ms=lambda_ms,
                    device=device
                    )

GAN training

In [None]:
# Train the model
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
dh = display(fig, display_id=True)

path = os.path.join(cd, 'models/GAN.pth')
GAN_trainer.train(train_dataloader_GAN, epochs, fig, dh)
GAN_trainer.save(path)

# Classifiers

In [None]:
# CLASSIFIER HYPERPARAMETERS ################################################################
batch_size = 32

# Loss weights
lambda_L1 = 0
lambda_L2 = 0

# Optimization
lr = 1e-3
optimizer_fun = partial(optim.AdamW, lr=lr)
scheduler_fun = partial(StepLR, step_size=1, gamma=0.8)

# Metrics
metrics = {'acc': BinaryAccuracy(), 'auroc': BinaryAUROC()}

# Model
clip_feature_dim = 768
classifier_hidden_dims = [64]
classifier_dropout_prob = 0.2
comb_dropout_prob = 0.2
comb_fusion = 'concat'
normalize_features = False
classifier_bn = False

# Bimodal Classifier

1) Bimodal trained on the unbalanced dataset

In [None]:
# Model
classifier_bimodal = Classifier(
                input_dim=clip_feature_dim,
                comb_convex_tensor=False,
                comb_proj=False, 
                comb_fusion=comb_fusion, 
                comb_dropout_prob=comb_dropout_prob,
                classifier_hidden_dims=classifier_hidden_dims,
                act="relu",
                bn=classifier_bn,
                classifiers_dropout_prob=classifier_dropout_prob,
                normalize_features=normalize_features
                ).to(device)

In [None]:
# Define dataloader
train_dataset = BimodalDataset(train_data)
val_dataset = BimodalDataset(val_data)

# Define dataloaders (unbalanced)
train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size, 
                              shuffle=True,
                              pin_memory=True,
                              drop_last=False,
                              )

val_dataloader = DataLoader(val_dataset, 
                            batch_size=batch_size, 
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            # sampler=CustomSampler(train_dataset)
                            )

In [None]:
# TRAINER ################################################################
optimizer = optimizer_fun(classifier_bimodal.parameters())
scheduler = scheduler_fun(optimizer)
criterion = nn.BCEWithLogitsLoss()

trainer = BimodalTrainer(classifier_bimodal,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        criterion=criterion,
                        metrics=metrics,
                        lambda_L1=lambda_L1,
                        lambda_L2=lambda_L2,
                        device=device)

In [None]:
# train the model
epochs = 20
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
dh = display(fig, display_id=True)

trainer.train(train_dataloader, val_dataloader, epochs, fig, dh)
# Save
path = os.path.join(cd, 'models/Classifier.pth')
trainer.save(path)
plt.close()

2) Bimodal trained with weighted loss

In [None]:
weighted_bimodal = Classifier(
                input_dim=clip_feature_dim,
                comb_convex_tensor=False,
                comb_proj=False, 
                comb_fusion=comb_fusion, 
                comb_dropout_prob=comb_dropout_prob,
                classifier_hidden_dims=classifier_hidden_dims,
                act="relu",
                bn=classifier_bn,
                classifiers_dropout_prob=classifier_dropout_prob,
                normalize_features=normalize_features
                ).to(device)

In [None]:
# Calculate weights
length_train = len(train_data)
labels_train = [x['label'] for x in train_data]
weights = torch.tensor([1-(x / length_train) for x in Counter(labels_train).values()]).to(device)
print('Weights:', weights)


In [None]:
# Define dataloader
train_dataset = BimodalDataset(train_data)
val_dataset = BimodalDataset(val_data)

# Define dataloaders (unbalanced)
train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size, 
                              shuffle=True,
                              pin_memory=True,
                              drop_last=False,
                              )

val_dataloader = DataLoader(val_dataset, 
                            batch_size=batch_size, 
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            # sampler=CustomSampler(train_dataset)
                            )

In [None]:
# TRAINER ################################################################
optimizer = optimizer_fun(classifier.parameters())
scheduler = scheduler_fun(optimizer)
criterion = lambda preds, labels: F.binary_cross_entropy_with_logits(preds.view(-1), labels, pos_weight=torch.gather(weights, dim=0, index=labels.type(torch.int64)))

trainer = BimodalTrainer(classifier,
                        optimizer=optimizer,
                        criterion=criterion,
                        scheduler=scheduler,
                        metrics=metrics,
                        lambda_L1=lambda_L1,
                        lambda_L2=lambda_L2,
                        device=device)

In [None]:
# train the model
epochs = 20
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
dh = display(fig, display_id=True)

trainer.train(train_dataloader, val_dataloader, epochs, fig, dh)
# Save
path = os.path.join(cd, 'models/Classifier_weightsBalanced.pth')
trainer.save(path)
plt.close()

3) Bimodal trained on over sampled dataset

In [None]:
oversampled_bimodal = Classifier(
                input_dim=clip_feature_dim,
                comb_convex_tensor=False,
                comb_proj=False, 
                comb_fusion=comb_fusion, 
                comb_dropout_prob=comb_dropout_prob,
                classifier_hidden_dims=classifier_hidden_dims,
                act="relu",
                bn=classifier_bn,
                classifiers_dropout_prob=classifier_dropout_prob,
                normalize_features=normalize_features
                ).to(device)

In [None]:
individual_weights = torch.tensor([weights[x] for x in train_labels]).to(device)
sampler = WeightedRandomSampler(individual_weights, len(train_data))

# Define dataloader
train_dataset = BimodalDataset(train_data)
val_dataset = BimodalDataset(val_data)

# Define dataloaders (balanced)
train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size, 
                              shuffle=False,
                              pin_memory=True,
                              drop_last=False,
                              sampler=sampler
                              )

val_dataloader = DataLoader(val_dataset,
                            batch_size=batch_size, 
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            )

In [None]:
# TRAINER ################################################################
optimizer = optimizer_fun(classifier.parameters())
scheduler = scheduler_fun(optimizer)
criterion = nn.BCEWithLogitsLoss()

trainer = BimodalTrainer(classifier,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        criterion=criterion,
                        metrics=metrics,
                        lambda_L1=lambda_L1,
                        lambda_L2=lambda_L2,
                        device=device)

In [None]:
# train the model
epochs = 20
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
dh = display(fig, display_id=True)

trainer.train(train_dataloader, val_dataloader, epochs, fig, dh)
# Save
path = os.path.join(cd, 'models/Classifier_samplerBalanced.pth')
trainer.save(path)
plt.close()

4) Bimodal trained on GAN augmented dataset

In [1]:
# Model
classifier_bimodal = Classifier(
                 input_dim=clip_feature_dim,
                 comb_convex_tensor=False,
                 comb_proj=False, 
                 comb_fusion=comb_fusion, 
                 comb_dropout_prob=comb_dropout_prob,
                 classifier_hidden_dims=classifier_hidden_dims,
                 act="relu",
                 bn=classifier_bn,
                 classifiers_dropout_prob=classifier_dropout_prob,
                 normalize_features=normalize_features
                ).to(device)


SyntaxError: unmatched ')' (2142120098.py, line 1)

In [None]:
classifier_bimodal.load_state_dict(torch.load('models/Classifier.pth'))

In [None]:
# Freeze classifier_bimodal
for param in classifier_bimodal.parameters():
    param.requires_grad = False

In [None]:
img_from_train = torch.cat([x['image'] for x in train_data_filtered], dim=0).float()
text_from_train = torch.cat([x['text'] for x in train_data_filtered], dim=0).float()
img_from_train= F.normalize(img_from_train, p=2, dim=-1).cpu()
text_from_train = F.normalize(text_from_train, p=2, dim=-1).cpu()

In [None]:
num_samples = 100000
num_samples_to_add = 300
generated_data = torch.zeros(num_samples_to_add, 2, 768).to(device)
prediction_threshold = 1
similarilty_threshold = 1
sim_df_L = 0
sim_df_U = 0.75
count = 0

gen.eval()
classifier_bimodal.eval()

with torch.no_grad():
    for i in range(num_samples):
        # Generate data
        noise = torch.randn(1, noise_dim).to(device)
        image, text = gen(noise)


        # Predict
        prediction = classifier_bimodal(image, text)
        predicton = F.sigmoid(prediction)
        if prediction < prediction_threshold:
            # Normalize
            image_norm = F.normalize(image, p=2, dim=-1)
            text_norm = F.normalize(text, p=2, dim=-1)

            if count == 0:
                generated_data[count, 0, :]=image
                generated_data[count, 1, :]=text
                count+=1
                continue

            # Similarity cosine with generated data
            generated_data_img = F.normalize(generated_data[:count, 0, :], p=2, dim=-1)
            generated_data_text = F.normalize(generated_data[:count, 1, :], p=2, dim=-1)

            cos_sim_img = torch.mm(generated_data_img, image_norm.T).cpu()
            cos_sim_text = torch.mm(generated_data_text, text_norm.T).cpu()

            sim = max(cos_sim_img.max(), cos_sim_text.max())

            # similarity cosing with original data

            cos_sim_img_df = torch.mm(img_from_train, image_norm.T).cpu()
            cos_sim_text_df = torch.mm(text_from_train, text_norm.T).cpu()

            sim_gen = max(cos_sim_img_df.max(), cos_sim_text_df.max())


            if sim < similarilty_threshold and (sim_gen > sim_df_L and sim_gen < sim_df_U):
                generated_data[count, 0, :]=image
                generated_data[count, 1, :]=text
                count+=1
                print(f"Added {count} samples, remainig data {num_samples-i} samples.")
                if count==num_samples_to_add:
                    break

In [None]:
# Join generated data with train dataset
# 1. Transform generated data into list of dicts
generated_data_list =[{'image': generated_data[i, 0, :], 'text': generated_data[i, 1, :], 'label': GAN_label} for i in range(count)]

# 2. Add generated data to the train dataset
train_data_GAN = train_data + generated_data_list
print(len(train_data_GAN))


# 3. Define dataloader
train_dataset_GAN = BimodalDataset(train_data_GAN)
train_dataloader_GAN = DataLoader(train_dataset_GAN, 
                              batch_size=batch_size, 
                              shuffle=True,
                              pin_memory=True,
                              drop_last=False,
                              )


In [None]:
classifier_bimodal_gen = Classifier(
                 input_dim=clip_feature_dim,
                 comb_convex_tensor=False,
                 comb_proj=False, 
                 comb_fusion=comb_fusion, 
                 comb_dropout_prob=comb_dropout_prob,
                 classifier_hidden_dims=classifier_hidden_dims,
                 act="relu",
                 bn=classifier_bn,
                 classifiers_dropout_prob=classifier_dropout_prob,
                 normalize_features=normalize_features
                ).to(device)

In [None]:


length_train = len(train_data_GAN)
labels_train = [x['label'] for x in train_data_GAN]
weights = torch.tensor([1-(x / length_train) for x in Counter(labels_train).values()]).to(device)

criterion = lambda preds, labels: F.binary_cross_entropy_with_logits(preds.view(-1), labels, pos_weight=torch.gather(weights, dim=0, index=labels.type(torch.int64)))

In [None]:
# Train the classifier on the new dataset
# Unfreeze the classifier
for param in classifier_bimodal_gen.parameters():
    param.requires_grad = True

# TRAINER ################################################################
lr = 5e-4
optimizer = optim.AdamW(classifier_bimodal_gen.parameters(), lr=lr)
scheduler = scheduler_fun(optimizer)
#criterion = nn.BCEWithLogitsLoss()
metrics = {'acc': BinaryAccuracy(), 'auroc': BinaryAUROC()}

trainer = BimodalTrainer(classifier_bimodal_gen,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        criterion=criterion,
                        metrics=metrics,
                        lambda_L1=lambda_L1,
                        lambda_L2=lambda_L2,
                        device=device)

In [None]:
# train the model
epochs = 20
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
dh = display(fig, display_id=True)

trainer.train(train_dataloader, val_dataloader, epochs, fig, dh)
# Save
path = os.path.join(cd, 'models/Classifier_GANBalanced.pth')
trainer.save(path)
plt.close()

# Evaluation

1) Bimodal trained on the unbalanced dataset

In [None]:
# Model
classifier_bimodal = Classifier(
                 input_dim=clip_feature_dim,
                 comb_convex_tensor=False,
                 comb_proj=False, 
                 comb_fusion=comb_fusion, 
                 comb_dropout_prob=comb_dropout_prob,
                 classifier_hidden_dims=classifier_hidden_dims,
                 act="relu",
                 bn=classifier_bn,
                 classifiers_dropout_prob=classifier_dropout_prob,
                 normalize_features=normalize_features
                ).to(device)

In [None]:
classifier_bimodal.load_state_dict(torch.load('models/Classifier.pth'))

In [None]:
# TRAINER ################################################################
optimizer = optimizer_fun(classifier_bimodal.parameters())
scheduler = scheduler_fun(optimizer)
criterion = nn.BCEWithLogitsLoss()

trainer = BimodalTrainer(classifier_bimodal,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        criterion=criterion,
                        metrics=metrics,
                        lambda_L1=lambda_L1,
                        lambda_L2=lambda_L2,
                        device=device)

In [None]:
# Evaluate on test dataset
test_dataset = BimodalDataset(test_data)
test_dataloader = DataLoader(test_dataset, 
                            batch_size=batch_size, 
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            )

trainer.evaluate(test_dataloader)
for metric, value in trainer.val_metrics_log.items():
    print(f'{metric}: {value[-1]}')


2) Bimodal trained with weighted loss


In [None]:
# Model
classifier = Classifier(
                 input_dim=clip_feature_dim,
                 comb_convex_tensor=False,
                 comb_proj=False, 
                 comb_fusion=comb_fusion, 
                 comb_dropout_prob=comb_dropout_prob,
                 classifier_hidden_dims=classifier_hidden_dims,
                 act="relu",
                 bn=classifier_bn,
                 classifiers_dropout_prob=classifier_dropout_prob,
                 normalize_features=normalize_features
                ).to(device)

In [None]:
classifier.load_state_dict(torch.load('models/Classifier_weightsBalanced.pth'))

In [None]:
# TRAINER ################################################################
optimizer = optimizer_fun(classifier.parameters())
scheduler = scheduler_fun(optimizer)
criterion = lambda preds, labels: F.binary_cross_entropy_with_logits(preds.view(-1), labels, pos_weight=torch.gather(weights, dim=0, index=labels.type(torch.int64)))

trainer = BimodalTrainer(classifier,
                        optimizer=optimizer,
                        criterion=criterion,
                        scheduler=scheduler,
                        metrics=metrics,
                        lambda_L1=lambda_L1,
                        lambda_L2=lambda_L2,
                        device=device)

In [None]:
# Evaluate on test dataset
test_dataset = BimodalDataset(test_data)
test_dataloader = DataLoader(test_dataset, 
                            batch_size=batch_size, 
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            )

trainer.evaluate(test_dataloader)
for metric, value in trainer.val_metrics_log.items():
    print(f'{metric}: {value[-1]}')


3) Bimodal trained on over sampled dataset

In [None]:
# Model
classifier = Classifier(
                 input_dim=clip_feature_dim,
                 comb_convex_tensor=False,
                 comb_proj=False, 
                 comb_fusion=comb_fusion, 
                 comb_dropout_prob=comb_dropout_prob,
                 classifier_hidden_dims=classifier_hidden_dims,
                 act="relu",
                 bn=classifier_bn,
                 classifiers_dropout_prob=classifier_dropout_prob,
                 normalize_features=normalize_features
                ).to(device)

In [None]:
classifier.load_state_dict(torch.load('models/Classifier_samplerBalanced.pth'))

In [None]:
# TRAINER ################################################################
optimizer = optimizer_fun(classifier.parameters())
scheduler = scheduler_fun(optimizer)
criterion = nn.BCEWithLogitsLoss()

trainer = BimodalTrainer(classifier,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        criterion=criterion,
                        metrics=metrics,
                        lambda_L1=lambda_L1,
                        lambda_L2=lambda_L2,
                        device=device)

In [None]:
# Evaluate on test dataset
test_dataset = BimodalDataset(test_data)
test_dataloader = DataLoader(test_dataset, 
                            batch_size=batch_size, 
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            )

trainer.evaluate(test_dataloader)
for metric, value in trainer.val_metrics_log.items():
    print(f'{metric}: {value[-1]}')


4) Bimodal trained on GAN augmented dataset

In [None]:
classifier_bimodal_gen = Classifier(
                 input_dim=clip_feature_dim,
                 comb_convex_tensor=False,
                 comb_proj=False, 
                 comb_fusion=comb_fusion, 
                 comb_dropout_prob=comb_dropout_prob,
                 classifier_hidden_dims=classifier_hidden_dims,
                 act="relu",
                 bn=classifier_bn,
                 classifiers_dropout_prob=classifier_dropout_prob,
                 normalize_features=normalize_features
                ).to(device)

In [None]:
classifier_bimodal_gen.load_state_dict(torch.load('models/Classifier_GANBalanced.pth'))

In [None]:
# Train the classifier on the new dataset
# Unfreeze the classifier
for param in classifier_bimodal_gen.parameters():
    param.requires_grad = True

# TRAINER ################################################################
lr = 5e-4
optimizer = optim.AdamW(classifier_bimodal_gen.parameters(), lr=lr)
scheduler = scheduler_fun(optimizer)
#criterion = nn.BCEWithLogitsLoss()
metrics = {'acc': BinaryAccuracy(), 'auroc': BinaryAUROC()}

trainer = BimodalTrainer(classifier_bimodal_gen,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        criterion=criterion,
                        metrics=metrics,
                        lambda_L1=lambda_L1,
                        lambda_L2=lambda_L2,
                        device=device)

In [None]:
# Evaluate on test dataset
test_dataset = BimodalDataset(test_data)
test_dataloader = DataLoader(test_dataset, 
                            batch_size=batch_size, 
                            shuffle=False,
                            pin_memory=True,
                            drop_last=False,
                            )

trainer.evaluate(test_dataloader)
for metric, value in trainer.val_metrics_log.items():
    print(f'{metric}: {value[-1]}')