In [1]:
import torch
import json
import argparse
import os
import monai
import pandas as pd
import numpy as np

from dev.data.dataset_csv import CSVDataset
from adrd.model import ADRDModel
from tqdm import tqdm
from collections import defaultdict

In [2]:
# change the train, vld and test file paths, new cnf_file path
basedir = '.'
data_path="/home/varunaja/mri_pet/adrd_tool_varuna/adrd_transformer/data/adni_data_single.csv" # path to the data file before train val test split
train_path="/home/varunaja/mri_pet/adrd_tool_varuna/adrd_transformer/data/adni_train_split_single.csv"
vld_path="/home/varunaja/mri_pet/adrd_tool_varuna/adrd_transformer/data/adni_val_split_single.csv"
test_path="/home/varunaja/mri_pet/adrd_tool_varuna/adrd_transformer/data/a4_test_split_single.csv"
cnf_file="/home/varunaja/mri_pet/adrd_tool_varuna/adrd_transformer/meta_files/ab_tau_config_finetune.toml"
orig_ckpt_path = '/data_1/skowshik/ckpts_backbone_swinunet/ckpt_without_imaging.pt'
new_ckpt_path = f'{basedir}/dev/ckpt/model_ckpt_finetune.pt'

# no need to change these as they will not be used with non-imaging model
emb_path = '/data_1/dlteif/SwinUNETR_MRI_stripped_MNI_emb/' 
nacc_mri_info = "dev/nacc_mri_3d.json"
other_mri_info = "dev/other_3d_mris.json"

img_net="NonImg"
img_mode=-1
mri_type="SEQ"

# these are labels to remove from the model's state dictionary
labels_to_remove = ['NC', 'MCI', 'DE', 'AD', 'LBD', 'VD', 'PRD', 'FTD', 'NPH', 'SEF', 'PSY', 'TBI', 'ODE']

# add the new labels
new_labels = ['amy_label', 'tau_label']

In [3]:
train_path

'/home/varunaja/mri_pet/adrd_tool_varuna/adrd_transformer/data/adni_train_split_single.csv'

In [4]:
state_dict = torch.load(orig_ckpt_path, map_location=torch.device('cpu'))
if 'state_dict' in state_dict:
    state_dict = state_dict['state_dict']
else:
    src_modalities = state_dict.pop('src_modalities')
    tgt_modalities = state_dict.pop('tgt_modalities')
    if 'label_distribution' in state_dict:
        label_distribution = state_dict.pop('label_distribution')
    if 'optimizer' in state_dict:
        optimizer = state_dict.pop('optimizer')
    d_model = state_dict.pop('d_model')
    nhead = state_dict.pop('nhead')
    num_encoder_layers = state_dict.pop('num_encoder_layers')
    num_decoder_layers = state_dict.pop('num_decoder_layers')
    if 'epoch' in state_dict.keys():
        start_epoch = state_dict.pop('epoch')
    img_net = state_dict.pop('img_net')
    imgnet_layers = state_dict.pop('imgnet_layers')
    img_size = state_dict.pop('img_size')
    patch_size = state_dict.pop('patch_size')
    imgnet_ckpt = state_dict.pop('imgnet_ckpt')
    train_imgnet = state_dict.pop('train_imgnet')
    if 'scaler' in state_dict and state_dict['scaler']:
        state_dict.pop('scaler')

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [5]:
torch.cuda.is_available()

False

In [6]:
# initialize datasets
seed = 0
stripped = '_stripped_MNI'
print("Loading training dataset ... ")
dat_trn = CSVDataset(dat_file=train_path, cnf_file=cnf_file, mode=0, img_mode=img_mode, mri_type=mri_type, arch=img_net, emb_path=emb_path, nacc_mri_info=nacc_mri_info, other_3d_mris=other_mri_info, transforms=None, stripped=stripped)
print("Done.\nLoading Validation dataset ...")
dat_vld = CSVDataset(dat_file=vld_path, cnf_file=cnf_file, mode=1, img_mode=img_mode, mri_type=mri_type, arch=img_net, emb_path=emb_path, nacc_mri_info=nacc_mri_info, other_3d_mris=other_mri_info, transforms=None, stripped=stripped)
# print("Done.\nLoading testing dataset ...")
# dat_tst = CSVDataset(dat_file=test_path, cnf_file=cnf_file, mode=2, img_mode=img_mode, mri_type=mri_type, arch=img_net, emb_path=emb_path, nacc_mri_info=nacc_mri_info, other_mri_info=other_mri_info, transforms=None, stripped=stripped)
# print("Done.")

Loading training dataset ... 
/home/varunaja/mri_pet/adrd_tool_varuna/adrd_transformer/data/adni_train_split_single.csv
AVAILABLE MRI Cohorts:  set()
NACC MRIs not available
Avail mris: 0
1262 are selected for mode 0.
Out of 107 features in configuration file, [] are unavailable in data file.
Out of 2 labels in configuration file, 0 are unavailable in data file.
Total mri embeddings found: {}
Total mri embeddings found: 0
Out of 1262 samples, 0 are dropped due to complete feature missing.
Out of 1262 samples, 0 are dropped due to complete label missing.
Done.
Loading Validation dataset ...
/home/varunaja/mri_pet/adrd_tool_varuna/adrd_transformer/data/adni_val_split_single.csv
AVAILABLE MRI Cohorts:  set()
NACC MRIs not available
Avail mris: 0
315 are selected for mode 1.
Out of 107 features in configuration file, [] are unavailable in data file.
Out of 2 labels in configuration file, 0 are unavailable in data file.
Total mri embeddings found: {}
Total mri embeddings found: 0
Out of 315

In [7]:
df = pd.read_csv(data_path)

label_distribution = {}
for label in new_labels:
    label_distribution[label] = dict(df[label].value_counts())
label_fractions = dat_trn.label_fractions

print(label_fractions)
print(label_distribution)

num_epochs = 128
batch_size = 128
lr = 1e-3
weight_decay = 0.01
gamma = 2
fusion_stage = 'middle'
load_from_ckpt = False
save_intermediate_ckpts = True
ranking_loss = True
train_imgnet = False

# initialize Transformer
mdl = ADRDModel(
    src_modalities = dat_trn.feature_modalities,
    tgt_modalities = dat_trn.label_modalities,
    label_fractions = label_fractions,
    d_model = d_model,
    nhead = nhead,
    num_epochs = num_epochs,
    batch_size = batch_size, 
    lr = lr,
    weight_decay = weight_decay,
    gamma = gamma,
    criterion = 'AUC (ROC)',
    device = 'cpu',
    cuda_devices = [1,2],
    img_net = img_net,
    imgnet_layers = imgnet_layers,
    img_size = img_size,
    fusion_stage= fusion_stage,
    imgnet_ckpt = imgnet_ckpt,
    patch_size = patch_size,
    ckpt_path = new_ckpt_path,
    train_imgnet = train_imgnet,
    load_from_ckpt = load_from_ckpt,
    save_intermediate_ckpts = save_intermediate_ckpts,
    data_parallel = False,
    verbose = 4,
    wandb_ = 1,
    label_distribution = label_distribution,
    ranking_loss = ranking_loss,
    _amp_enabled = False,
    _dataloader_num_workers = 1,
)

{'amy_label': 0.49445324881141045, 'tau_label': 0.13391442155309033}
{'amy_label': {0: 804, 1: 773}, 'tau_label': {0.0: 602, 1.0: 212}}
Device: cpu


In [8]:
# Copy the saved model weights to the new model state dictionary
new_mdl_state_dict = mdl.net_.state_dict()
for key in state_dict.keys():
    if key in labels_to_remove or key == 'emb_aux':
        continue
    if key in new_mdl_state_dict:
        new_mdl_state_dict[key] = state_dict[key]

# Load the updated state dictionary into the new model
mdl.net_.load_state_dict(new_mdl_state_dict)

<All keys matched successfully>

In [9]:
# Train the model
mdl.fit(dat_trn.features, dat_vld.features, dat_trn.labels, dat_vld.labels, img_train_trans=None, img_vld_trans=None, img_mode=img_mode)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvarunaja[0m ([33mpet-proj[0m). Use [1m`wandb login --relogin`[0m to force relogin


AUC (ROC)
Ranking loss: True
Batch size: 128
None
Epoch 000 (TRN): 100%|| 1262/1262 [00:07<00:00, 176.21it/s]
Accuracy:           0.5689    0.6093    
Balanced Accuracy:  0.5688    0.5483    
Precision:          0.5656    0.3156    
Sensitivity/Recall: 0.5529    0.4201    
Specificity:        0.5846    0.6765    
F1 score:           0.5592    0.3604    
MCC:                0.1376    0.0891    
AUC (ROC):          0.5892    0.5785    
AUC (PR):           0.6021    0.3839    
Loss:               3.4742    0.7240    
Epoch 000 (VLD): 100%|| 315/315 [00:01<00:00, 308.95it/s]
Accuracy:           0.6762    0.3077    
Balanced Accuracy:  0.6660    0.5357    
Precision:          0.7474    0.2687    
Sensitivity/Recall: 0.4765    1.0000    
Specificity:        0.8554    0.0714    
F1 score:           0.5820    0.4236    
MCC:                0.3611    0.1386    
AUC (ROC):          0.7268    0.8765    
AUC (PR):           0.7300    0.7146    
Loss:               3.0679    0.6457    
model_ckpt_f