In [None]:
# for torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import *
from sklearn.model_selection import StratifiedKFold, train_test_split
from ranger.ranger2020 import Ranger
from torch_summary import summary

# import custom libraries
import import_ipynb
from utils import *
from dataset import *
from net import *

In [None]:
# Setting args
LR_CAE      = 0.003            # learning rate for CAE
LR_MTL      = 0.003            # learning rate for MTL-CAE
EPOCHS_CAE  = 150              # training epochs for CAE
EPOCHS_MTL  = 50               # training epochs for MTL-CAE
BATCH_SIZE  = 8                # batch size
PRINT_IDX   = 10               # print result every PRINT_IDX epochs
DEVICE      = 'cuda:0'         # device: 'cpu' or 'cuda'
L2_REG      = 1e-4             # L2 regularization for each parameter
NETCAE_FILE = 'NetCAE.pth'     # Pre-trained CAE model path
TRAIN_CAE   = False            # True for re-training CAE
K_FOLDS     = 10               # K-Fold CV
RAND_STATE  = 23               # random seed
LOSS_FUNC   = nn.BCELoss()     # Binary Cross Enptroy
MODALITY    = 'multimodality'  # 'mri' or 'cdata' or 'multimodality'

LOADER_ARGS = {                # args for torch DataLoader
    'batch_size' : BATCH_SIZE,
    'shuffle'    : True,
    'num_workers': 4,
    'pin_memory' : True,
    'drop_last'  : True,
}

In [None]:
def runMTL(model, data, record):
    
    mri, cdata, target, task = data
    mri, cdata, target = mri.to(DEVICE), cdata.to(DEVICE), target.to(DEVICE)
    
    t1_index = np.argwhere(task==1).reshape(-1)
    t2_index = np.argwhere(task==2).reshape(-1)

    out1, out2 = model(mri, cdata)
        
    out1 = out1[t1_index].reshape(-1)
    out2 = out2[t2_index].reshape(-1)
    target1 = target[t1_index].float()
    target2 = target[t2_index].float()

    loss1 = LOSS_FUNC(out1, target1)
    loss2 = LOSS_FUNC(out2, target2)
    loss = 0.5*loss1 + loss2
        
    with torch.no_grad():
        # update metrics
        record.append(target2.cpu(), out2.cpu(), loss2.item())
    
    return loss
        
    
def trainMTL(model, dataloader, optimizer, epoch):
    model.train()
    record = Record()    
    for data in dataloader:
        optimizer.zero_grad()
        loss = runMTL(model, data, record)
        loss.backward()
        optimizer.step()
    return record


def validMTL(model, dataloader, epoch):
    model.eval()
    record = Record()
    with torch.no_grad():
        for data in dataloader:
            loss = runMTL(model, data, record)
    return record


def runCAE(model, data):         
    mri, cdata, target, task = data
    mri = mri.to(DEVICE)
    recon = model(mri)
    loss = F.binary_cross_entropy(recon, mri)
    return loss, mri, recon


def trainCAE(model, dataloader, optimizer, epoch):
    model.train()
    total_loss = 0
    for data in dataloader:
        optimizer.zero_grad()
        loss, mri, recon = runCAE(model, data)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    total_loss /= len(dataloader)
    if epoch % PRINT_IDX == 0:
        time('-- Epoch {:2d}  rloss: {:6.4f}'.format(epoch,total_loss))
        with torch.no_grad():
            showMRIs(mri.cpu())
            showMRIs(recon.cpu())
    return total_loss

        
def validCAE(model, dataloader):
    model.eval()
    with torch.no_grad():
        data = next(iter(dataloader))
        loss, mri, recon = runCAE(model, data)
        time('- rloss: {:6.4f}'.format(loss.item()))
        showMRIs(mri.cpu())
        showMRIs(recon.cpu())

In [None]:
SDIR = '/media/mbl/HDD/Screening_PREPSD3'
time('Building dataset...')
dataset = DataBuilder(SDIR).build()
cn = dataset.get('CN')
ad = dataset.get('AD')
smci = dataset.get('sMCI')
pmci = dataset.get('pMCI')

# random split 32 subjects as test set
smci, test_smci = train_test_split(smci, test_size=16, random_state=RAND_STATE)
pmci, test_pmci = train_test_split(pmci, test_size=16, random_state=RAND_STATE)

# subects for train/valid and test, labels for Stratified K-Fold
subjects = pd.concat((cn, ad, smci, pmci))
test_subjects = pd.concat((test_smci, test_pmci))
labels = np.array([0]*len(cn) + [1]*len(ad) + [2]*len(smci) + [3]*len(pmci))

dataset.print()
print('Total subjects:', len(subjects)+32)

In [None]:
train_set = ADNIDataset(pd.concat((subjects,test_subjects)), train=True, modality=MODALITY)
dataloader = DataLoader(train_set, **LOADER_ARGS)
model = NetCAE().to(DEVICE)

if TRAIN_CAE and MODALITY!='cdata':
    time('Training NetAE...')
    optimizer = Ranger(model.parameters(), lr=LR_CAE, k=6)
    scheduler = CosineAnnealingWarmRestarts(optimizer,5,2,eta_min=1e-5)
    mloss = 0.2
    for epoch in range(1, EPOCHS_CAE + 1):
        loss = trainCAE(model, dataloader, optimizer, epoch)
        scheduler.step()
        if loss < mloss:
            torch.save(model.state_dict(), NETCAE_FILE)
            time('Model saved: {:6.4f}'.format(loss))
            mloss = loss
            
elif MODALITY=='cdata':
    time('Use clinical feature only, skip loading model')
    
else:
    time('NetAE Result...')
    model.load_state_dict(torch.load(NETCAE_FILE))
    validCAE(model, dataloader)

In [None]:
kfold = StratifiedKFold(n_splits=K_FOLDS, random_state=RAND_STATE, shuffle=True)
for k, (train_index, valid_index) in enumerate(kfold.split(subjects,labels), 1):

    # setting train/valid/test dataset
    train_set = subjects.iloc[train_index]
    valid_set = subjects.iloc[valid_index]
    test_set  = test_subjects
        
    train_set = ADNIDataset(train_set, train=True,  modality=MODALITY)
    valid_set = ADNIDataset(valid_set, train=False, modality=MODALITY)
    test_set  = ADNIDataset(test_set,  train=False, modality=MODALITY)
        
    train_loader = DataLoader(train_set, **LOADER_ARGS)
    valid_loader = DataLoader(valid_set, **LOADER_ARGS)
    test_loader  = DataLoader(test_set,  **LOADER_ARGS)
    
    # loading model
    autoencoder = NetCAE().to(DEVICE)
    autoencoder.load_state_dict(torch.load(NETCAE_FILE))
    autoencoder.encoder[:3].requires_grad_(False)
        
    model = NetMTL(autoencoder.encoder, modality=MODALITY).to(DEVICE)
    
    # optimizer & LR scheduler
    optimizer = Ranger(model.parameters(), lr=LR_MTL, k=6, weight_decay=L2_REG)
    scheduler = CosineAnnealingWarmRestarts(optimizer,5,2,eta_min=1e-5)
    
    for epoch in range(1, EPOCHS_MTL + 1):
        record1 = trainMTL(model, train_loader, optimizer, epoch)
        record2 = validMTL(model, valid_loader, epoch)
        record3 = validMTL(model, test_loader,  epoch)
        record3.calculate()
        print(record3.acc())
        scheduler.step()