## Loading Libraries and data

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import warnings
warnings.filterwarnings('ignore')
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [None]:
## Loading BOLD5000 data

BOLD_brain_signals_train = torch.load("/BOLD5000/data_augmentation/train_25percent_noise_75.pt", map_location=device)
# BOLD_brain_signals_train = torch.load("/BOLD5000_V2/brain_signals_train.pt", map_location=device)
BOLD_brain_signals_test = torch.load("/BOLD5000_V2/brain_signals_test.pt", map_location=device)

# BOLD_brain_signals_train = torch.from_numpy(BOLD_brain_signals_train).to(device)
BOLD_brain_signals_test = torch.from_numpy(BOLD_brain_signals_test).to(device)

# Normalizing the BOLD signals

mean = BOLD_brain_signals_train.mean(0)
std = BOLD_brain_signals_train.std(0)
BOLD_brain_signals_train = (BOLD_brain_signals_train - mean) / std
BOLD_brain_signals_test = (BOLD_brain_signals_test - mean) / std

BOLD_brain_signals_train = torch.nan_to_num(BOLD_brain_signals_train)
BOLD_brain_signals_test = torch.nan_to_num(BOLD_brain_signals_test)


BOLD_stimulus_embeddings_train = torch.load( "/BOLD5000/data_augmentation/train_embds_25percent.pt", map_location=device)
# BOLD_stimulus_embeddings_train = torch.load( "/BOLD5000_V2/image_embeddings_train.pt", map_location=device)
BOLD_stimulus_embeddings_test = torch.load("/BOLD5000_V2/image_embeddings_test.pt", map_location=device)

BOLD_subject_ids_train = torch.load("/BOLD5000/data_augmentation/train_ids_25percent.pt", map_location=device)
# BOLD_subject_ids_train = torch.load("/BOLD5000_V2/subject_ids_train.pt", map_location=device)
BOLD_subject_ids_train = torch.tensor(BOLD_subject_ids_train).to(device)

BOLD_subject_ids_test = torch.load("/BOLD5000_V2/subject_ids_test.pt", map_location=device)

# Convert BOLD5000 brain signals to tensor
BOLD_brain_signals_train = torch.tensor(BOLD_brain_signals_train).float()
BOLD_brain_signals_test = torch.tensor(BOLD_brain_signals_test).float()

# Extract CLIP embeddings from BOLD5000
BOLD_stimulus_embeddings_train = BOLD_stimulus_embeddings_train[:, 1, 0].float()
BOLD_stimulus_embeddings_test = BOLD_stimulus_embeddings_test[:, 1, 0].float()



In [3]:
BOLD_brain_signals_train.shape, BOLD_stimulus_embeddings_train.shape, BOLD_subject_ids_train.shape, BOLD_brain_signals_test.shape, BOLD_stimulus_embeddings_test.shape, BOLD_subject_ids_test.shape

(torch.Size([21566, 15724]),
 torch.Size([21566, 1280]),
 torch.Size([21566]),
 torch.Size([445, 15724]),
 torch.Size([445, 1280]),
 torch.Size([445]))

In [None]:
#Loading NSD dataset

NSD_train_data = torch.load("/decoding_NSD/data_augmentation/train_25percent_noise_50.pt").to(device).float()

NSD_test_data = np.load("/decoding_NSD/data_fmri_nsd/test_data.npy")
NSD_test_data = torch.from_numpy(NSD_test_data).to(device).float()


mean = NSD_train_data.mean(0)
std = NSD_train_data.std(0)
NSD_train_data = (NSD_train_data - mean) / std
NSD_test_data = (NSD_test_data - mean) / std

NSD_train_data = torch.nan_to_num(NSD_train_data)
NSD_test_data = torch.nan_to_num(NSD_test_data)

NSD_train_img_embeds = torch.load("/decoding_NSD/data_augmentation/train_embds_25percent.pt", map_location=device)
NSD_test_img_embeds = torch.load("/decoding_NSD/data_fmri_nsd/test_clip_img_embeds.pt", map_location=device)

NSD_train_img_embeds = NSD_train_img_embeds[:, 1, 0].float()
NSD_test_img_embeds = NSD_test_img_embeds[:, 1, 0].float()


NSD_subject_train_ids = torch.load("/decoding_NSD/data_augmentation/train_ids_25percent.pt")
NSD_subject_test_ids = np.load("/decoding_NSD/data_fmri_nsd/subject_test_ids.npy")

# NSD_subject_train_ids=[int(i[-1]) for i in NSD_subject_train_ids]
NSD_subject_test_ids=[int(i[-1]) for i in NSD_subject_test_ids]

NSD_subject_train_ids = torch.from_numpy(np.array(NSD_subject_train_ids)).to(device)
NSD_subject_test_ids = torch.from_numpy(np.array(NSD_subject_test_ids)).to(device)

In [None]:
#Loading GOD dataset
GOD_train_brain_signals = torch.load('/GOD/train_brain_signals.pt').to(device)
GOD_test_brain_signals = torch.load('/GOD/test_brain_signals.pt').to(device)

mean = GOD_train_brain_signals.mean(0)
std = GOD_train_brain_signals.std(0)
GOD_train_brain_signals = (GOD_train_brain_signals - mean) / std
GOD_test_brain_signals = (GOD_test_brain_signals - mean) / std

GOD_train_brain_signals = torch.nan_to_num(GOD_train_brain_signals).float()
GOD_test_brain_signals = torch.nan_to_num(GOD_test_brain_signals).float()

GOD_train_image_embds = torch.load('/GOD/train_image_embeddings.pt', map_location=device)
GOD_test_image_embds = torch.load('/GOD/test_image_embeddings.pt', map_location=device)

GOD_train_image_embds = GOD_train_image_embds[:, 1, 0].float()
GOD_test_image_embds = GOD_test_image_embds[:, 1, 0].float()


GOD_train_subject_ids = torch.load('/GOD/train_subject_ids.pt', map_location=device)
# GOD_train_subject_ids = torch.tensor(GOD_train_subject_ids).to(device)
GOD_test_subject_ids = torch.load('/GOD/test_subject_ids.pt', map_location=device)


GOD_train_brain_signals.shape, GOD_train_image_embds.shape, GOD_train_subject_ids.shape, GOD_test_brain_signals.shape, GOD_test_image_embds.shape, GOD_test_subject_ids.shape

(torch.Size([6000, 15724]),
 torch.Size([6000, 1280]),
 torch.Size([6000]),
 torch.Size([250, 15724]),
 torch.Size([250, 1280]),
 torch.Size([250]))

In [6]:
# Concatenate datasets
cross_train_data = torch.cat([NSD_train_data, BOLD_brain_signals_train, GOD_train_brain_signals], dim=0)
cross_test_data = torch.cat([NSD_test_data, BOLD_brain_signals_test, GOD_test_brain_signals], dim=0)

cross_train_embeds = torch.cat([NSD_train_img_embeds, BOLD_stimulus_embeddings_train, GOD_train_image_embds], dim=0)
cross_test_embeds = torch.cat([NSD_test_img_embeds, BOLD_stimulus_embeddings_test, GOD_test_image_embds], dim=0)

cross_train_subjects = torch.cat([torch.tensor(NSD_subject_train_ids,  device=device), torch.tensor(BOLD_subject_ids_train, device=device), torch.tensor(GOD_train_subject_ids, device=device)], dim=0)
cross_test_subjects = torch.cat([torch.tensor(NSD_subject_test_ids, device=device), torch.tensor(BOLD_subject_ids_test, device=device), torch.tensor(GOD_test_subject_ids, device=device)], dim=0)

## Defining model

In [34]:
import torch
import torch.nn as nn
import pytorch_lightning as pl


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, act_fn=nn.ReLU, alignment_layers_keys=[12, 13, 14, 15, 16], common_dim=1024): #,  1, 2, 5, 7,  8, 9, 10, 11
        super(Encoder, self).__init__()
        self.common_dim = common_dim
        self.alignment_layers = nn.ModuleDict({str(k): nn.Linear(input_dim, common_dim) for k in alignment_layers_keys})
        self.dropout = nn.Dropout(p=0.3)

            # Transformer Encoder
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=common_dim, nhead=8, dim_feedforward=2048, activation='gelu'), 
            num_layers=4
        )
        
        layers = [nn.LayerNorm(common_dim)]
        prev_dim = common_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(act_fn())
            layers.append(nn.Linear(hidden_dim, prev_dim)) #Residual Connections 
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, output_dim))
        layers.append(nn.LayerNorm(output_dim))
        self.net = nn.Sequential(*layers)
                # Alignment layer
        self.alignment_layer = nn.Linear(output_dim, output_dim)

    def forward(self, x, k=None):
        if k is None:
            k = torch.ones(len(x), dtype=torch.long)
        
        result = torch.zeros(len(x), self.common_dim, device=x.device)
        for key in k.unique():
            mask = (k == key)
            result[mask] = self.alignment_layers[str(key.item())](x[mask])
            # result[mask] = self.dropout(result[mask]) #add dropoutout
        #  # Add Transformer encoding here
        result = result.unsqueeze(1)  # Add sequence dimension for Transformer (batch, seq_len, features)
        result = self.transformer(result)
        result = result.squeeze(1)  # Remove sequence dimension
            #Alignment layer
        aligned_result = self.net(result)
        aligned_result = self.alignment_layer(aligned_result)
        return aligned_result #self.net(result)

class ContrastiveModel(pl.LightningModule):
    def __init__(self, num_input_channels, base_channel_size, latent_dim, temperature=0.1, act_fn=nn.GELU, loss_type="contrastive"):
        super().__init__()
        self.temperature = temperature
        self.model = Encoder(num_input_channels, base_channel_size, latent_dim, act_fn)
        self.loss_type = loss_type

        if loss_type == "contrastive":
            self.loss_fn = self.contrastive_loss
        elif loss_type == "mean_contrastive":
            self.loss_fn = self.mean_contrastive
        elif loss_type == "mse":
            self.loss_fn = torch.nn.functional.mse_loss
        elif loss_type == "cosine":
            self.loss_fn = self.cosine_loss

        self.train_losses = []
        self.train_mse = []
        self.train_cosine = []
        self.val_losses = []
        self.val_mse = []
        self.val_cosine = []

        self.train_history = {
            "train_loss": [],
            "train_mse": [],
            "train_cosine": []
        }
        
        self.val_history = {
            "val_loss": [],
            "val_mse": [],
            "val_cosine": []
        }

    def forward(self, x, **kwargs):
        return self.model(x, **kwargs)
 
    def contrastive_loss(self, z_i, z_j):
        z_i = nn.functional.normalize(z_i, dim=1)
        z_j = nn.functional.normalize(z_j, dim=1)
        
        logits = (z_i @ z_j.T) / self.temperature
        targets = torch.arange(logits.shape[0]).long().to(logits.device)
        
        # Original cross-entropy loss
        loss1 = torch.nn.functional.cross_entropy(logits, targets)
        
        # Transposed cross-entropy loss
        loss2 = torch.nn.functional.cross_entropy(logits.T, targets)
        
        # Combined loss
        loss = 0.5 * loss1 + 0.5 * loss2
    
        return loss


    def mean_contrastive(self, z_i, z_j, temperature=1.0):
        return nn.functional.mse_loss(z_i, z_j) + self.contrastive_loss(z_i, z_j, temperature=temperature) / 8
    
    def cosine_loss(self, z_i, z_j, temperature=1.0):
        cosine_similarity = torch.nn.functional.cosine_similarity(z_i, z_j).mean()
        return 1 - cosine_similarity

    def training_step(self, batch, batch_idx):
        x, y, idx = batch #add noise inX

        noise = 0.15 * torch.randn_like(x) 
        x = x + noise 

        y_hat = self(x, k=idx)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.train_losses.append(loss.item())

        mse_loss = torch.nn.functional.mse_loss(y_hat, y)
        cosine_similarity = torch.nn.functional.cosine_similarity(y_hat, y).mean()
        self.train_mse.append(mse_loss.item())
        self.train_cosine.append(cosine_similarity.item())

        return loss

    def validation_step(self, batch, batch_idx):
        x, y, idx = batch
        y_hat = self(x, k=idx)
        loss = self.loss_fn(y_hat, y)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        
        mse_loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log('val_mse_loss', mse_loss, on_epoch=True, prog_bar=True)
        
        cosine_similarity = torch.nn.functional.cosine_similarity(y_hat, y).mean()
        self.log('val_cosine_similarity', cosine_similarity, on_epoch=True, prog_bar=True)
        
        self.val_losses.append(loss.item())
        self.val_mse.append(mse_loss.item())
        self.val_cosine.append(cosine_similarity.item())
        return mse_loss

    def on_train_epoch_end(self):
        self.train_history["train_loss"].append(np.mean(self.train_losses))
        self.train_history["train_mse"].append(np.mean(self.train_mse))
        self.train_history["train_cosine"].append(np.mean(self.train_cosine))
        self.train_losses = []
        self.train_mse = []
        self.train_cosine = []
        super().on_train_epoch_end()

    def on_validation_epoch_end(self):
        self.val_history["val_loss"].append(np.mean(self.val_losses))
        self.val_history["val_mse"].append(np.mean(self.val_mse))
        self.val_history["val_cosine"].append(np.mean(self.val_cosine))
        self.val_losses = []
        self.val_mse = []
        self.val_cosine = []
        super().on_validation_epoch_end()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=50, verbose=True)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

In [None]:
BS = 256  

# Create the final dataset
train_dataset = TensorDataset(cross_train_data, cross_train_embeds, cross_train_subjects)
test_dataset = TensorDataset(cross_train_data, cross_test_embeds, cross_test_subjects)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BS, shuffle=False)

In [None]:
brain_model = ContrastiveModel(num_input_channels=cross_train_data.shape[-1], base_channel_size=[1024], latent_dim=1280, act_fn=nn.GELU, loss_type="contrastive")

trainer = pl.Trainer(max_epochs=10, devices=[2])

# Train the model
trainer.fit(brain_model, train_dataloader, test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | Encoder | 119 M  | train
------------------------------------------
119 M     Trainable params
0         Non-trainable params
119 M     Total params
476.676   Total estimated model params size (MB)
59        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [None]:

#RIDGE REGRESION
from sklearn.linear_model import Ridge

train_pred_embeddings = []
train_gt_embeddings = []

brain_model.eval()
with torch.no_grad():
    for x, y, k in train_dataloader:
        x = x.to("cpu")
        y = y.to("cpu")
        k = k.to("cpu")
        y_hat = brain_model(x, k=k)
        train_pred_embeddings.append(y_hat.cpu().numpy())
        train_gt_embeddings.append(y.cpu().numpy())

train_pred_embeddings = np.vstack(train_pred_embeddings)
train_gt_embeddings = np.vstack(train_gt_embeddings)

# Train Ridge regression model
ridge_reg = Ridge(alpha=50000.0)
ridge_reg.fit(train_pred_embeddings, train_gt_embeddings)

# Predict embeddings for the test set
test_pred_embeddings = []

with torch.no_grad():
    for x, y, k in test_dataloader:
        x = x.to("cpu")
        y = y.to("cpu")
        k = k.to("cpu")
        y_hat = brain_model(x, k=k)
        test_pred_embeddings.append(y_hat.cpu().numpy())

test_pred_embeddings = np.vstack(test_pred_embeddings)

# Apply Ridge regression on the test set embeddings
refined_test_embeddings = ridge_reg.predict(test_pred_embeddings)

# Convert refined embeddings to tensor
refined_test_embeddings = torch.tensor(refined_test_embeddings).float()



In [32]:
BOLD_stimulus_embeddings_test = BOLD_stimulus_embeddings_test.to("cpu")

BOLD_pred_IP = refined_test_embeddings.to("cpu")


In [74]:
NSD_stimulus_embeddings_test = NSD_test_img_embeds.to("cpu")
NSD_pred_IP = refined_test_embeddings.to("cpu")

In [58]:
GOD_stimulus_embeddings_test = GOD_test_image_embds.to("cpu")
GOD_pred_IP = refined_test_embeddings.to("cpu")

In [28]:
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda:2")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
pipeline.set_ip_adapter_scale(1.0)

#RESHAPING THE EMBEDDINGS

y_pred_IP = torch.zeros((2, refined_test_embeddings.shape[0], 1, refined_test_embeddings.shape[-1]), dtype=torch.float16)
y_pred_IP[1] = refined_test_embeddings.unsqueeze(1)
y_pred_IP[0] = torch.zeros_like(refined_test_embeddings.unsqueeze(1))
y_pred_IP = y_pred_IP.transpose(0,1)

y_pred_IP.shape
torch.save(y_pred_IP, "New_results/BOLD_predicted_embeddings_08042025.pt")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]