In [1]:
import numpy as np
import os

In [None]:
def normalize(x, mean=None, std=None):
    mean = np.mean(x) if mean is None else mean
    std = np.std(x) if std is None else std
    return (x - mean) / (std * 1.0)

def identity(x):
    return x

def list_get_all_index(list, value):
    return [i for i, v in enumerate(list) if v == value]


def pad_to_patch_size(x, patch_size):
    # pad the last dimension only
    padding_config = [(0,0)] * (x.ndim - 1) + [(0, patch_size-x.shape[-1]%patch_size)]
    return np.pad(x, padding_config, 'wrap')

def get_stimuli_list(root, sub):
    sti_name = []
    path = os.path.join(root, 'Stimuli_Presentation_Lists', sub)
    folders = os.listdir(path)
    folders.sort()
    for folder in folders:
        if not os.path.isdir(os.path.join(path, folder)):
            continue
        files = os.listdir(os.path.join(path, folder))
        files.sort()
        for file in files:
            if file.endswith('.txt'):
                sti_name += list(np.loadtxt(os.path.join(path, folder, file), dtype=str))

    sti_name_to_return = []
    for name in sti_name:
        if name.startswith('rep_'):
            name = name.replace('rep_', '', 1)
        sti_name_to_return.append(name)
    return sti_name_to_return

def pad_fmri_to_target(fmri_data, target_samples, target_voxels):
    # Padding for the first dimension (samples)
    if fmri_data.shape[0] < target_samples:
        sample_pad_size = target_samples - fmri_data.shape[0]
        fmri_data = np.pad(fmri_data, ((0, sample_pad_size), (0, 0)), mode='constant')
    elif fmri_data.shape[0] > target_samples:
        fmri_data = fmri_data[:target_samples, :]  # Crop to target sample size

    # Padding for the second dimension (voxels)
    if fmri_data.shape[1] < target_voxels:
        voxel_pad_size = target_voxels - fmri_data.shape[1]
        fmri_data = np.pad(fmri_data, ((0, 0), (0, voxel_pad_size)), mode='constant')
    elif fmri_data.shape[1] > target_voxels:
        fmri_data = fmri_data[:, :target_voxels]  # Crop to target voxel size

    return fmri_data

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from torchvision import transforms, models
from tqdm import tqdm

class BOLD5000_Contrastive_Dataset(Dataset):
    def __init__(self, fmri, image, fmri_transform=None, image_transform=None, num_voxels=0):
        self.fmri = torch.tensor(fmri)
        self.image = image
        self.fmri_transform = fmri_transform
        self.image_transform = image_transform
        self.num_voxels = num_voxels
        
        self.image = np.transpose(image, (0, 3, 1, 2)) 
        self.image = self.image.astype(np.float32) / 255.0
        
        self.resnet_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) 
        self.resnet_model.eval()
        self.resnet_model = torch.nn.Sequential(*list(self.resnet_model.children())[:-1])

        # self.image_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")  
        # self.image_encoder = torch.quantization.quantize_dynamic(self.image_encoder, dtype=torch.qint8)
        # self.image_embeddings = self.image_encoder.get_image_features(pixel_values=inputs)

        inputs = torch.from_numpy(self.image)
        self.image_embeddings = self.resnet_model(inputs).squeeze(-1).squeeze(-1)

    def __len__(self):
        return len(self.fmri)

    def __getitem__(self, idx):
        fmri = self.fmri[idx]
        img = self.image[idx]

        if np.random.rand() > 0.5:
            label = 1
            return {'fmri': fmri, 'image': img, 'label': torch.tensor(label)}
        else:
            dissimilar_idx1 = np.random.randint(0, len(self.fmri))
            dissimilar_idx2 = np.random.randint(0, len(self.fmri))
            fmri_dissimilar = self.fmri[dissimilar_idx1]
            img_dissimilar = self.image[dissimilar_idx2]

            label = 0
            return {'fmri': fmri_dissimilar, 'image': img_dissimilar, 'label': torch.tensor(label)}

def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transform=None,
                            image_transform=None, subjects=['CSI1', 'CSI2', 'CSI3', 'CSI4'], 
                            include_nonavg_test=False, include_image_caption=False, include_image_clip=False):
    roi_list = ['EarlyVis', 'LOC', 'OPA', 'PPA', 'RSC']
    fmri_path = os.path.join(path, 'BOLD5000_GLMsingle_ROI_betas/py')
    img_path = os.path.join(path, 'BOLD5000_Stimuli')
    imgs_dict = np.load(os.path.join(img_path, 'Scene_Stimuli/Presented_Stimuli/img_dict.npy'), allow_pickle=True).item()
    
    repeated_imgs_list = np.loadtxt(os.path.join(img_path, 'Scene_Stimuli', 'repeated_stimuli_113_list.txt'), dtype=str)

    fmri_files = [f for f in os.listdir(fmri_path) if f.endswith('.npy')]
    fmri_files.sort()

    fmri_train_major = []
    fmri_test_major = []
    img_train_major = []
    img_test_major = []

    for sub in subjects:
        fmri_data_sub = []
        for roi in roi_list:
            for npy in fmri_files:
                if npy.endswith('.npy') and sub in npy and roi in npy:
                    fmri_data_sub.append(np.load(os.path.join(fmri_path, npy)))
        fmri_data_sub = np.concatenate(fmri_data_sub, axis=-1) 
        fmri_data_sub = normalize(pad_to_patch_size(fmri_data_sub, patch_size))
      
        img_files = get_stimuli_list(img_path, sub)
        img_data_sub = [imgs_dict[name] for name in img_files]

        test_idx = [list_get_all_index(img_files, img) for img in repeated_imgs_list]
        test_idx = [i for i in test_idx if len(i) > 0]  
        test_fmri = np.stack([fmri_data_sub[idx].mean(axis=0) for idx in test_idx])
        test_img = np.stack([img_data_sub[idx[0]] for idx in test_idx])

        test_idx_flatten = []
        for idx in test_idx:
            test_idx_flatten += idx 
        if include_nonavg_test:
            test_fmri = np.concatenate([test_fmri, fmri_data_sub[test_idx_flatten]], axis=0)
            test_img = np.concatenate([test_img, np.stack([img_data_sub[idx] for idx in test_idx_flatten])], axis=0)

        train_idx = [i for i in range(len(img_files)) if i not in test_idx_flatten]
        train_img = np.stack([img_data_sub[idx] for idx in train_idx])
        train_fmri = fmri_data_sub[train_idx]
        train_fmri = pad_fmri_to_target(train_fmri, target_samples=4803, target_voxels=512)
        test_fmri = pad_fmri_to_target(test_fmri, target_samples=4803, target_voxels=512)

        fmri_train_major.append(train_fmri)
        fmri_test_major.append(test_fmri)
        img_train_major.append(train_img)
        img_test_major.append(test_img)

    fmri_train_major = np.concatenate(fmri_train_major, axis=0)[:100,:]
    fmri_test_major = np.concatenate(fmri_test_major, axis=0)[:100,:]
    img_train_major = np.concatenate(img_train_major, axis=0)[:100,:]
    img_test_major = np.concatenate(img_test_major, axis=0)[:100,:]

    num_voxels = fmri_train_major.shape[-1]
    
    # fmri_enc = FMRITransformerEncoder(cfg)  # Initialize the fMRI transformer encoder
    
    return (BOLD5000_Contrastive_Dataset(
                fmri_train_major, img_train_major, fmri_transform=fmri_transform, image_transform=image_transform, num_voxels=num_voxels), 
            BOLD5000_Contrastive_Dataset(
                fmri_test_major, img_test_major, fmri_transform=fmri_transform, image_transform=image_transform, num_voxels=num_voxels))



In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

class FMRIEncoderDecoder(nn.Module):
    def __init__(self, cfg: dict):
        super(FMRIEncoderDecoder, self).__init__()
        self.cfg = cfg

        # Encoder
        max_seq_len = max(cfg['model']['max_seq_len'], cfg['model']['hidden_size'])
        self.positional_embedding = nn.Embedding(max_seq_len, cfg['model']['hidden_size'])

        enc_layer = nn.TransformerEncoderLayer(
            d_model=cfg['model']['hidden_size'],
            nhead=cfg['model']['num_attention_heads'],
            dim_feedforward=cfg['model']['dim_feedforward'],
            activation=cfg['model']['activation'],
            dropout=cfg['model']['dropout'],
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg['model']['num_layers'])

        # Decoder
        dec_layer = nn.TransformerDecoderLayer(
            d_model=cfg['model']['hidden_size'],
            nhead=cfg['model']['num_attention_heads'],
            dim_feedforward=cfg['model']['dim_feedforward'],
            activation=cfg['model']['activation'],
            dropout=cfg['model']['dropout'],
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=cfg['model']['num_layers'])

        self.fc = nn.Linear(cfg['model']['hidden_size'], 512)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, num_voxels = x.size()

        # Encoder
        x = x.unsqueeze(1)  
        positional_embedding = self.positional_embedding(torch.arange(num_voxels, device=x.device)).unsqueeze(0)
        positional_embedding = positional_embedding.expand(batch_size, num_voxels, -1)
        x = x + positional_embedding
        memory = self.encoder(x)

        # Decoder
        x = self.decoder(memory, memory) 
        x = torch.mean(x, dim=1) 
        x = self.fc(x)  

        return x


In [None]:
import torch
from torch.utils.data import DataLoader
from torch import optim
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = '/home2/prateekj/tether-assgn/notebooks/fmri_encoder_decoder.pth'

cfg = {
    'model': {
        'max_seq_len': 100,
        'hidden_size': 512,
        'num_attention_heads': 8,
        'dim_feedforward': 2048,
        'activation': 'relu',
        'dropout': 0.1,
        'num_layers': 6
    }
}
fmri_encoder_decoder = FMRIEncoderDecoder(cfg).to(device)
fmri_encoder_decoder.load_state_dict(torch.load(model_path))
fmri_encoder_decoder.eval() 

fmri_encoder = fmri_encoder_decoder.encoder

fmri_encoder

TransformerEncoder(
  (layers): ModuleList(
    (0-5): 6 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=512, bias=True)
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
)

In [None]:
import torch
from torchvision import models, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

resnet_model = models.resnet18(pretrained=True)
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-1]) 
resnet_model.to(device)
resnet_model.eval()



Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

In [None]:
import torch
import torch.nn as nn

class FMRIImageAlignModel(nn.Module):
    def __init__(self, fmri_encoder, output_dim=512):
        super(FMRIImageAlignModel, self).__init__()
        self.fmri_encoder = fmri_encoder  
        self.fc_layer = nn.Linear(512, output_dim)  

    def forward(self, fmri_data):
        fmri_embedding = self.fmri_encoder(fmri_data)
        return self.fc_layer(fmri_embedding)

In [None]:
cfg = {
    'model': {
        'max_seq_len': 512,  
        'hidden_size': 512, 
        'num_attention_heads': 8,
        'dim_feedforward': 2048,
        'activation': 'relu',
        'dropout': 0.1,
        'num_layers': 6
    }
}

In [None]:
def contrastive_loss(anchor, positive, label, margin=1.0):
    cosine_similarity = torch.nn.functional.cosine_similarity(anchor, positive)
    loss = (1 - label) * torch.clamp(margin - cosine_similarity, min=0) + label * (1 - cosine_similarity)
    return loss.mean()

In [None]:
fmri_encoder_decoder = FMRIEncoderDecoder(cfg).to(device)
fmri_encoder_decoder.load_state_dict(torch.load('/home2/prateekj/tether-assgn/notebooks/fmri_encoder_decoder.pth'))
fmri_encoder_decoder.eval()

data_path = '/home2/prateekj/tether-assgn/data/BOLD5000'
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset, test_dataset = create_BOLD5000_dataset(
    path=data_path,
    patch_size=16,
    fmri_transform=lambda x: torch.tensor(x, dtype=torch.float32),
    image_transform=image_transform,
    subjects=['CSI1', 'CSI2', 'CSI3', 'CSI4'],
    include_image_clip=False  
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
model = FMRIImageAlignModel(fmri_encoder_decoder).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir='./logs-contrast')

# Training loop
for epoch in range(25):
    model.train()
    running_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/25"):
        optimizer.zero_grad()

        fmri_data = batch['fmri'].to(device)
        image_data = batch['image'].to(device)
        labels = batch['label'].to(device)

        with torch.no_grad():
            image_embedding = resnet_model(image_data).squeeze(3).squeeze(2)

        fmri_embedding = model.fmri_encoder(fmri_data)
        fmri_embedding_aligned = model(fmri_embedding)

        loss = contrastive_loss(fmri_embedding_aligned, image_embedding, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    writer.add_scalar('Loss/train', avg_train_loss, epoch)
    print(f"Epoch {epoch + 1}: Train Loss: {avg_train_loss:.4f}")

    model.eval()  
    running_val_loss = 0.0
    with torch.no_grad(): 
        for batch in tqdm(test_loader, desc="Validation"):
            fmri_data = batch['fmri'].to(device)
            image_data = batch['image'].to(device)
            labels = batch['label'].to(device)

            image_embedding = resnet_model(image_data).squeeze(3).squeeze(2)
            fmri_embedding = model.fmri_encoder(fmri_data)

            fmri_embedding_aligned = model(fmri_embedding)
            loss = contrastive_loss(fmri_embedding_aligned, image_embedding, labels)

            running_val_loss += loss.item()

    avg_val_loss = running_val_loss / len(test_loader)
    writer.add_scalar('Loss/val', avg_val_loss, epoch)
    print(f"Validation Loss: {avg_val_loss:.4f}")

torch.save(model.state_dict(), "fmri_image_align_model.pth")
writer.close()

Epoch 1/25: 100%|██████████| 25/25 [00:07<00:00,  3.28it/s]


Epoch 1: Train Loss: 0.3361


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.50it/s]


Validation Loss: 0.2324


Epoch 2/25: 100%|██████████| 25/25 [00:05<00:00,  4.52it/s]


Epoch 2: Train Loss: 0.2269


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.26it/s]


Validation Loss: 0.2361


Epoch 3/25: 100%|██████████| 25/25 [00:05<00:00,  4.47it/s]


Epoch 3: Train Loss: 0.2229


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.14it/s]


Validation Loss: 0.2402


Epoch 4/25: 100%|██████████| 25/25 [00:05<00:00,  4.46it/s]


Epoch 4: Train Loss: 0.2206


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.76it/s]


Validation Loss: 0.2299


Epoch 5/25: 100%|██████████| 25/25 [00:05<00:00,  4.31it/s]


Epoch 5: Train Loss: 0.2142


Validation: 100%|██████████| 25/25 [00:01<00:00, 13.38it/s]


Validation Loss: 0.2361


Epoch 6/25: 100%|██████████| 25/25 [00:05<00:00,  4.38it/s]


Epoch 6: Train Loss: 0.2184


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.19it/s]


Validation Loss: 0.2319


Epoch 7/25: 100%|██████████| 25/25 [00:05<00:00,  4.32it/s]


Epoch 7: Train Loss: 0.2160


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.24it/s]


Validation Loss: 0.2280


Epoch 8/25: 100%|██████████| 25/25 [00:05<00:00,  4.36it/s]


Epoch 8: Train Loss: 0.2151


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.80it/s]


Validation Loss: 0.2317


Epoch 9/25: 100%|██████████| 25/25 [00:05<00:00,  4.30it/s]


Epoch 9: Train Loss: 0.2228


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.84it/s]


Validation Loss: 0.2265


Epoch 10/25: 100%|██████████| 25/25 [00:05<00:00,  4.34it/s]


Epoch 10: Train Loss: 0.2140


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.77it/s]


Validation Loss: 0.2299


Epoch 11/25: 100%|██████████| 25/25 [00:05<00:00,  4.36it/s]


Epoch 11: Train Loss: 0.2206


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.00it/s]


Validation Loss: 0.2278


Epoch 12/25: 100%|██████████| 25/25 [00:05<00:00,  4.34it/s]


Epoch 12: Train Loss: 0.2178


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.38it/s]


Validation Loss: 0.2326


Epoch 13/25: 100%|██████████| 25/25 [00:05<00:00,  4.35it/s]


Epoch 13: Train Loss: 0.2190


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.69it/s]


Validation Loss: 0.2318


Epoch 14/25: 100%|██████████| 25/25 [00:05<00:00,  4.42it/s]


Epoch 14: Train Loss: 0.2142


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.01it/s]


Validation Loss: 0.2350


Epoch 15/25: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s]


Epoch 15: Train Loss: 0.2164


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.23it/s]


Validation Loss: 0.2253


Epoch 16/25: 100%|██████████| 25/25 [00:05<00:00,  4.40it/s]


Epoch 16: Train Loss: 0.2156


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.29it/s]


Validation Loss: 0.2310


Epoch 17/25: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s]


Epoch 17: Train Loss: 0.2149


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.00it/s]


Validation Loss: 0.2248


Epoch 18/25: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s]


Epoch 18: Train Loss: 0.2173


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.68it/s]


Validation Loss: 0.2282


Epoch 19/25: 100%|██████████| 25/25 [00:05<00:00,  4.37it/s]


Epoch 19: Train Loss: 0.2170


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.15it/s]


Validation Loss: 0.2335


Epoch 20/25: 100%|██████████| 25/25 [00:05<00:00,  4.43it/s]


Epoch 20: Train Loss: 0.2164


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.98it/s]


Validation Loss: 0.2366


Epoch 21/25: 100%|██████████| 25/25 [00:05<00:00,  4.40it/s]


Epoch 21: Train Loss: 0.2147


Validation: 100%|██████████| 25/25 [00:01<00:00, 16.80it/s]


Validation Loss: 0.2283


Epoch 22/25: 100%|██████████| 25/25 [00:05<00:00,  4.34it/s]


Epoch 22: Train Loss: 0.2167


Validation: 100%|██████████| 25/25 [00:01<00:00, 15.22it/s]


Validation Loss: 0.2301


Epoch 23/25: 100%|██████████| 25/25 [00:05<00:00,  4.38it/s]


Epoch 23: Train Loss: 0.2125


Validation: 100%|██████████| 25/25 [00:01<00:00, 14.24it/s]


Validation Loss: 0.2306


Epoch 24/25: 100%|██████████| 25/25 [00:05<00:00,  4.39it/s]


Epoch 24: Train Loss: 0.2180


Validation: 100%|██████████| 25/25 [00:01<00:00, 15.30it/s]


Validation Loss: 0.2273


Epoch 25/25: 100%|██████████| 25/25 [00:05<00:00,  4.43it/s]


Epoch 25: Train Loss: 0.2163


Validation: 100%|██████████| 25/25 [00:01<00:00, 17.59it/s]


Validation Loss: 0.2310
