In [None]:
import sys
sys.path.append(r'./Mind_Vis_utils/')

from fmri_caption import GPTCaptionModel, create_fmri_encoder_from_pretrained,top_k_top_p_filtering, set_parameter_requires_grad, define_GPTCaption_model
from utils import calculate_accuracy_on_test, calculate_semantic_similarity, state_dict_MLP_to_MLP_dropout, get_k_best_torch, print_batch
from dataset import BOLD5000_dataset, identity
from dataset import create_BOLD5000_dataset
from torch.utils.data import DataLoader, Subset
import torch
import torch.optim as optim
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer, util
from transformers import get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import SelectKBest
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import os
#import optuna
%matplotlib inline

#### Setup

In [20]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_from_checkpoint = False
subjects_list = ['CSI1'] # ['CSI1', 'CSI2', 'CSI3', 'CSI4'], Only needed when create_dataset=True
path_pretrained_fmri_encoder = r"../pretrains/fmri_encoder_pretrain_metafile.pth"
path_presaved_dataset = r"../data/BOLD5000/CSI1_no_duplicates.pth"
path_checkpoints = '../data/Checkpoints'

print(f"Using device: {device}")
print(f"Training on subject's {subjects_list}")
print(f"MinD-Vis pretrained encoder: {path_pretrained_fmri_encoder}")


Using device: cuda:0
Training on subject's ['CSI1']
MinD-Vis pretrained encoder: /databases/roeyshafran/BrainCap/pretrains/pretrain_metafile.pth


#### Load Dataset

In [3]:
# create BOLD5000 dataset
BOLD_dataset = torch.load(path_presaved_dataset)
bold_train, bold_test = BOLD_dataset['train'], BOLD_dataset['test']
num_voxels = bold_test.num_voxels

train_idx, val_idx = train_test_split(list(range(len(bold_train))),test_size=0.1)
bold_val = Subset(bold_train, val_idx)
bold_train = Subset(bold_train, train_idx)

print(f"Train length: {len(bold_train)}, Validation length: {len(bold_val)}, Test length: {len(bold_test)}")

#### Model Initialization and Hyper-Parameters

In [None]:
# Hyper-Parameters
BATCH_SIZE = 8
LEARNING_RATE = 1.25e-5
NUM_EPOCHS = 16
weight_decay = 0.2
use_amp = False
train_from_checkpoint = True # Set to True if you want to load previous checkout and train from there.
scheduler_milestones = []
scheduler_gamma = 0.5

train_dl = DataLoader(bold_train, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(bold_val, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(bold_test, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
# Get encoder-decoder
encoder = create_fmri_encoder_from_pretrained(path_pretrained_fmri_encoder, num_voxels, feature_extraction=True)
encoder = encoder.to(device)

# The MLP mapping network architecture
projection_sizes = [encoder.embed_dim,4*encoder.embed_dim, 4*encoder.embed_dim,4*encoder.embed_dim, 2*encoder.embed_dim]

# Make sure to set use_dropout if you are loading a checkout that was trained with dropout (final checkpoint for example)
decoder = define_GPTCaption_model(encoder, projection_sizes=projection_sizes, use_dropout=True)

Position interpolate from 262 to 106
missing keys: ['mask_token']
unexpected keys: ['decoder_pos_embed', 'decoder_embed.weight', 'decoder_embed.bias', 'decoder_blocks.0.norm1.weight', 'decoder_blocks.0.norm1.bias', 'decoder_blocks.0.attn.qkv.weight', 'decoder_blocks.0.attn.qkv.bias', 'decoder_blocks.0.attn.proj.weight', 'decoder_blocks.0.attn.proj.bias', 'decoder_blocks.0.norm2.weight', 'decoder_blocks.0.norm2.bias', 'decoder_blocks.0.mlp.fc1.weight', 'decoder_blocks.0.mlp.fc1.bias', 'decoder_blocks.0.mlp.fc2.weight', 'decoder_blocks.0.mlp.fc2.bias', 'decoder_blocks.1.norm1.weight', 'decoder_blocks.1.norm1.bias', 'decoder_blocks.1.attn.qkv.weight', 'decoder_blocks.1.attn.qkv.bias', 'decoder_blocks.1.attn.proj.weight', 'decoder_blocks.1.attn.proj.bias', 'decoder_blocks.1.norm2.weight', 'decoder_blocks.1.norm2.bias', 'decoder_blocks.1.mlp.fc1.weight', 'decoder_blocks.1.mlp.fc1.bias', 'decoder_blocks.1.mlp.fc2.weight', 'decoder_blocks.1.mlp.fc2.bias', 'decoder_blocks.2.norm1.weight', 'dec

In [None]:
# If train_from_checkpoint == True, load checkpoint.
checkpoint_name = r'decoder_22012023_12-20-13.pth'
if train_from_checkpoint:
    model_dict = torch.load(os.path.join(path_checkpoints, checkpoint_name))
    print(f"Loaded checkpoin comment: {model_dict['comment']}")
    
    # Loading training data to visualize the complete training process
    running_loss = model_dict['training_data']['running_loss']
    running_semantic_accuracy = model_dict['training_data']['running_semantic_accuracy']
    val_accuracy = model_dict['training_data']['val_accuracy']
    lr_monitor = model_dict['training_data']['lr_monitor']
    
    new_sd = state_dict_MLP_to_MLP_dropout(decoder.embedding_space_projection.state_dict(), model_dict['decoder_projection']['sd'])
    print(decoder.embedding_space_projection.load_state_dict(new_sd))
    set_parameter_requires_grad(decoder.embedding_space_projection, feature_extraction=False)
    del model_dict
    torch.cuda.empty_cache()

decoder.to(device)

In [7]:
optimizer = optim.AdamW(decoder.parameters(), lr=LEARNING_RATE, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, scheduler_milestones, gamma=scheduler_gamma, verbose=True)
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

Adjusting learning rate of group 0 to 1.2500e-05.


#### Training Loop

In [None]:
decoder.train()
encoder.eval()
decoder.to(device)

# Use saved training data if training from a checkpoint
if not train_from_checkpoint:
  running_loss = []
  running_semantic_accuracy = []
  val_accuracy = []
  lr_monitor = []

for epoch in range(NUM_EPOCHS):
    print(f"\n---- Starting epoch {epoch} ----")
    decoder.train()
    with tqdm(train_dl, unit='batch') as tepoch:
      semantic_accuracy = 0
      for batch_idx, batch in enumerate(tepoch):
          decoder.train()
          tepoch.set_description(f"Epoch: {epoch}")

          #batch_fmri = batch['fmri'].to(device)
          batch_fmri = batch['fmri']
          batch_fmri = batch_fmri.to(device)
          batch_caption = batch['caption']

          with torch.cuda.amp.autocast(enabled=use_amp):
            fmri_prefix = encoder.forward(batch_fmri)
            tokens, attention_mask = decoder.tokenizer(batch_caption, return_tensors="pt", padding=True).values()
            tokens, attention_mask, fmri_prefix = tokens.to(device), attention_mask.to(device), fmri_prefix.to(device)
            outputs = decoder.forward(tokens, fmri_prefix, attention_mask)
            logits = outputs.logits[:, decoder.prefix_length-1:-1]
            loss = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=decoder.tokenizer.pad_token_id)

          decoder.zero_grad(set_to_none=True)
          optimizer.zero_grad(set_to_none=True)
          scaler.scale(loss).backward()
          scaler.step(optimizer)
          scaler.update()

          # Evaluate model
          if batch_idx % 10 == 0:
            decoder.eval()
            with torch.no_grad():
              generated_caption = decoder.generate_caption(fmri_prefix, device)
              semantic_accuracy = torch.mean(calculate_semantic_similarity(generated_caption, batch_caption, device)).item()
              running_semantic_accuracy.append(semantic_accuracy)
              running_loss.append(loss.item())
              lr_monitor.append(scheduler.get_last_lr())
          
          tepoch.set_postfix(loss=loss.item(), train_accuracy=semantic_accuracy)

          # Free GPU memory, this was needed to prevent our GPU RAM from reaching capacity
          del batch_fmri, batch_caption, tokens, attention_mask, logits, outputs, loss
          torch.cuda.empty_cache()
    decoder.eval()
    val_accuracy.append(calculate_accuracy_on_test(encoder, decoder, val_dl, device, return_best_batch=False))
    print(f"---- epoch {epoch} loss: {np.mean(running_loss[int(np.floor((epoch*len(train_dl))/10 + 1)):-1]):.4}, train accuracy: {np.mean(running_semantic_accuracy[int(np.floor((epoch*len(train_dl))/10 + 1)):-1]):.4}, validation accuracy (running, % above thresh): {val_accuracy[-1]} ---- ")
    scheduler.step()


#### Model Evaluation and Visualization

In [None]:
# Visuazlie Training

fig, axs = plt.subplots(3,1, figsize=(10,10))
axs = axs.flatten()
#batch_iterations = np.arange(0, len(train_dl)*NUM_EPOCHS+1, 10)
batch_iterations = np.arange(0, len(running_loss)*10, 10)
val_iterations = np.arange(0, len(val_accuracy)*len(train_dl), len(train_dl))
axs[0].plot(batch_iterations, running_loss)
axs[0].set_xlabel('Batch')
axs[0].set_ylabel('Loss')

axs[1].plot(batch_iterations, running_semantic_accuracy, label='Train')
axs[1].plot(val_iterations, val_accuracy, '.', label=('Validation', 'Validation - above 0.5'))
axs[1].set_xlabel('Batch')
axs[1].set_ylabel('Accuracy')
axs[1].legend()

axs[2].plot(batch_iterations[:-1:10], np.array(lr_monitor).squeeze()[0:-1:10])
axs[2].set_xlabel('Batch')
axs[2].set_ylabel('lr')

fig.tight_layout() 

In [None]:
# Calculate accuracies
decoder.to(device)
decoder.eval()
accuracy_test, above_thresh_test = calculate_accuracy_on_test(encoder, decoder, test_dl, device, return_best_batch=False)
accuracy_val, above_thresh_val = calculate_accuracy_on_test(encoder, decoder, val_dl, device, return_best_batch=False)
accuracy_train, above_thresh_train = calculate_accuracy_on_test(encoder, decoder, train_dl, device, return_best_batch=False)

print(f"Train: {accuracy_train}, {above_thresh_train*100}%")
print(f"Validaion: {accuracy_val}, {above_thresh_val*100}%")
print(f"Test: {accuracy_test}, {above_thresh_test*100}%")

In [None]:
# Get Best Samples
K = 2
best_k_val = get_k_best_torch(encoder, decoder, val_dl, K, device)
best_k_train = get_k_best_torch(encoder, decoder, train_dl, K, device)
best_k_test = get_k_best_torch(encoder, decoder, test_dl, K, device)

def best_k_dict_to_records(best_k):
    best_k['image'] = best_k['image'].cpu()
    best_k['fmri'] = best_k['fmri'].cpu()
    best_k['accuracy'] = best_k['accuracy'].cpu()
    best_k_records = [dict(zip(best_k,t)) for t in zip(*best_k.values())]

    return best_k_records

best_k_train_records = best_k_dict_to_records(best_k_train)
best_k_test_records = best_k_dict_to_records(best_k_test)
best_k_val_records = best_k_dict_to_records(best_k_val)

In [None]:
# Print best results
# print_batch(best_k_train_records, fontsize=10, num_of_columns=2, caption_as_title=False)
# print_batch(best_k_val_records, fontsize=10, num_of_columns=2, caption_as_title=False)
print_batch(best_k_test_records, fontsize=10, num_of_columns=2, caption_as_title=False)

#### Save Model

In [None]:
# Test accuracy is used for the file name. Run calculate accuracies cell before.
# Change saved parameters as needed
to_save = {
    'comment': "You can add some comments on the checkpoint",
    'hyperparameters': {'batch_size': BATCH_SIZE},
    'decoder_projection': {'sizes': decoder.projection_sizes, 'sd': decoder.embedding_space_projection.state_dict()},
    'optimizer': {'type': type(optimizer), 'optimizer_param_groups': optimizer.state_dict()['param_groups']},
    'scheduler': {'type': type(scheduler), 'sd': scheduler.state_dict()},
    'training_data': {
        "running_loss": running_loss,
        'running_semantic_accuracy': running_semantic_accuracy,
        'val_accuracy': val_accuracy,
        'lr_monitor': lr_monitor
    }
}

now = datetime.now()
dt_string = now.strftime("%d%m%Y_%H-%M-%S")
torch.save(to_save, os.path.join(path_checkpoints, f"decoder_test_accuracy_{accuracy_test:.4}_{dt_string}.pth"))