In [5]:
import os, sys
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
import argparse
import time
import timm.optim.optim_factory as optim_factory
import datetime
import matplotlib.pyplot as plt
import wandb
import copy

# own code
from config import Config_MBM_finetune
from dataset import create_Kamitani_dataset, create_BOLD5000_dataset
from sc_mbm.mae_for_fmri import MAEforFMRI
from sc_mbm.trainer import train_one_epoch
from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler
from sc_mbm.utils import save_model

In [6]:
class wandb_logger:
    def __init__(self, config):
        wandb.init( project='mind-vis',
                    group="stepA_sc-mbm_tune",
                    anonymous="allow",
                    config=config,
                    reinit=True)

        self.config = config
        self.step = None
    
    def log(self, name, data, step=None):
        if step is None:
            wandb.log({name: data})
        else:
            wandb.log({name: data}, step=step)
            self.step = step
    
    def watch_model(self, *args, **kwargs):
        wandb.watch(*args, **kwargs)

    def log_image(self, name, fig):
        if self.step is None:
            wandb.log({name: wandb.Image(fig)})
        else:
            wandb.log({name: wandb.Image(fig)}, step=self.step)

    def finish(self):
        wandb.finish(quiet=True)

In [7]:
def get_args_parser():
    parser = argparse.ArgumentParser('MAE finetuning on Test fMRI', add_help=False)

    # Training Parameters
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--num_epoch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--mask_ratio', type=float)

    # Project setting
    parser.add_argument('--root_path', type=str)
    parser.add_argument('--pretrain_mbm_path', type=str)
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--include_nonavg_test', type=bool)   
    
    # distributed training parameters
    parser.add_argument('--local_rank', type=int)
                        
    return parser

In [8]:
def create_readme(config, path):
    print(config.__dict__)
    with open(os.path.join(path, 'README.md'), 'w+') as f:
        print(config.__dict__, file=f)

def fmri_transform(x, sparse_rate=0.2):
    # x: 1, num_voxels
    x_aug = copy.deepcopy(x)
    idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False)
    x_aug[idx] = 0
    return torch.FloatTensor(x_aug)

In [10]:
config = Config_MBM_finetune()

In [11]:
if torch.cuda.device_count() > 1:
    torch.cuda.set_device(config.local_rank) 
    torch.distributed.init_process_group(backend='nccl')
output_path = os.path.join(config.root_path, 'results', 'fmri_pretrain',  '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")))
# output_path = os.path.join(config.root_path, 'results', 'fmri_pretrain')
config.output_path = output_path
logger = wandb_logger(config) if config.local_rank == 0 else None

if config.local_rank == 0:
    os.makedirs(output_path, exist_ok=True)
    create_readme(config, output_path)

device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu')
torch.manual_seed(config.seed)
np.random.seed(config.seed)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mzhaizhongyuan[0m ([33m11785-bhiksha[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'root_path': '.', 'output_path': './results/fmri_pretrain/22-11-2024-18-46-43', 'kam_path': './data/Kamitani/npz', 'bold5000_path': './data/BOLD5000', 'dataset': 'GOD', 'pretrain_mbm_path': './pretrains/GOD/fmri_encoder.pth', 'include_nonavg_test': True, 'kam_subs': ['sbj_3'], 'bold5000_subs': ['CSI1'], 'lr': 5.3e-05, 'weight_decay': 0.05, 'num_epoch': 15, 'batch_size': 16, 'mask_ratio': 0.75, 'accum_iter': 1, 'clip_grad': 0.8, 'warmup_epochs': 2, 'min_lr': 0.0, 'local_rank': 0}


AttributeError: 'Config_MBM_finetune' object has no attribute 'seed'

In [None]:
def main(config):
    if torch.cuda.device_count() > 1:
        torch.cuda.set_device(config.local_rank) 
        torch.distributed.init_process_group(backend='nccl')
    sd = torch.load(config.pretrain_mbm_path, map_location='cpu')
    config_pretrain = sd['config']
    
    output_path = os.path.join(config.root_path, 'results', 'fmri_finetune',  '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")))
    # output_path = os.path.join(config.root_path, 'results', 'fmri_finetune')
    config.output_path = output_path
    logger = wandb_logger(config) if config.local_rank == 0 else None
    
    if config.local_rank == 0:
        os.makedirs(output_path, exist_ok=True)
        create_readme(config, output_path)
    
    device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu')
    torch.manual_seed(config_pretrain.seed)
    np.random.seed(config_pretrain.seed)

    # create model
    num_voxels = (sd['model']['pos_embed'].shape[1] - 1)* config_pretrain.patch_size
    model = MAEforFMRI(num_voxels=num_voxels, patch_size=config_pretrain.patch_size, embed_dim=config_pretrain.embed_dim,
                    decoder_embed_dim=config_pretrain.decoder_embed_dim, depth=config_pretrain.depth, 
                    num_heads=config_pretrain.num_heads, decoder_num_heads=config_pretrain.decoder_num_heads, 
                    mlp_ratio=config_pretrain.mlp_ratio, focus_range=None, use_nature_img_loss=False) 
    model.load_state_dict(sd['model'], strict=False)

    model.to(device)
    model_without_ddp = model

    # create dataset and dataloader
    if config.dataset == 'GOD':
        _, test_set = create_Kamitani_dataset(path=config.kam_path, patch_size=config_pretrain.patch_size, 
                                subjects=config.kam_subs, fmri_transform=torch.FloatTensor, include_nonavg_test=config.include_nonavg_test)
    elif config.dataset == 'BOLD5000':
        _, test_set = create_BOLD5000_dataset(path=config.bold5000_path, patch_size=config_pretrain.patch_size, 
                fmri_transform=torch.FloatTensor, subjects=config.bold5000_subs, include_nonavg_test=config.include_nonavg_test)
    else:
        raise NotImplementedError

    print(test_set.fmri.shape)
    if test_set.fmri.shape[-1] < num_voxels:
        test_set.fmri = np.pad(test_set.fmri, ((0,0), (0, num_voxels - test_set.fmri.shape[-1])), 'wrap')
    else:
        test_set.fmri = test_set.fmri[:, :num_voxels]
    print(f'Dataset size: {len(test_set)}')
    sampler = torch.utils.data.DistributedSampler(test_set) if torch.cuda.device_count() > 1 else torch.utils.data.RandomSampler(test_set) 
    dataloader_hcp = DataLoader(test_set, batch_size=config.batch_size, sampler=sampler)

    if torch.cuda.device_count() > 1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank, find_unused_parameters=config.use_nature_img_loss)

    param_groups = optim_factory.add_weight_decay(model, config.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=config.lr, betas=(0.9, 0.95))
    print(optimizer)
    loss_scaler = NativeScaler()

    if logger is not None:
        logger.watch_model(model,log='all', log_freq=1000)

    cor_list = []
    start_time = time.time()
    print('Finetuning MAE on test fMRI ... ...')
    for ep in range(config.num_epoch):
        if torch.cuda.device_count() > 1: 
            sampler.set_epoch(ep) # to shuffle the data at every epoch
        cor = train_one_epoch(model, dataloader_hcp, optimizer, device, ep, loss_scaler, logger, config, start_time, model_without_ddp)
        cor_list.append(cor)
        if (ep % 2 == 0 or ep + 1 == config.num_epoch) and ep != 0 and config.local_rank == 0:
            # save models
            save_model(config_pretrain, ep, model_without_ddp, optimizer, loss_scaler, os.path.join(output_path,'checkpoints'))
            # plot figures
            plot_recon_figures(model, device, test_set, output_path, 5, config, logger, model_without_ddp)
            
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    if logger is not None:
        logger.log('max cor', np.max(cor_list), step=config.num_epoch-1)
        logger.finish()
    return

In [None]:
@torch.no_grad()
def plot_recon_figures(model, device, dataset, output_path, num_figures = 5, config=None, logger=None, model_without_ddp=None):
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    model.eval()
    fig, axs = plt.subplots(num_figures, 3, figsize=(30,15))
    fig.tight_layout()
    axs[0,0].set_title('Ground-truth')
    axs[0,1].set_title('Masked Ground-truth')
    axs[0,2].set_title('Reconstruction')

    for ax in axs:
        sample = next(iter(dataloader))['fmri']
        sample = sample.to(device)
        _, pred, mask = model(sample, mask_ratio=config.mask_ratio)
        sample_with_mask = model_without_ddp.patchify(sample).to('cpu').numpy().reshape(-1, model_without_ddp.patch_size)
        pred = model_without_ddp.unpatchify(pred).to('cpu').numpy().reshape(-1)
        sample = sample.to('cpu').numpy().reshape(-1)
        mask = mask.to('cpu').numpy().reshape(-1)
        # cal the cor
        cor = np.corrcoef([pred, sample])[0,1]

        x_axis = np.arange(0, sample.shape[-1])
        # groundtruth
        ax[0].plot(x_axis, sample)
        # groundtruth with mask
        s = 0
        for x, m in zip(sample_with_mask,mask):
            if m == 0:
                ax[1].plot(x_axis[s:s+len(x)], x, color='#1f77b4')
            s += len(x)
        # pred
        ax[2].plot(x_axis, pred)
        ax[2].set_ylabel('cor: %.4f'%cor, weight = 'bold')
        ax[2].yaxis.set_label_position("right")

    fig_name = 'reconst-%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))
    fig.savefig(os.path.join(output_path, f'{fig_name}.png'))
    if logger is not None:
        logger.log_image('reconst', fig)
    plt.close(fig)

def update_config(args, config):
    for attr in config.__dict__:
        if hasattr(args, attr):
            if getattr(args, attr) != None:
                setattr(config, attr, getattr(args, attr))
    return config

In [None]:
if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    config = Config_MBM_finetune()
    config = update_config(args, config)
    main(config)