In [1]:
import os, sys
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
import argparse
import time
import timm.optim.optim_factory as optim_factory
import datetime
import matplotlib.pyplot as plt
import wandb
import copy

from config import Config_MBM_fMRI
from dataset import hcp_dataset
from sc_mbm.mae_for_spike_train import MAEforSPIKE
from sc_mbm.trainer import train_one_epoch
from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler
from sc_mbm.utils import save_model

In [2]:
config = Config_MBM_fMRI()

In [3]:
img_size = 256
model = MAEforSPIKE(img_size=img_size, patch_size=config.patch_size, embed_dim=config.embed_dim,
                decoder_embed_dim=config.decoder_embed_dim, depth=config.depth, 
                num_heads=config.num_heads, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio,
                focus_range=config.focus_range, focus_rate=config.focus_rate, 
                img_recon_weight=config.img_recon_weight, use_nature_img_loss=config.use_nature_img_loss)   

In [4]:
model(torch.rand(1, 1, 256, 256))

(tensor(1.3977, grad_fn=<DivBackward0>),
 tensor([[[-0.4932,  1.0020,  1.2352,  ..., -1.7808, -0.8354,  1.3208],
          [-0.2488,  1.3524,  0.5713,  ..., -1.4481, -1.9327,  2.5502],
          [-0.6857,  1.2362,  0.7120,  ..., -1.7042, -0.5881,  1.5393],
          ...,
          [ 0.0126,  2.4812,  0.5324,  ..., -1.3076, -1.5454,  2.3111],
          [ 0.0196,  2.5036,  0.6580,  ..., -1.1654, -1.7286,  2.2324],
          [ 0.0081,  2.4196,  0.8417,  ..., -1.0599, -1.7259,  2.1585]]],
        grad_fn=<SliceBackward0>),
 tensor([[0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 0.,
          1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 1., 1., 0.,
          1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,