In [1]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.nn.parallel import DistributedDataParallel
import argparse
import time
import timm.optim.optim_factory as optim_factory
import datetime
import matplotlib.pyplot as plt
import wandb
import copy
from sklearn.metrics import accuracy_score

# from config import Config_MBM_SPIKE
from config import Config_MBM_fMRI
from dataset import allen_dataset_1d, allen_dataset_2d, allen_dataset_static_grating_1d
# from sc_mbm.mae_for_spike_train import MAEforSPIKE, spike_encoder
from sc_mbm.mae_for_fmri import MAEforFMRI, fmri_encoder, fmri_classifier
# from sc_mbm.trainer import train_one_epoch_spike
from sc_mbm.trainer import train_one_epoch
from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler
from sc_mbm.utils import save_model

In [2]:
os.environ["WANDB_START_METHOD"] = "thread"
os.environ['WANDB_DIR'] = "."

class wandb_logger:
    def __init__(self, config):
        wandb.init(
                    project="mind-vis",
                    anonymous="allow",
                    group='stageA_sc-mbm',
                    config=config,
                    reinit=True)

        self.config = config
        self.step = None
    
    def log(self, name, data, step=None):
        if step is None:
            wandb.log({name: data})
        else:
            wandb.log({name: data}, step=step)
            self.step = step
    
    def watch_model(self, *args, **kwargs):
        wandb.watch(*args, **kwargs)

    def log_image(self, name, fig):
        if self.step is None:
            wandb.log({name: wandb.Image(fig)})
        else:
            wandb.log({name: wandb.Image(fig)}, step=self.step)

    def finish(self):
        wandb.finish(quiet=True)

In [3]:
def create_readme(config, path):
    print(config.__dict__)
    with open(os.path.join(path, 'README.md'), 'w+') as f:
        print(config.__dict__, file=f)

In [4]:
def create_readme(config, path):
    print(config.__dict__)
    with open(os.path.join(path, 'README.md'), 'w+') as f:
        print(config.__dict__, file=f)

In [5]:
def fmri_transform(x, sparse_rate=0.2):
    # x: 1, num_voxels
    x_aug = copy.deepcopy(x)
    idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False)
    x_aug[idx] = 0
    return torch.FloatTensor(x_aug)

In [6]:
config = torch.load('../results/spike_pretrain/09-12-2024-00-07-46/checkpoints/checkpoint.pth')['config']

  config = torch.load('../results/spike_pretrain/09-12-2024-00-07-46/checkpoints/checkpoint.pth')['config']


In [7]:
config.root_path

'../'

In [8]:
if torch.cuda.device_count() > 1:
    torch.cuda.set_device(config.local_rank) 
    torch.distributed.init_process_group(backend='nccl')
output_path = os.path.join(config.root_path, 'results', 'spike_pretrain',  '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")))
# output_path = os.path.join(config.root_path, 'results', 'fmri_pretrain')
config.output_path = output_path
logger = wandb_logger(config) if config.local_rank == 0 else None

if config.local_rank == 0:
    os.makedirs(output_path, exist_ok=True)
    create_readme(config, output_path)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mzhaizhongyuan[0m ([33m11785-bhiksha[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'lr': 0.00025, 'min_lr': 0.0, 'weight_decay': 0.05, 'num_epoch': 500, 'warmup_epochs': 40, 'batch_size': 32, 'clip_grad': 0.8, 'mask_ratio': 0.75, 'patch_size': 8, 'embed_dim': 128, 'decoder_embed_dim': 512, 'depth': 24, 'num_heads': 16, 'decoder_num_heads': 16, 'mlp_ratio': 1.0, 'root_path': '../', 'output_path': '../results/spike_pretrain/09-12-2024-09-56-08', 'seed': 2022, 'roi': 'VC', 'aug_times': 1, 'num_sub_limit': None, 'include_hcp': True, 'include_kam': True, 'accum_iter': 1, 'use_nature_img_loss': False, 'img_recon_weight': 0.5, 'focus_range': None, 'focus_rate': 0.6, 'local_rank': 0}


In [9]:
device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu')
torch.manual_seed(config.seed)
np.random.seed(config.seed)

In [10]:
# create dataset and dataloader
# allen_dataset = allen_dataset_1d(fmri_transform=fmri_transform)
allen_dataset = allen_dataset_static_grating_1d(fmri_transform=fmri_transform)

print(f'Dataset size: {len(allen_dataset)}\nNumber of neurons: {allen_dataset.n_neurons}')
sampler = torch.utils.data.DistributedSampler(allen_dataset, rank=config.local_rank) if torch.cuda.device_count() > 1 else None 

dataloader_allen = DataLoader(allen_dataset, batch_size=config.batch_size, sampler=sampler, 
            shuffle=(sampler is None), pin_memory=True)

Dataset size: 4800
Number of neurons: 1952


In [11]:
# create model
# config.num_voxels = allen_dataset.n_neurons
# model = MAEforSPIKE(img_size=allen_dataset.n_neurons, patch_size=config.patch_size, embed_dim=config.embed_dim,
#                 decoder_embed_dim=config.decoder_embed_dim, depth=config.depth, 
#                 num_heads=config.num_heads, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio,
#                 focus_range=config.focus_range, focus_rate=config.focus_rate, 
#                 img_recon_weight=config.img_recon_weight, use_nature_img_loss=config.use_nature_img_loss)
model = MAEforFMRI(num_voxels=allen_dataset.n_neurons, patch_size=config.patch_size, embed_dim=config.embed_dim,
                decoder_embed_dim=config.decoder_embed_dim, depth=config.depth, 
                num_heads=config.num_heads, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio,
                focus_range=config.focus_range, focus_rate=config.focus_rate, 
                img_recon_weight=config.img_recon_weight, use_nature_img_loss=config.use_nature_img_loss)
model.to(device)
model_without_ddp = model
if torch.cuda.device_count() > 1:
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank, find_unused_parameters=config.use_nature_img_loss)

param_groups = optim_factory.add_weight_decay(model, config.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=config.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.00025
    maximize: False
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.00025
    maximize: False
    weight_decay: 0.05
)


  self._scaler = torch.cuda.amp.GradScaler()


In [12]:
state_dict = torch.load('../results/spike_pretrain/09-12-2024-00-07-46/checkpoints/checkpoint.pth', weights_only=False)

In [13]:
model.load_state_dict(state_dict['model'])
optimizer.load_state_dict(state_dict['optimizer'])
loss_scaler.load_state_dict(state_dict['scaler'])
epoch = state_dict['epoch'] + 1 # begins from next epoch
config = state_dict['config']

### Transfer weights to a encoder-only model

In [14]:
# Load the MAEforSPIKE state_dict
mae_state_dict = torch.load('../results/spike_pretrain/09-12-2024-00-07-46/checkpoints/checkpoint.pth')['model']

  mae_state_dict = torch.load('../results/spike_pretrain/09-12-2024-00-07-46/checkpoints/checkpoint.pth')['model']


In [15]:
model.load_state_dict(mae_state_dict)

<All keys matched successfully>

In [16]:
# Define the keys that belong to the encoder
encoder_keys = [key for key in mae_state_dict.keys() if not key.startswith('decoder') and not key.startswith('mask_token')]

# Create a new state_dict with only encoder weights
encoder_state_dict = {key: mae_state_dict[key] for key in encoder_keys}

mae_encoder = fmri_encoder(num_voxels=allen_dataset.n_neurons, patch_size=config.patch_size, embed_dim=config.embed_dim,
                 depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio)
m, u = mae_encoder.load_state_dict(encoder_state_dict, strict=False)
print('missing keys:', u)
print('unexpected keys:', m)

missing keys: []
unexpected keys: ['mask_token']


In [17]:
mae_encoder.embed_dim

128

### Initialize classifier

In [18]:
model = fmri_classifier(base_encoder=mae_encoder, num_classes=6)
model = model.to(device)

In [19]:
# Training loop
def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for data_dict in dataloader:
        inputs = data_dict['fmri']
        labels = data_dict['class_label']
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [20]:
# Evaluation loop
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data_dict in dataloader:
            inputs = data_dict['fmri']
            labels = data_dict['class_label']
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [21]:
# Initialize W&B
wandb.init(project="fmri-classification", name="fine-tuning-fmri-encoder")

In [22]:
# Define split sizes
train_size = int(0.8 * len(allen_dataset))  # 80% for training
test_size = len(allen_dataset) - train_size  # Remaining 20% for testing

# Split dataset
train_dataset, test_dataset = random_split(allen_dataset, [train_size, test_size])

In [23]:
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

In [24]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
param_groups = optim_factory.add_weight_decay(model, config.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=config.lr, betas=(0.9, 0.95))

In [25]:
num_epochs = 5

In [26]:
# Training and evaluation
train_loss_list = []
train_acc_list = []
val_loss_list = []
val_acc_list = []
for epoch in range(num_epochs):
    train_loss, train_acc = train_model(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate_model(model, val_loader, criterion, device)

    train_loss_list.append(train_loss)
    train_acc_list.append(train_acc)
    val_loss_list.append(val_loss)
    val_acc_list.append(val_acc)

    # Log metrics to W&B
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "val_loss": val_loss,
        "val_accuracy": val_acc
    })

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/5
Train Loss: 0.3075, Train Acc: 0.9505
Val Loss: 0.0339, Val Acc: 0.9938
Epoch 2/5
Train Loss: 0.0585, Train Acc: 0.9836
Val Loss: 0.0225, Val Acc: 0.9927
Epoch 3/5
Train Loss: 0.0510, Train Acc: 0.9836
Val Loss: 0.0414, Val Acc: 0.9885
Epoch 4/5
Train Loss: 0.0367, Train Acc: 0.9893
Val Loss: 0.0593, Val Acc: 0.9792
Epoch 5/5
Train Loss: 0.0648, Train Acc: 0.9818
Val Loss: 0.0284, Val Acc: 0.9917
