In [1]:
import sys
import datetime
import numpy as np
import nibabel as nib
import warnings
import ants
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score as auc
from sklearn.metrics import roc_curve as roc

# Global settings
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
DEBUG=True
np.set_printoptions(threshold=sys.maxsize)
warnings.filterwarnings('ignore', category=UserWarning)
%autosave 60
%matplotlib inline

In [None]:
class Record():
    
    ''' A record of an epoch '''
    
    def __init__(self):
        self.tar = np.array([]).astype(int)
        self.out = np.array([]).astype(np.float32)
        self.lss = np.array([]).astype(np.float32)
    
    def append(self, tar, out, lss):
        self.tar = np.append(self.tar, tar)
        self.out = np.append(self.out, out)
        self.lss = np.append(self.lss, lss)
            
    def calculate(self, threshold=None):
        if threshold is None:
            # get the best threshold
            fpr, tpr, thresholds = roc(self.tar, self.out)
            threshold = thresholds[np.argmax(tpr - fpr)]
        pred = np.where(self.out>=threshold, 1, 0)
        (self.tn, self.fp, self.fn, self.tp) = confusion_matrix(self.tar, pred, labels=[0,1]).ravel()
        return threshold
    
    def result(self):
        return np.array([self.loss(), self.acc()*100, self.sen()*100, self.spe()*100, self.auc()*100])
        
    def loss(self):
        return np.nanmean(self.lss)
    
    def acc(self):
        return (self.tp+self.tn)/(self.tp+self.tn+self.fp+self.fn)
    
    def sen(self):
        return (self.tp)/(self.tp+self.fn)
    
    def spe(self):
        return (self.tn)/(self.tn+self.fp)
    
    def auc(self):
        return auc(self.tar, self.out, labels=[0,1])

In [1]:
def time(tag='', date=False, end='\n'):
    format = '%Y-%m-%d %H:%M:%S' if date else '%H:%M:%S'
    print(datetime.datetime.now().strftime(format), tag, end=end)

In [None]:
def debug(s=''):
    if DEBUG:
        print(s)

In [None]:
def showMRI(img):
    plt.figure()
    plt.imshow(img[0,:,48,:], cmap="gray", vmax=img.max(), vmin=0)

In [None]:
def showMRIs(imgs):
    n = min(len(imgs), 6)
    plt.figure(figsize=(20,2*(n+(n-1))//n))
    for i, img in zip(range(n), imgs):
        img = img[0,:,48,:]
        plt.subplot((n+(n-1))//n, n, i+1)
        plt.axis('off')
        #plt.hist(img.reshape(img.shape[0]*img.shape[1]), bins = range(0,1000,10))
        plt.imshow(img, cmap="gray", vmax=img.max())
    plt.show()

In [None]:
def nii2img(nii):
    img = ants.image_read(nii)
    img = img.numpy().astype(np.float32)
    #img = nib.load(nii)
    #img = img.get_fdata()
    return img