In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

class FMRIEncoderDecoder(nn.Module):
    def __init__(self, cfg: dict):
        super(FMRIEncoderDecoder, self).__init__()
        self.cfg = cfg

        # Encoder
        max_seq_len = max(cfg['model']['max_seq_len'], cfg['model']['hidden_size'])
        self.positional_embedding = nn.Embedding(max_seq_len, cfg['model']['hidden_size'])

        enc_layer = nn.TransformerEncoderLayer(
            d_model=cfg['model']['hidden_size'],
            nhead=cfg['model']['num_attention_heads'],
            dim_feedforward=cfg['model']['dim_feedforward'],
            activation=cfg['model']['activation'],
            dropout=cfg['model']['dropout'],
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg['model']['num_layers'])

        # Decoder
        dec_layer = nn.TransformerDecoderLayer(
            d_model=cfg['model']['hidden_size'],
            nhead=cfg['model']['num_attention_heads'],
            dim_feedforward=cfg['model']['dim_feedforward'],
            activation=cfg['model']['activation'],
            dropout=cfg['model']['dropout'],
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=cfg['model']['num_layers'])

        self.fc = nn.Linear(cfg['model']['hidden_size'], 512)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, num_voxels = x.size()

        # Encoder
        x = x.unsqueeze(1)  # Add sequence dimension: [batch_size, 1, num_voxels]
        positional_embedding = self.positional_embedding(torch.arange(num_voxels, device=x.device)).unsqueeze(0)
        positional_embedding = positional_embedding.expand(batch_size, num_voxels, -1)
        x = x + positional_embedding
        memory = self.encoder(x)

        # Decoder (identity: decoding back the same representation)
        x = self.decoder(memory, memory)  # Use memory itself as the input to the decoder
        x = torch.mean(x, dim=1)  # Average pooling over the sequence dimension
        x = self.fc(x)  # Map to 512-d feature vector

        return x


In [2]:
def normalize(x, mean=None, std=None):
    mean = np.mean(x) if mean is None else mean
    std = np.std(x) if std is None else std
    return (x - mean) / (std * 1.0)

def identity(x):
    return x

def list_get_all_index(list, value):
    return [i for i, v in enumerate(list) if v == value]


def pad_to_patch_size(x, patch_size):
    # pad the last dimension only
    padding_config = [(0,0)] * (x.ndim - 1) + [(0, patch_size-x.shape[-1]%patch_size)]
    return np.pad(x, padding_config, 'wrap')

def get_stimuli_list(root, sub):
    sti_name = []
    path = os.path.join(root, 'Stimuli_Presentation_Lists', sub)
    folders = os.listdir(path)
    folders.sort()
    for folder in folders:
        if not os.path.isdir(os.path.join(path, folder)):
            continue
        files = os.listdir(os.path.join(path, folder))
        files.sort()
        for file in files:
            if file.endswith('.txt'):
                sti_name += list(np.loadtxt(os.path.join(path, folder, file), dtype=str))

    sti_name_to_return = []
    for name in sti_name:
        if name.startswith('rep_'):
            name = name.replace('rep_', '', 1)
        sti_name_to_return.append(name)
    return sti_name_to_return

def pad_fmri_to_target(fmri_data, target_samples, target_voxels):
    # Padding for the first dimension (samples)
    if fmri_data.shape[0] < target_samples:
        sample_pad_size = target_samples - fmri_data.shape[0]
        fmri_data = np.pad(fmri_data, ((0, sample_pad_size), (0, 0)), mode='constant')
    elif fmri_data.shape[0] > target_samples:
        fmri_data = fmri_data[:target_samples, :]  # Crop to target sample size

    # Padding for the second dimension (voxels)
    if fmri_data.shape[1] < target_voxels:
        voxel_pad_size = target_voxels - fmri_data.shape[1]
        fmri_data = np.pad(fmri_data, ((0, 0), (0, voxel_pad_size)), mode='constant')
    elif fmri_data.shape[1] > target_voxels:
        fmri_data = fmri_data[:, :target_voxels]  # Crop to target voxel size

    return fmri_data



In [None]:
# from transformers import CLIPModel
from torchvision import transforms, models

class BOLD5000_ResNet_dataset(Dataset):
    def __init__(self, fmri, image, fmri_transform=identity, image_transform=identity, num_voxels=0, fmri_encoder=None):
        # self.fmri = fmri
        self.fmri = torch.tensor(fmri)
        self.image = image
        self.fmri_transform = fmri_transform
        self.image_transform = image_transform
        self.num_voxels = num_voxels
        self.image = np.transpose(image, (0, 3, 1, 2))  
        self.image = self.image.astype(np.float32) / 255.0
        
        self.resnet_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) 
        self.resnet_model.eval()
        self.resnet_model = torch.nn.Sequential(*list(self.resnet_model.children())[:-1])
        
        inputs = torch.from_numpy(self.image)
        self.image_embeddings = self.resnet_model(inputs).squeeze(-1).squeeze(-1) 
        
        # print(f"Image embeddings shape: {self.image_embeddings.shape}")
        
        self.fmri_encoder = fmri_encoder

    def __len__(self):
        return len(self.fmri)
    
    def __getitem__(self, idx):
        fmri = self.fmri[idx]
        fmri = self.fmri_transform(fmri)
        
        fmri_embedding = self.fmri_encoder(fmri.unsqueeze(0)) 
        
        return {
            'fmri': fmri_embedding, 
            'image': self.image_embeddings[idx].unsqueeze(0) 
        }

def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transform=None,
                            image_transform=None, subjects=['CSI1', 'CSI2', 'CSI3', 'CSI4'], 
                            include_nonavg_test=False, include_image_caption=False, include_image_clip=False):
    roi_list = ['EarlyVis', 'LOC', 'OPA', 'PPA', 'RSC']
    fmri_path = os.path.join(path, 'BOLD5000_GLMsingle_ROI_betas/py')
    img_path = os.path.join(path, 'BOLD5000_Stimuli')
    imgs_dict = np.load(os.path.join(img_path, 'Scene_Stimuli/Presented_Stimuli/img_dict.npy'), allow_pickle=True).item()
    
    repeated_imgs_list = np.loadtxt(os.path.join(img_path, 'Scene_Stimuli', 'repeated_stimuli_113_list.txt'), dtype=str)

    fmri_files = [f for f in os.listdir(fmri_path) if f.endswith('.npy')]
    fmri_files.sort()

    fmri_train_major = []
    fmri_test_major = []
    img_train_major = []
    img_test_major = []
    caption_train_major = []
    caption_test_major = []

    for sub in subjects:
        fmri_data_sub = []
        for roi in roi_list:
            for npy in fmri_files:
                if npy.endswith('.npy') and sub in npy and roi in npy:
                    fmri_data_sub.append(np.load(os.path.join(fmri_path, npy)))
        fmri_data_sub = np.concatenate(fmri_data_sub, axis=-1)  # concatenate all rois
        fmri_data_sub = normalize(pad_to_patch_size(fmri_data_sub, patch_size))
      
        img_files = get_stimuli_list(img_path, sub)
        img_data_sub = [imgs_dict[name] for name in img_files]
        # print("Image data sub shape: ", np.shape(img_data_sub))

        test_idx = [list_get_all_index(img_files, img) for img in repeated_imgs_list]
        test_idx = [i for i in test_idx if len(i) > 0]  # remove empty list for CSI4
        test_fmri = np.stack([fmri_data_sub[idx].mean(axis=0) for idx in test_idx])
        test_img = np.stack([img_data_sub[idx[0]] for idx in test_idx])

        test_idx_flatten = []
        for idx in test_idx:
            test_idx_flatten += idx  
        if include_nonavg_test:
            test_fmri = np.concatenate([test_fmri, fmri_data_sub[test_idx_flatten]], axis=0)
            test_img = np.concatenate([test_img, np.stack([img_data_sub[idx] for idx in test_idx_flatten])], axis=0)

        train_idx = [i for i in range(len(img_files)) if i not in test_idx_flatten]
        train_img = np.stack([img_data_sub[idx] for idx in train_idx])
        train_fmri = fmri_data_sub[train_idx]
        train_fmri = pad_fmri_to_target(train_fmri, target_samples=4803, target_voxels=512)
        test_fmri = pad_fmri_to_target(test_fmri, target_samples=4803, target_voxels=512)

        fmri_train_major.append(train_fmri)
        fmri_test_major.append(test_fmri)
        img_train_major.append(train_img)
        img_test_major.append(test_img)

    fmri_train_major = np.concatenate(fmri_train_major, axis=0)[:64,:]
    fmri_test_major = np.concatenate(fmri_test_major, axis=0)[:64,:]
    img_train_major = np.concatenate(img_train_major, axis=0)[:64,:]
    img_test_major = np.concatenate(img_test_major, axis=0)[:64,:]

    # print(img_train_major.shape)

    num_voxels = fmri_train_major.shape[-1]
    
    # fmri_enc = FMRITransformerEncoder(cfg)  # Initialize the fMRI transformer encoder
    
    
    if include_image_clip:  
        return (BOLD5000_ResNet_dataset(
                    fmri_train_major, img_train_major, fmri_transform, image_transform, num_voxels, fmri_encoder=fmri_enc), 
                BOLD5000_ResNet_dataset(
                    fmri_test_major, img_test_major, fmri_transform, image_transform, num_voxels, fmri_encoder=fmri_enc))
    else:  # Use BOLD5000_dataset if include_image_clip is False
        return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform, num_voxels),
                BOLD5000_dataset(fmri_test_major, img_test_major, fmri_transform, image_transform, num_voxels))
           

from torch.utils.data import Dataset
import numpy as np
from PIL import Image

class BOLD5000_dataset(Dataset):
    def __init__(self, fmri, image, fmri_transform=identity, image_transform=identity, num_voxels=0):
        self.fmri = fmri
        self.image = image
        self.fmri_transform = fmri_transform
        self.image_transform = image_transform
        self.num_voxels = num_voxels
    
    def __len__(self):
        return len(self.fmri)
    
    def __getitem__(self, index):
        
        fmri = self.fmri[index]
        img = self.image[index] / 255.0  
        fmri = np.expand_dims(fmri, axis=0)
        
        if isinstance(img, np.ndarray):
            img = Image.fromarray(np.uint8(img))
        
        fmri = self.fmri_transform(fmri) 
        img = self.image_transform(img) 
        return {'fmri': fmri, 'image': img}


class BOLD5000_CLIP_dataset(Dataset):
    def __init__(self, fmri, image, fmri_transform=identity, image_transform=identity, num_voxels=0, fmri_encoder=None):
        self.fmri = fmri
        self.image = image
        self.fmri_transform = fmri_transform
        self.image_transform = image_transform
        self.num_voxels = num_voxels
        
        resize_transform = transforms.Resize((224, 224)) 
        self.image = np.transpose(image, (0, 3, 1, 2))  
        self.image = np.array([resize_transform(torch.from_numpy(img)) for img in self.image])

        self.image = self.image.astype(np.float32) / 255.0
    
        self.image_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")  
        
        self.image_encoder = torch.quantization.quantize_dynamic(self.image_encoder, dtype=torch.qint8)
        
        inputs = torch.from_numpy(self.image)
        self.image_embeddings = self.image_encoder.get_image_features(pixel_values=inputs)
        self.fmri_encoder = fmri_encoder

    def __len__(self):
        return len(self.fmri)
    
    def __getitem__(self, idx):
        fmri = self.fmri[idx]
        fmri = self.fmri_transform(fmri)
        
        fmri_embedding = self.fmri_encoder(fmri.unsqueeze(0))  
        
        img = self.image[idx]
        img = self.image_transform(img) 
        
        return {
            'fmri': fmri_embedding, 
            'image': self.image_embeddings[idx] 
        }

In [None]:
import torch
from torch.utils.data import DataLoader
from torch import optim
from torch import nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir='./logs')  

data_path = '/home2/prateekj/tether-assgn/data/BOLD5000' 
from torchvision import transforms

image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_dataset, test_dataset = create_BOLD5000_dataset(
    path=data_path,
    patch_size=16,
    fmri_transform=lambda x: torch.tensor(x, dtype=torch.float32),
    image_transform=image_transform,
    subjects=['CSI1', 'CSI2', 'CSI3', 'CSI4'],
    include_nonavg_test=False,
    include_image_clip=False
)

# Hyperparameters
batch_size = 4
epochs = 25
learning_rate = 1e-4
cfg = {
    'model': {
        'max_seq_len': 100,
        'hidden_size': 512,
        'num_attention_heads': 8,
        'dim_feedforward': 2048,
        'activation': 'relu',
        'dropout': 0.1,
        'num_layers': 6
    }
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = FMRIEncoderDecoder(cfg).to(device) 
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
        optimizer.zero_grad()

        fmri_data = batch['fmri'].to(device) 
        output = model(fmri_data.squeeze(1))
        loss = loss_fn(output, fmri_data)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    writer.add_scalar('Loss/train', avg_train_loss, epoch)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            fmri_data = batch['fmri'].to(device)

            output = model(fmri_data.squeeze(1))
            loss = loss_fn(output, fmri_data)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(test_loader)
    writer.add_scalar('Loss/val', avg_val_loss, epoch)

    print(f"Epoch {epoch + 1}: Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

torch.save(model.state_dict(), "fmri_encoder_decoder.pth")
writer.close()


Using device: cuda


  return F.mse_loss(input, target, reduction=self.reduction)
Epoch 1/25: 100%|██████████| 16/16 [00:06<00:00,  2.65it/s]


Epoch 1: Train Loss: 0.7727, Validation Loss: 0.4032


Epoch 2/25: 100%|██████████| 16/16 [00:01<00:00,  8.43it/s]


Epoch 2: Train Loss: 0.6651, Validation Loss: 0.3878


Epoch 3/25: 100%|██████████| 16/16 [00:01<00:00,  8.41it/s]


Epoch 3: Train Loss: 0.6583, Validation Loss: 0.3829


Epoch 4/25: 100%|██████████| 16/16 [00:01<00:00,  8.36it/s]


Epoch 4: Train Loss: 0.6581, Validation Loss: 0.3788


Epoch 5/25: 100%|██████████| 16/16 [00:01<00:00,  8.08it/s]


Epoch 5: Train Loss: 0.6600, Validation Loss: 0.3829


Epoch 6/25: 100%|██████████| 16/16 [00:01<00:00,  8.36it/s]


Epoch 6: Train Loss: 0.6557, Validation Loss: 0.3765


Epoch 7/25: 100%|██████████| 16/16 [00:01<00:00,  8.31it/s]


Epoch 7: Train Loss: 0.6535, Validation Loss: 0.3777


Epoch 8/25: 100%|██████████| 16/16 [00:01<00:00,  8.36it/s]


Epoch 8: Train Loss: 0.6565, Validation Loss: 0.3866


Epoch 9/25: 100%|██████████| 16/16 [00:01<00:00,  8.37it/s]


Epoch 9: Train Loss: 0.6544, Validation Loss: 0.3839


Epoch 10/25: 100%|██████████| 16/16 [00:02<00:00,  7.76it/s]


Epoch 10: Train Loss: 0.6611, Validation Loss: 0.3738


Epoch 11/25: 100%|██████████| 16/16 [00:01<00:00,  8.31it/s]


Epoch 11: Train Loss: 0.6556, Validation Loss: 0.3834


Epoch 12/25: 100%|██████████| 16/16 [00:01<00:00,  8.32it/s]


Epoch 12: Train Loss: 0.6541, Validation Loss: 0.3739


Epoch 13/25: 100%|██████████| 16/16 [00:01<00:00,  8.34it/s]


Epoch 13: Train Loss: 0.6532, Validation Loss: 0.3768


Epoch 14/25: 100%|██████████| 16/16 [00:01<00:00,  8.31it/s]


Epoch 14: Train Loss: 0.6533, Validation Loss: 0.3804


Epoch 15/25: 100%|██████████| 16/16 [00:01<00:00,  8.14it/s]


Epoch 15: Train Loss: 0.6522, Validation Loss: 0.3759


Epoch 16/25: 100%|██████████| 16/16 [00:01<00:00,  8.32it/s]


Epoch 16: Train Loss: 0.6525, Validation Loss: 0.3776


Epoch 17/25: 100%|██████████| 16/16 [00:01<00:00,  8.29it/s]


Epoch 17: Train Loss: 0.6514, Validation Loss: 0.3790


Epoch 18/25: 100%|██████████| 16/16 [00:01<00:00,  8.17it/s]


Epoch 18: Train Loss: 0.6490, Validation Loss: 0.3758


Epoch 19/25: 100%|██████████| 16/16 [00:01<00:00,  8.24it/s]


Epoch 19: Train Loss: 0.6483, Validation Loss: 0.3759


Epoch 20/25: 100%|██████████| 16/16 [00:01<00:00,  8.24it/s]


Epoch 20: Train Loss: 0.6474, Validation Loss: 0.3742


Epoch 21/25: 100%|██████████| 16/16 [00:02<00:00,  7.97it/s]


Epoch 21: Train Loss: 0.6502, Validation Loss: 0.3783


Epoch 22/25: 100%|██████████| 16/16 [00:01<00:00,  8.17it/s]


Epoch 22: Train Loss: 0.6454, Validation Loss: 0.3728


Epoch 23/25: 100%|██████████| 16/16 [00:01<00:00,  8.25it/s]


Epoch 23: Train Loss: 0.6477, Validation Loss: 0.3754


Epoch 24/25: 100%|██████████| 16/16 [00:01<00:00,  8.09it/s]


Epoch 24: Train Loss: 0.6473, Validation Loss: 0.3757


Epoch 25/25: 100%|██████████| 16/16 [00:01<00:00,  8.27it/s]


Epoch 25: Train Loss: 0.6483, Validation Loss: 0.3785
