In [1]:
import numpy as np
import argparse
import os
import sys
import torch
import torch.fft as F
from importlib import reload
from torch.nn.functional import relu
import torch.nn as nn
import torch.nn.functional as Func
import torch.optim as optim
import utils
import logging

from matplotlib import pyplot as plt
from utils import kplot,mask_naiveRand,mask_filter
from mnet import MNet
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

In [2]:
def sigmoid_binarize(M,threshold=0.6):
    sigmoid = nn.Sigmoid()
    mask = sigmoid(M)
    mask_pred = torch.ones_like(mask)
    mask_pred[mask<=threshold] = 0
    return mask_pred
    
def trainMNet(trainimgs,trainlabels,testimgs,testlabels,\
              epochs=20,batchsize=5,\
              lr=0.01,lr_weight_decay=1e-8,opt_momentum=0,positive_weight=6,\
              lr_s_stepsize=5,lr_s_gamma=0.5,\
              model=None,save_cp=True,threshold=0.5,\
              beta=1,poolk=3,datatype=torch.float,print_every=10):
    '''
    trainimgs    : train data, with dimension (#imgs,height,width,layer)
    '''
    
    train_shape  = trainimgs.shape; test_shape = testimgs.shape 
    trainimgs    = torch.tensor(trainimgs,dtype=datatype).view(train_shape[0],-1,train_shape[1],train_shape[2])
    trainlabels  = torch.tensor(trainlabels,dtype=datatype)
    testimgs     = torch.tensor(testimgs,dtype=datatype).view(test_shape[0],-1,test_shape[1],test_shape[2])
    testlabels   = torch.tensor(testlabels ,dtype=datatype)
    
    train_shape = trainimgs.shape
    dir_checkpoint = '/home/huangz78/mri/checkpoints/'
    # add normalization for images here
    
    if model is None:
        net = MNet(beta=beta,in_channels=train_shape[1],out_size=trainlabels.shape[1],\
                   imgsize=(train_shape[2],train_shape[3]),poolk=poolk)
        optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=lr_weight_decay, momentum=opt_momentum)
        epoch_init = 0
    else:
        net = model[0]
        optimizer  = model[1]
        epoch_init = model[2] + 1
#     criterion = nn.MSELoss()
    pos_weight = torch.ones([trainlabels.shape[1]]) * positive_weight # weight assigned to positive labels 
    criterion  = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    test_criterion = nn.BCELoss()
    sigmoid = nn.Sigmoid()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=lr_s_stepsize,factor=lr_s_gamma)

    epoch_loss        = np.full((epochs),np.nan)
    precision_history = np.full((epochs),np.nan)
    recall_history    = np.full((epochs),np.nan)
    for epoch in range(epoch_init,epoch_init + epochs):
        batch_init = 0; step_count = 1
        while batch_init < train_shape[0]:
            batch = np.arange(batch_init,min(batch_init+batchsize,train_shape[0]))
            imgbatch = trainimgs[batch,:,:,:] # maybe shuffling?
            batchlabels = trainlabels[batch,:]
            mask_pred   = net(imgbatch)
            train_loss  = criterion(mask_pred,batchlabels)
            batch_init += batchsize; step_count += 1
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            if (step_count%print_every)==0:
                with torch.no_grad():
                    net.eval()
                    mask_test = sigmoid_binarize(net(testimgs),threshold=threshold)
                    test_loss = test_criterion(mask_test,testlabels)
                    print('epoch {} global step {}: train batch loss {}, test loss {} '.format(epoch+1,step_count,train_loss.item(),test_loss.item()))
                    net.train()
        net.eval()
        mask_test = sigmoid_binarize(net(testimgs),threshold=threshold)
        test_loss = test_criterion(mask_test,testlabels)
        net.train()
        scheduler.step(test_loss)
        epoch_loss[epoch-epoch_init] = test_loss.item()
        precision_history[epoch-epoch_init] = precision_score(torch.flatten(testlabels),torch.flatten(mask_test))
        recall_history[epoch-epoch_init] = recall_score(torch.flatten(testlabels),torch.flatten(mask_test))
        print('\t epoch {} end: test loss {} '.format(epoch+1,test_loss.item()))
        print('\t epoch {} end: precision {} '.format(epoch+1,precision_history[epoch-epoch_init]))
        print('\t epoch {} end: recall    {} '.format(epoch+1,recall_history[epoch-epoch_init]))
        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                print('Created checkpoint directory')
            except OSError:
                pass
            torch.save({'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'epoch': epoch,
                        'threshold':threshold
                        }, dir_checkpoint + 'mnet.pth')
#                         }, dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            print(f'\t Checkpoint saved after epoch {epoch + 1}!')
    
            np.savez(dir_checkpoint+'epoch_loss.npz', loss=epoch_loss,precision=precision_history,recall=recall_history)

In [3]:
data_path = '/home/huangz78/data/'
data_imgs = np.load(data_path+'data_gt.npz')
data_labels = np.load(data_path+'data_gt_greedymask.npz')
print(data_imgs.files)
data    = data_imgs['imgdata']
labels  = data_labels['mask'].T
datashape = data.shape
print(datashape)
print(labels.shape)

['imgdata']
(320, 320, 199)
(199, 320)


#### data preparation

In [4]:
base = 24
mask = torch.tensor( mask_naiveRand(320,fix=base,other=0,roll=False)[0] ,dtype=torch.float )
data_under = np.zeros((datashape[2],datashape[0],datashape[1]))
for ind in range(data.shape[2]):
    img = data[:,:,ind]
    img = img/np.max(np.abs(img))
    yfull = F.fftn(torch.tensor(img,dtype=torch.float),dim=(0,1),norm='ortho')
    ypart = torch.tensordot(torch.diag(mask).to(torch.cfloat) , yfull,dims=([1],[0]))
    data_under[ind,:,:] = torch.abs(F.ifftn(ypart,dim=(0,1),norm='ortho'))

  mask = torch.tensor( mask_naiveRand(320,fix=base,other=0,roll=False)[0] ,dtype=torch.float )


In [5]:
from sklearn.model_selection import train_test_split
imgNum = 199
traininds, testinds = train_test_split(np.arange(imgNum),random_state=0,shuffle=True,train_size=round(imgNum*0.8))
test_total = testinds.size
traindata    = data_under[traininds,:,:]
trainlabels  = mask_filter(labels[traininds,:],base=base)
valdata      = data_under[testinds[0:test_total//2],:,:]
vallabels    = mask_filter(labels[testinds[0:test_total//2],:],base=base)

In [6]:
print(traindata.shape)
print(trainlabels.shape)
print(valdata.shape)
print(vallabels.shape)

(159, 320, 320)
(159, 296)
(20, 320, 320)
(20, 296)


In [7]:
sys.path.insert(0,'/home/huangz78/mri/mnet/')
import mnet
reload(mnet)
from mnet import MNet

#### training

In [8]:
# mnet = MNet(out_size=trainlabels.shape[1])
# checkpoint = torch.load('/home/huangz78/mri/checkpoints/mnet.pth')
# mnet.load_state_dict(checkpoint['model_state_dict'])
# print('mnet loaded successfully from : ' + '/home/huangz78/mri/checkpoints/mnet.pth' )
# mnet.train()
# # print(mnet)
trainMNet(traindata,trainlabels, valdata,vallabels,
          epochs=60, batchsize=5, \
          lr=1e-3, lr_weight_decay=0,opt_momentum=0,positive_weight=3,\
          lr_s_stepsize=2,lr_s_gamma=0.5,\
          threshold=.5, beta=1,save_cp=True)

epoch 1 global step 10: train batch loss 9.79659652709961, test loss 24.341217041015625 
epoch 1 global step 20: train batch loss 2.8813672065734863, test loss 24.594594955444336 
epoch 1 global step 30: train batch loss 7.909975528717041, test loss 22.027027130126953 
	 epoch 1 end: test loss 36.08108139038086 
	 epoch 1 end: precision 0.4134366925064599 
	 epoch 1 end: recall    0.8 
	 Checkpoint saved after epoch 1!
epoch 2 global step 10: train batch loss 1.4218230247497559, test loss 21.756755828857422 
epoch 2 global step 20: train batch loss 0.879325270652771, test loss 22.972972869873047 
epoch 2 global step 30: train batch loss 2.503659725189209, test loss 21.79054069519043 
	 epoch 2 end: test loss 25.405405044555664 
	 epoch 2 end: precision 0.5210711150131695 
	 epoch 2 end: recall    0.741875 
	 Checkpoint saved after epoch 2!
epoch 3 global step 10: train batch loss 0.9571739435195923, test loss 24.070945739746094 
epoch 3 global step 20: train batch loss 0.69599437713623

epoch 20 global step 10: train batch loss 0.5642018914222717, test loss 18.648649215698242 
epoch 20 global step 20: train batch loss 0.4912300407886505, test loss 18.530405044555664 
epoch 20 global step 30: train batch loss 0.491893470287323, test loss 19.138513565063477 
	 epoch 20 end: test loss 20.320945739746094 
	 epoch 20 end: precision 0.5988053758088602 
	 epoch 20 end: recall    0.751875 
	 Checkpoint saved after epoch 20!
epoch 21 global step 10: train batch loss 0.5627310276031494, test loss 18.4797306060791 
epoch 21 global step 20: train batch loss 0.48933014273643494, test loss 18.29391860961914 
epoch 21 global step 30: train batch loss 0.4920302927494049, test loss 19.1047306060791 
	 epoch 21 end: test loss 19.966217041015625 
	 epoch 21 end: precision 0.6047094188376754 
	 epoch 21 end: recall    0.754375 
	 Checkpoint saved after epoch 21!
epoch 22 global step 10: train batch loss 0.5616095066070557, test loss 18.41216278076172 
epoch 22 global step 20: train batch

	 Checkpoint saved after epoch 38!
epoch 39 global step 10: train batch loss 0.5390455722808838, test loss 18.429054260253906 
epoch 39 global step 20: train batch loss 0.45853179693222046, test loss 18.20945930480957 
epoch 39 global step 30: train batch loss 0.44844359159469604, test loss 18.716217041015625 
	 epoch 39 end: test loss 18.648649215698242 
	 epoch 39 end: precision 0.6319148936170212 
	 epoch 39 end: recall    0.7425 
	 Checkpoint saved after epoch 39!
epoch 40 global step 10: train batch loss 0.5369148254394531, test loss 18.597972869873047 
epoch 40 global step 20: train batch loss 0.45839521288871765, test loss 18.445945739746094 
epoch 40 global step 30: train batch loss 0.4429212510585785, test loss 18.58108139038086 
	 epoch 40 end: test loss 18.530405044555664 
	 epoch 40 end: precision 0.6344200962052379 
	 epoch 40 end: recall    0.741875 
	 Checkpoint saved after epoch 40!
epoch 41 global step 10: train batch loss 0.5370559692382812, test loss 18.5304050445556

	 Checkpoint saved after epoch 57!
epoch 58 global step 10: train batch loss 0.5332120060920715, test loss 18.513513565063477 
epoch 58 global step 20: train batch loss 0.451916366815567, test loss 18.344594955444336 
epoch 58 global step 30: train batch loss 0.43805569410324097, test loss 18.46283721923828 
	 epoch 58 end: test loss 18.4797306060791 
	 epoch 58 end: precision 0.636021505376344 
	 epoch 58 end: recall    0.739375 
	 Checkpoint saved after epoch 58!
epoch 59 global step 10: train batch loss 0.5330859422683716, test loss 18.513513565063477 
epoch 59 global step 20: train batch loss 0.4519413411617279, test loss 18.3952693939209 
epoch 59 global step 30: train batch loss 0.4373060166835785, test loss 18.46283721923828 
	 epoch 59 end: test loss 18.4797306060791 
	 epoch 59 end: precision 0.636021505376344 
	 epoch 59 end: recall    0.739375 
	 Checkpoint saved after epoch 59!
epoch 60 global step 10: train batch loss 0.5330671072006226, test loss 18.513513565063477 
epoch

	 Checkpoint saved after epoch 76!
epoch 77 global step 10: train batch loss 0.5328631401062012, test loss 18.530405044555664 
epoch 77 global step 20: train batch loss 0.45169079303741455, test loss 18.41216278076172 
epoch 77 global step 30: train batch loss 0.43692412972450256, test loss 18.41216278076172 
	 epoch 77 end: test loss 18.4797306060791 
	 epoch 77 end: precision 0.6361679224973089 
	 epoch 77 end: recall    0.73875 
	 Checkpoint saved after epoch 77!
epoch 78 global step 10: train batch loss 0.5328613519668579, test loss 18.530405044555664 
epoch 78 global step 20: train batch loss 0.4516866207122803, test loss 18.41216278076172 
epoch 78 global step 30: train batch loss 0.4369288980960846, test loss 18.41216278076172 
	 epoch 78 end: test loss 18.4797306060791 
	 epoch 78 end: precision 0.6361679224973089 
	 epoch 78 end: recall    0.73875 
	 Checkpoint saved after epoch 78!
epoch 79 global step 10: train batch loss 0.532859742641449, test loss 18.530405044555664 
epoc

	 Checkpoint saved after epoch 95!
epoch 96 global step 10: train batch loss 0.5328307151794434, test loss 18.54729652404785 
epoch 96 global step 20: train batch loss 0.4516112804412842, test loss 18.41216278076172 
epoch 96 global step 30: train batch loss 0.43700966238975525, test loss 18.41216278076172 
	 epoch 96 end: test loss 18.4797306060791 
	 epoch 96 end: precision 0.6361679224973089 
	 epoch 96 end: recall    0.73875 
	 Checkpoint saved after epoch 96!
epoch 97 global step 10: train batch loss 0.5328290462493896, test loss 18.54729652404785 
epoch 97 global step 20: train batch loss 0.4516071677207947, test loss 18.41216278076172 
epoch 97 global step 30: train batch loss 0.4370136260986328, test loss 18.41216278076172 
	 epoch 97 end: test loss 18.4797306060791 
	 epoch 97 end: precision 0.6361679224973089 
	 epoch 97 end: recall    0.73875 
	 Checkpoint saved after epoch 97!
epoch 98 global step 10: train batch loss 0.5328272581100464, test loss 18.54729652404785 
epoch 9

	 epoch 114 end: test loss 18.4797306060791 
	 epoch 114 end: precision 0.6361679224973089 
	 epoch 114 end: recall    0.73875 
	 Checkpoint saved after epoch 114!
epoch 115 global step 10: train batch loss 0.5327966809272766, test loss 18.54729652404785 
epoch 115 global step 20: train batch loss 0.4515336751937866, test loss 18.37837791442871 
epoch 115 global step 30: train batch loss 0.4370791018009186, test loss 18.429054260253906 
	 epoch 115 end: test loss 18.4797306060791 
	 epoch 115 end: precision 0.6361679224973089 
	 epoch 115 end: recall    0.73875 
	 Checkpoint saved after epoch 115!
epoch 116 global step 10: train batch loss 0.5327948927879333, test loss 18.54729652404785 
epoch 116 global step 20: train batch loss 0.4515294134616852, test loss 18.37837791442871 
epoch 116 global step 30: train batch loss 0.4370822012424469, test loss 18.429054260253906 
	 epoch 116 end: test loss 18.4797306060791 
	 epoch 116 end: precision 0.6361679224973089 
	 epoch 116 end: recall   

epoch 133 global step 20: train batch loss 0.4514618217945099, test loss 18.37837791442871 
epoch 133 global step 30: train batch loss 0.4371333420276642, test loss 18.429054260253906 
	 epoch 133 end: test loss 18.4797306060791 
	 epoch 133 end: precision 0.6361679224973089 
	 epoch 133 end: recall    0.73875 
	 Checkpoint saved after epoch 133!
epoch 134 global step 10: train batch loss 0.5327611565589905, test loss 18.54729652404785 
epoch 134 global step 20: train batch loss 0.45145782828330994, test loss 18.37837791442871 
epoch 134 global step 30: train batch loss 0.43713614344596863, test loss 18.429054260253906 
	 epoch 134 end: test loss 18.4797306060791 
	 epoch 134 end: precision 0.6361679224973089 
	 epoch 134 end: recall    0.73875 
	 Checkpoint saved after epoch 134!
epoch 135 global step 10: train batch loss 0.5327591300010681, test loss 18.54729652404785 
epoch 135 global step 20: train batch loss 0.45145383477211, test loss 18.37837791442871 
epoch 135 global step 30: 

	 Checkpoint saved after epoch 151!
epoch 152 global step 10: train batch loss 0.5327259302139282, test loss 18.54729652404785 
epoch 152 global step 20: train batch loss 0.4513866603374481, test loss 18.37837791442871 
epoch 152 global step 30: train batch loss 0.43718093633651733, test loss 18.429054260253906 
	 epoch 152 end: test loss 18.4797306060791 
	 epoch 152 end: precision 0.6361679224973089 
	 epoch 152 end: recall    0.73875 
	 Checkpoint saved after epoch 152!
epoch 153 global step 10: train batch loss 0.532724142074585, test loss 18.54729652404785 
epoch 153 global step 20: train batch loss 0.45138299465179443, test loss 18.37837791442871 
epoch 153 global step 30: train batch loss 0.437183141708374, test loss 18.429054260253906 
	 epoch 153 end: test loss 18.4797306060791 
	 epoch 153 end: precision 0.6361679224973089 
	 epoch 153 end: recall    0.73875 
	 Checkpoint saved after epoch 153!
epoch 154 global step 10: train batch loss 0.5327220559120178, test loss 18.547296

epoch 170 global step 30: train batch loss 0.4372166395187378, test loss 18.41216278076172 
	 epoch 170 end: test loss 18.46283721923828 
	 epoch 170 end: precision 0.6365105008077544 
	 epoch 170 end: recall    0.73875 
	 Checkpoint saved after epoch 170!
epoch 171 global step 10: train batch loss 0.5326883792877197, test loss 18.54729652404785 
epoch 171 global step 20: train batch loss 0.45131364464759827, test loss 18.361486434936523 
epoch 171 global step 30: train batch loss 0.43721824884414673, test loss 18.41216278076172 
	 epoch 171 end: test loss 18.46283721923828 
	 epoch 171 end: precision 0.6365105008077544 
	 epoch 171 end: recall    0.73875 
	 Checkpoint saved after epoch 171!
epoch 172 global step 10: train batch loss 0.5326865315437317, test loss 18.54729652404785 
epoch 172 global step 20: train batch loss 0.4513099491596222, test loss 18.361486434936523 
epoch 172 global step 30: train batch loss 0.4372200667858124, test loss 18.41216278076172 
	 epoch 172 end: test 

epoch 189 global step 10: train batch loss 0.5326519012451172, test loss 18.54729652404785 
epoch 189 global step 20: train batch loss 0.45124632120132446, test loss 18.344594955444336 
epoch 189 global step 30: train batch loss 0.4372458755970001, test loss 18.41216278076172 
	 epoch 189 end: test loss 18.46283721923828 
	 epoch 189 end: precision 0.6365105008077544 
	 epoch 189 end: recall    0.73875 
	 Checkpoint saved after epoch 189!
epoch 190 global step 10: train batch loss 0.5326499938964844, test loss 18.54729652404785 
epoch 190 global step 20: train batch loss 0.451242595911026, test loss 18.344594955444336 
epoch 190 global step 30: train batch loss 0.4372476041316986, test loss 18.41216278076172 
	 epoch 190 end: test loss 18.46283721923828 
	 epoch 190 end: precision 0.6365105008077544 
	 epoch 190 end: recall    0.73875 
	 Checkpoint saved after epoch 190!
epoch 191 global step 10: train batch loss 0.5326481461524963, test loss 18.54729652404785 
epoch 191 global step 20