In [None]:
# The code is extensively uses the ACN implementation, please see:
## https://github.com/Wangyixinxin/ACN##
#!/usr/bin/env python3
# encoding: utf-8
import yaml
from data import make_data_loaders
from models import build_model
from models.discriminator import get_style_discriminator
from solver import make_optimizer_double
from losses import get_losses, bce_loss, DiceLoss
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
from utils import init_env
import nibabel as nib
import numpy as np



In [None]:
def load_old_model(model_full, model_missing, d_style, optimizer, saved_model_path):
    print("Constructing model from saved file... ")
    checkpoint = torch.load(saved_model_path)
    model_full.load_state_dict(checkpoint["model_full"])
    model_missing.load_state_dict(checkpoint["model_missing"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    d_style.load_state_dict(checkpoint["d_style"])
    epoch = checkpoint["epochs"]

    return model_full, model_missing, d_style, optimizer, epoch
        
def to_numpy(tensor):
    if isinstance(tensor, (int, float)):
        return tensor
    else:
        return tensor.data.cpu().numpy()




In [None]:
## Main section
config = yaml.load(open('./config.yml'), Loader=yaml.FullLoader)
init_env('0')
loaders = make_data_loaders(config)
model_full, model_missing = build_model(inp_dim1 = 4, inp_dim2 = 1)
model_full    = model_full.cuda()
model_missing = model_missing.cuda()
d_style       = get_style_discriminator(num_classes = 128).cuda()
task_name = 'brats2018_flair'
log_dir = os.path.join(config['path_to_log'], task_name)
optimizer, scheduler = make_optimizer_double(config, model_full, model_missing)
losses = get_losses(config)

continue_training = False
epoch = 0

if not os.path.exists(log_dir):
    os.makedirs(log_dir)  
    
criteria = DiceLoss() 

In [None]:
## evalute the performance
def get_mask(seg_volume):
    seg_volume = seg_volume.detach().cpu().numpy()
    seg_volume = np.squeeze(seg_volume)
    wt_pred = seg_volume[0]
    tc_pred = seg_volume[1]
    et_pred = seg_volume[2]
    mask = np.zeros_like(wt_pred)
    mask[wt_pred > 0.5] = 2
    mask[tc_pred > 0.5] = 1
    mask[et_pred > 0.5] = 4
    mask = mask.astype("uint8")
    return mask

def eval_metrics(gt, pred):
    loss_wt = criteria(np.where(gt>0, 1, 0), np.where(pred>0, 1, 0))
    loss_ct = criteria(np.where(gt==1, 1, 0)+np.where(gt==4, 1, 0), np.where(pred==1, 1, 0)+np.where(pred==4, 1, 0))
    loss_et = criteria(np.where(gt==4, 1, 0), np.where(pred==4, 1, 0))
    return loss_wt, loss_et, loss_ct

def measure_dice_score(batch_pred, batch_y):
    pred = get_mask(batch_pred)
    gt   = get_mask(batch_y)
    loss_wt, loss_et, loss_ct = eval_metrics(gt, pred)
    score = (loss_wt+loss_et+loss_ct)/3.0
    
    return score
    


In [None]:
def train_val(model_full, model_missing, d_style, loaders, optimizer, scheduler, losses, epoch_init=0):
    n_epochs = int(config['epochs'])
    iter_num = 0
    best_dice = 0.0
    for epoch in range(epoch_init, n_epochs):
        scheduler.step()
        train_loss = 0.0
        val_scores_full = 0.0
        val_scores_miss = 0.0
        
        for phase in ['train', 'eval']:
            loader = loaders[phase]
            total = len(loader)
            for batch_id, (batch_x, batch_y) in enumerate(loader):
                iter_num = iter_num + 1
                batch_x, batch_y = batch_x.cuda(non_blocking=True), batch_y.cuda(non_blocking=True)
                with torch.set_grad_enabled(phase == 'train'):
                    seg_f, style_f, content_f = model_full(batch_x[:,0:])
                    seg_m, style_m, content_m = model_missing(batch_x[:,0:1])
                    loss_dict = losses['co_loss'](config, seg_f, content_f, batch_y, seg_m, content_m, style_f, style_m, epoch)
                   
                    d_style.train()
                    optimizer_d_style = optim.Adam(d_style.parameters(), lr = float(config['lr']), betas=(0.9, 0.99))
                    # labels for style adversarial training
                    source_label = 0
                    target_label = 1

                    optimizer.zero_grad()
                    optimizer_d_style.zero_grad()
                    
                    # only train. Don't accumulate grads in disciminators
                    for param in d_style.parameters():
                        param.requires_grad = False

                    if phase == 'train':
                        (loss_dict['loss_Co']).backward(retain_graph=True)
                        train_loss += loss_dict['loss_Co'].item()
                    

                    
                    
                    ##################### adversarial training ot fool the discriminator ######################
                    df_src_main = style_f
                    df_trg_main = style_m
                    d_df_out_main = d_style(df_trg_main)
                    loss_adv_df_trg_main = bce_loss(d_df_out_main, source_label)
                    loss = 0.0002 * loss_adv_df_trg_main
                    if phase == 'train':
                        loss.backward()                    
                    
                    
                    ####################### Train discriminator networks ######################################
                    # enable training mode on discriminator networks
                    for param in d_style.parameters():
                        param.requires_grad = True
                    df_src_main = df_src_main.detach()
                    d_df_out_main = d_style(df_src_main)
                    loss_d_feature_main = bce_loss(d_df_out_main, source_label)
                    if phase == 'train':
                        loss_d_feature_main.backward()
                    
                    ####################### train with target ##################################################
                    df_trg_main = df_trg_main.detach()
                    d_df_out_main = d_style(df_trg_main)
                    loss_d_feature_main = bce_loss(d_df_out_main, target_label)

                    if phase == 'train':
                        loss_d_feature_main.backward()

                
                num_classes = 4

                if phase == 'train':
                    optimizer.step()
                    optimizer_d_style.step()
                    if (batch_id + 1) % 20 == 0:
                        print(f'Epoch {epoch+1}>> itteration {batch_id+1}>> training loss>> {train_loss/(batch_id+1)}')
 
                else:
                    val_scores_full += measure_dice_score(seg_f, batch_y)
                    val_scores_miss += measure_dice_score(seg_m, batch_y)
           
            if phase == 'train':
                print(f'Epoch {epoch+1} overall training loss>> {train_loss/(batch_id+1)}')
            else:
                dice = (val_scores_miss/(batch_id+1))
                print(f'Epoch {epoch+1} validation dice score for missing modality>> {dice}')
                state = {}
                state['model_full'] = model_full.state_dict()
                state['model_missing'] = model_missing.state_dict()
                state['d_style'] = d_style.state_dict()
                state['optimizer'] = optimizer.state_dict()
                state['epochs'] = epoch
                file_name = log_dir+'/model_weights.pth'
                torch.save(state, file_name)
                if dice > best_dice:
                    torch.save(state, file_name)
                    best_dice = dice


In [None]:
saved_model_path = log_dir+'/model_weights.pth'
if continue_training:
    model_full, model_missing, d_style, optimizer, epoch = load_old_model(model_full, model_missing, d_style, optimizer, saved_model_path)

train_val(model_full, model_missing, d_style, loaders, optimizer, scheduler, losses, epoch)
print('Training process is finished')


