In [1]:
import numpy as np
import argparse
import os
import sys
import torch
import torch.fft as F
from importlib import reload
from torch.nn.functional import relu
import torch.nn as nn
import torch.nn.functional as Func
import torch.optim as optim
import utils
import logging

from matplotlib import pyplot as plt
from utils import kplot,mask_naiveRand,mask_filter
from mnet import MNet
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

In [2]:
def sigmoid_binarize(M,threshold=0.6):
    sigmoid = nn.Sigmoid()
    mask = sigmoid(M)
    mask_pred = torch.ones_like(mask)
    mask_pred[mask<=threshold] = 0
    return mask_pred
    
def trainMNet(trainimgs,trainlabels,testimgs,testlabels,\
              epochs=20,batchsize=5,\
              lr=0.01,lr_weight_decay=1e-8,opt_momentum=0,positive_weight=6,\
              lr_s_stepsize=5,lr_s_gamma=0.5,\
              model=None,save_cp=True,threshold=0.5,\
              beta=1,poolk=3,datatype=torch.float,print_every=10):
    '''
    trainimgs    : train data, with dimension (#imgs,height,width,layer)
    '''
    
    train_shape  = trainimgs.shape; test_shape = testimgs.shape 
    trainimgs    = torch.tensor(trainimgs,dtype=datatype).view(train_shape[0],-1,train_shape[1],train_shape[2])
    trainlabels  = torch.tensor(trainlabels,dtype=datatype)
    testimgs     = torch.tensor(testimgs,dtype=datatype).view(test_shape[0],-1,test_shape[1],test_shape[2])
    testlabels   = torch.tensor(testlabels ,dtype=datatype)
    
    train_shape = trainimgs.shape
    dir_checkpoint = '/home/huangz78/mri/checkpoints/'
    # add normalization for images here
    
    if model is None:
        net = MNet(beta=beta,in_channels=train_shape[1],out_size=trainlabels.shape[1],\
                   imgsize=(train_shape[2],train_shape[3]),poolk=poolk)
        optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=lr_weight_decay, momentum=opt_momentum)
        epoch_init = 0
    else:
        net = model[0]
        optimizer  = model[1]
        epoch_init = model[2] + 1
#     criterion = nn.MSELoss()
    pos_weight = torch.ones([trainlabels.shape[1]]) * positive_weight # weight assigned to positive labels 
    criterion  = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    test_criterion = nn.BCELoss()
    sigmoid = nn.Sigmoid()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=lr_s_stepsize,factor=lr_s_gamma)

    epoch_loss        = np.full((epochs),np.nan)
    precision_history = np.full((epochs),np.nan)
    recall_history    = np.full((epochs),np.nan)
    for epoch in range(epoch_init,epoch_init + epochs):
        batch_init = 0; step_count = 1
        while batch_init < train_shape[0]:
            batch = np.arange(batch_init,min(batch_init+batchsize,train_shape[0]))
            imgbatch = trainimgs[batch,:,:,:] # maybe shuffling?
            batchlabels = trainlabels[batch,:]
            mask_pred   = net(imgbatch)
            train_loss  = criterion(mask_pred,batchlabels)
            batch_init += batchsize; step_count += 1
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            if (step_count%print_every)==0:
                with torch.no_grad():
                    net.eval()
                    mask_test = sigmoid_binarize(net(testimgs),threshold=threshold)
                    test_loss = test_criterion(mask_test,testlabels)
                    print('epoch {} global step {}: train batch loss {}, test loss {} '.format(epoch+1,step_count,train_loss.item(),test_loss.item()))
                    net.train()
        net.eval()
        mask_test = sigmoid_binarize(net(testimgs),threshold=threshold)
        test_loss = test_criterion(mask_test,testlabels)
        net.train()
        scheduler.step(test_loss)
        epoch_loss[epoch-epoch_init] = test_loss.item()
        precision_history[epoch-epoch_init] = precision_score(torch.flatten(testlabels),torch.flatten(mask_test))
        recall_history[epoch-epoch_init] = recall_score(torch.flatten(testlabels),torch.flatten(mask_test))
        print('\t epoch {} end: test loss {} '.format(epoch+1,test_loss.item()))
        print('\t epoch {} end: precision {} '.format(epoch+1,precision_history[epoch-epoch_init]))
        print('\t epoch {} end: recall    {} '.format(epoch+1,recall_history[epoch-epoch_init]))
        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                print('Created checkpoint directory')
            except OSError:
                pass
            torch.save({'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'epoch': epoch,
                        'threshold':threshold
                        }, dir_checkpoint + 'mnet.pth')
#                         }, dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            print(f'\t Checkpoint saved after epoch {epoch + 1}!')
    
            np.savez(dir_checkpoint+'epoch_loss.npz', loss=epoch_loss,precision=precision_history,recall=recall_history)

In [3]:
data_path = '/home/huangz78/data/'
data_imgs = np.load(data_path+'data_gt.npz')
data_labels = np.load(data_path+'data_gt_greedymask.npz')
print(data_imgs.files)
data    = data_imgs['imgdata']
labels  = data_labels['mask'].T
datashape = data.shape
print(datashape)
print(labels.shape)

['imgdata']
(320, 320, 199)
(199, 320)


#### data preparation

In [4]:
base = 24
mask = torch.tensor( mask_naiveRand(320,fix=base,other=0,roll=False)[0] ,dtype=torch.float )
data_under = np.zeros((datashape[2],datashape[0],datashape[1]))
for ind in range(data.shape[2]):
    img = data[:,:,ind]
    img = img/np.max(np.abs(img))
    yfull = F.fftn(torch.tensor(img,dtype=torch.float),dim=(0,1),norm='ortho')
    ypart = torch.tensordot(torch.diag(mask).to(torch.cfloat) , yfull,dims=([1],[0]))
    data_under[ind,:,:] = torch.abs(F.ifftn(ypart,dim=(0,1),norm='ortho'))

  mask = torch.tensor( mask_naiveRand(320,fix=base,other=0,roll=False)[0] ,dtype=torch.float )


In [5]:
from sklearn.model_selection import train_test_split
imgNum = 199
traininds, testinds = train_test_split(np.arange(imgNum),random_state=0,shuffle=True,train_size=round(imgNum*0.8))
test_total = testinds.size
traindata    = data_under[traininds,:,:]
trainlabels  = mask_filter(labels[traininds,:],base=base)
valdata      = data_under[testinds[0:test_total//2],:,:]
vallabels    = mask_filter(labels[testinds[0:test_total//2],:],base=base)

In [6]:
print(traindata.shape)
print(trainlabels.shape)
print(valdata.shape)
print(vallabels.shape)

(159, 320, 320)
(159, 296)
(20, 320, 320)
(20, 296)


In [7]:
sys.path.insert(0,'/home/huangz78/mri/mnet/')
import mnet
reload(mnet)
from mnet import MNet

#### training

In [12]:
# mnet = MNet(out_size=trainlabels.shape[1])
# checkpoint = torch.load('/home/huangz78/mri/checkpoints/mnet.pth')
# mnet.load_state_dict(checkpoint['model_state_dict'])
# print('mnet loaded successfully from : ' + '/home/huangz78/mri/checkpoints/mnet.pth' )
# mnet.train()
# # print(mnet)
trainMNet(traindata,trainlabels, valdata,vallabels,
          epochs=60, batchsize=5, \
          lr=1e-4, lr_weight_decay=0,opt_momentum=0,positive_weight=1,\
          lr_s_stepsize=2,lr_s_gamma=0.5,\
          threshold=.5, beta=1,save_cp=True)

epoch 1 global step 10: train batch loss 0.37502849102020264, test loss 16.08108139038086 
epoch 1 global step 20: train batch loss 0.28756147623062134, test loss 16.41891860961914 
epoch 1 global step 30: train batch loss 0.27262017130851746, test loss 16.486486434936523 
	 epoch 1 end: test loss 16.114864349365234 
	 epoch 1 end: precision 0.7307142857142858 
	 epoch 1 end: recall    0.639375 
	 Checkpoint saved after epoch 1!
epoch 2 global step 10: train batch loss 0.29684698581695557, test loss 16.824323654174805 
epoch 2 global step 20: train batch loss 0.2700611352920532, test loss 16.25 
epoch 2 global step 30: train batch loss 0.2540127635002136, test loss 16.5202693939209 
	 epoch 2 end: test loss 16.013513565063477 
	 epoch 2 end: precision 0.7318634423897582 
	 epoch 2 end: recall    0.643125 
	 Checkpoint saved after epoch 2!
epoch 3 global step 10: train batch loss 0.2949286699295044, test loss 16.5202693939209 
epoch 3 global step 20: train batch loss 0.2671421468257904,

	 Checkpoint saved after epoch 19!
epoch 20 global step 10: train batch loss 0.21977078914642334, test loss 15.236486434936523 
epoch 20 global step 20: train batch loss 0.2057623565196991, test loss 15.777027130126953 
epoch 20 global step 30: train batch loss 0.16045457124710083, test loss 15.878377914428711 
	 epoch 20 end: test loss 15.827702522277832 
	 epoch 20 end: precision 0.7366167023554604 
	 epoch 20 end: recall    0.645 
	 Checkpoint saved after epoch 20!
epoch 21 global step 10: train batch loss 0.21565930545330048, test loss 15.287161827087402 
epoch 21 global step 20: train batch loss 0.20306776463985443, test loss 15.895270347595215 
epoch 21 global step 30: train batch loss 0.15692956745624542, test loss 15.844594955444336 
	 epoch 21 end: test loss 15.726351737976074 
	 epoch 21 end: precision 0.7387580299785867 
	 epoch 21 end: recall    0.646875 
	 Checkpoint saved after epoch 21!
epoch 22 global step 10: train batch loss 0.21185781061649323, test loss 15.270270347

	 epoch 38 end: test loss 16.030405044555664 
	 epoch 38 end: precision 0.734341252699784 
	 epoch 38 end: recall    0.6375 
	 Checkpoint saved after epoch 38!
epoch 39 global step 10: train batch loss 0.19806911051273346, test loss 15.979729652404785 
epoch 39 global step 20: train batch loss 0.19376562535762787, test loss 16.08108139038086 
epoch 39 global step 30: train batch loss 0.1471765637397766, test loss 16.148649215698242 
	 epoch 39 end: test loss 16.030405044555664 
	 epoch 39 end: precision 0.734341252699784 
	 epoch 39 end: recall    0.6375 
	 Checkpoint saved after epoch 39!
epoch 40 global step 10: train batch loss 0.19802656769752502, test loss 15.979729652404785 
epoch 40 global step 20: train batch loss 0.19373400509357452, test loss 16.08108139038086 
epoch 40 global step 30: train batch loss 0.1471555083990097, test loss 16.148649215698242 
	 epoch 40 end: test loss 16.030405044555664 
	 epoch 40 end: precision 0.734341252699784 
	 epoch 40 end: recall    0.6375 
	

epoch 57 global step 30: train batch loss 0.14700840413570404, test loss 16.131755828857422 
	 epoch 57 end: test loss 16.04729652404785 
	 epoch 57 end: precision 0.7338129496402878 
	 epoch 57 end: recall    0.6375 
	 Checkpoint saved after epoch 57!
epoch 58 global step 10: train batch loss 0.1977592408657074, test loss 15.979729652404785 
epoch 58 global step 20: train batch loss 0.1935812085866928, test loss 16.08108139038086 
epoch 58 global step 30: train batch loss 0.1470058411359787, test loss 16.131755828857422 
	 epoch 58 end: test loss 16.064189910888672 
	 epoch 58 end: precision 0.7332854061826024 
	 epoch 58 end: recall    0.6375 
	 Checkpoint saved after epoch 58!
epoch 59 global step 10: train batch loss 0.19775390625, test loss 15.979729652404785 
epoch 59 global step 20: train batch loss 0.19357796013355255, test loss 16.08108139038086 
epoch 59 global step 30: train batch loss 0.14700312912464142, test loss 16.131755828857422 
	 epoch 59 end: test loss 16.0641899108