In [43]:
from fmri_caption import GPTCaptionModel, create_fmri_encoder_from_pretrained,top_k_top_p_filtering
from dataset import create_BOLD5000_dataset
from torch.utils.data import DataLoader, Subset
import torch
import torch.optim as optim
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer, util
from sklearn.model_selection import train_test_split

#### Setup

In [20]:
# Parameters
BATCH_SIZE = 3
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# Pretrained files
path_fmri_encoder = r"C:\Users\roeys\OneDrive - Technion\Semester 7\DL\Project\Mind-Cap\Mind-Cap\pretrains\pretrain_metafile.pth"
path_BOLD_dataset = r"C:\Users\roeys\OneDrive - Technion\Semester 7\DL\Project\Mind-Cap\Mind-Cap\data\BOLD5000\CSI1_dataset.pth"

# create BOLD5000 dataset
BOLD_dataset = torch.load(path_BOLD_dataset)
bold_train, bold_test = BOLD_dataset['train'], BOLD_dataset['test']


In [23]:
test_idx, val_idx = train_test_split(list(range(len(bold_test))),test_size=0.5)
bold_val = Subset(bold_test, val_idx)
bold_test = Subset(bold_test, test_idx)

train_dl = DataLoader(bold_train, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(bold_val, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(bold_test, batch_size=BATCH_SIZE, shuffle=True)


10064

#### Function Delerations

In [45]:
def calculate_semantic_similarity(fmri_prefix, real_caption, decoder):
    generated_caption = decoder.generate_caption(fmri_prefix)
    sentence_model = SentenceTransformer('all-mpnet-base-v2')
    embed_generated = sentence_model.encode(generated_caption, convert_to_tensor=True)
    embed_real_caption = sentence_model.encode(real_caption, convert_to_tensor=True)

    return util.pytorch_cos_sim(embed_generated, embed_real_caption)

def define_GPTCaption_model(encoder, trial=None, projection_sizes=None):
    if trial:
        # TODO: Add Optuna support. If trial is used, override projection_sizes with Optuna suggestion
        num_layers = trial.suggest_int("num_projection_layers", 1, 5)
        projection_sizes = [encoder.embed_dim]*num_layers
    gpt_decoder = GPTCaptionModel(encoder.num_patches, encoder.embed_dim, projection_sizes)

    return gpt_decoder

def objective(trial, encoder, train_dl, val_dl):

    # Generate the model
    model = define_GPTCaption_model(encoder, trial)

    # Generate the optimizers
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    optimizer_name = trial.suggest_categorial("optimizer", ['Adam', 'AdamW', "SGD"])
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr)

    # Training
    encoder.eval()
    for epoch in range(NUM_EPOCHS):
        model.train()
        print(f"** Starting epoch {epoch} **")
        for batch_idx, batch in enumerate(train_dl):
            if batch_idx*BATCH_SIZE >= TRIAL_NUM_TRAIN_EXAMPLES:
                break

            encoder.zero_grad()
            optimizer.zero_grad()

            print(f">>>> encoding fmri scans ", end="")
            fmri_prefix = encoder.forward(batch['fmri'])
            print(f"-> tokenizing captions ", end="")
            tokens, attention_mask = decoder.tokenizer(batch['caption'], return_tensors="pt", padding=True).values()
            print(f"-> to device {device} ", end="")
            tokens, attention_mask, fmri_prefix = tokens.to(device), attention_mask.to(device), fmri_prefix.to(device)
            print(f"-> decoding ", end="")
            outputs = model.forward(tokens, fmri_prefix, attention_mask)
            logits = outputs.logits[:, decoder.prefix_length-1:-1]

            print(f"-> calculating loss ")
            loss = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=decoder.tokenizer.pad_token_id)
            loss.backward()
            epoch_train_loss.append(loss.item())
            optimizer.step()
            print(f">>>> batch {batch_idx} finished")

        print(f"---- epoch loss: {epoch_train_loss[epoch]} ---- ")

    # Evaluating model
    model.eval()



#### Training Loop

In [24]:
# Constants
lr = 1e-3
NUM_EPOCHS = 5
TRIAL_NUM_TRAIN_EXAMPLES = BATCH_SIZE*30
TRIAL_NUM_VAL_EXAMPLES = BATCH_SIZE*10

In [None]:
# Get encoder-decoder
encoder = create_fmri_encoder_from_pretrained(path_fmri_encoder, bold_train.num_voxels)
encoder.eval()
decoder = define_GPTCaption_model(encoder, projection_sizes=[encoder.embed_dim])

optimizer = optim.AdamW(decoder.parameters(), lr=lr)


In [29]:
# Train
print("\n\n")
epoch_train_loss = []
for epoch in range(NUM_EPOCHS):
    decoder.train()
    print(f"** Starting epoch {epoch} **")
    for batch_idx, batch in enumerate(train_dl):
        encoder.zero_grad()
        optimizer.zero_grad()

        print(f">>>> encoding fmri scans ", end="")
        fmri_prefix = encoder.forward(batch['fmri'])
        print(f"-> tokenizing captions ", end="")
        tokens, attention_mask = decoder.tokenizer(batch['caption'], return_tensors="pt", padding=True).values()
        tokens, attention_mask, fmri_prefix = tokens.to(device), attention_mask.to(device), fmri_prefix.to(device)
        print(f"-> decoding ")
        outputs = decoder.forward(tokens, fmri_prefix, attention_mask)
        logits = outputs.logits[:, decoder.prefix_length-1:-1]

        loss = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=decoder.tokenizer.pad_token_id)
        loss.backward()
        epoch_train_loss.append(loss.item())
        optimizer.step()
        print(f">>>> batch {batch_idx} finished")

    print(f"---- epoch loss: {epoch_train_loss[epoch]} ---- ")






** Starting epoch 0 **
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 0 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 1 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 2 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 3 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 4 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 5 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 6 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 7 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 8 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 9 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 
>>>> batch 10 finished
>>>> encoding fmri scans -> tokenizing captions -> decoding 

KeyboardInterrupt: 

In [33]:
sent1 = "I like Python because I can build AI applications"
sent2 = "I like Python because I can do data analytics"
token1 = decoder.tokenizer.encode(sent1, return_tensors="pt")
token2 = decoder.tokenizer.encode(sent2, return_tensors="pt")
embed1 = decoder.gpt.transformer.wte(token1)
embed2 = decoder.gpt.transformer.wte(token2)
#torch.cosine_similarity()

In [44]:
filtered = top_k_top_p_filtering(logits,top_k=1000, top_p=0.95)

RuntimeError: torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries tensor and input value tensor must match, but we got boundaries tensor [3, 16, 1000] and input value tensor [3, 1]