# Import

In [2]:
import os
from matplotlib import pyplot as plt
from sklearn.metrics import roc_curve, auc
import numpy as np
import torch as t
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch import optim, nn
from PIL import  Image
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torch.utils.data.sampler import  WeightedRandomSampler
import random
import glob
import pandas as pd

%matplotlib inline

# set_param

In [3]:
inchan = 1 # set 1 to 10
CT_type = 'P' # 'P' or 'CE'
seed = 1 # random seed
gpu = "cuda:1"
n_class = 2 # classes number
save_num = 0 # 5-fold cross-validation from 0 to 4
save_state = 1 # save the training state 0 or 1

model_name = 'Xception'
opt_name = 'adam'
epochs = 50 # epoch number
size_in = 171
bs  = 120 # training_batchsize
bs2 = 60 # testing_batchsize

pre_dir_name = ('CV_save/save_'+str(save_num)+'/'
                +'CT_type:'+str(CT_type)
                +'-inchan:'+str(inchan)
                +'-n_classes:'+str(n_class)
                +'-seed:'+str(seed))

path_CV = 'data/Center1'

df_sample_list = pd.read_csv('CV_sample_list.csv')
train_ID_list = df_sample_list[(df_sample_list['fold_num'] == save_num)&(df_sample_list['train'] == 1)]['ID'].tolist()
test_ID_list = df_sample_list[(df_sample_list['fold_num'] == save_num)&(df_sample_list['train'] == 0)]['ID'].tolist()

dev = t.device(gpu) if t.cuda.is_available() else t.device("cpu")

loss_func = F.cross_entropy

# function

In [4]:
if inchan == 1 : 
    transform = T.Compose([
        T.ToTensor(), 
        T.Normalize(mean=[.5], std=[.5]) 
    ])

    mirror = T.Compose([
        T.ToTensor(), 
        T.Normalize(mean=[.5], std=[.5]), 
        T.RandomHorizontalFlip(p=1)
    ])

    class CTdataset(Dataset):
        def __init__(self, imgs_list,transform = None, size_in = None):
            self.imgs_path = imgs_list
            self.transform = transform

        def __getitem__(self, index):
            img_path = self.imgs_path[index]
            if 'KF' in img_path:
                label = 0
            else:
                label = 1

            pil_img = Image.open(img_path).convert('L').resize((size_in, size_in),Image.ANTIALIAS)
            image_data = np.zeros([size_in, size_in,inchan])
            image_data[:,:,0] = pil_img
            image_data = image_data.astype(np.float32)

            if self.transform:
                data = self.transform(image_data)
            else:
                data = image_data
            return data, label

        def __len__(self):
            return len(self.imgs_path)

if inchan >= 2 : # 1.0 or 1.0CE
    std_list = list()
    for i in range(inchan):
        std_list.append(0.5)
    transform = T.Compose([
        T.ToTensor(), 
        T.Normalize(mean=std_list, std=std_list) 
    ])

    mirror = T.Compose([
        T.ToTensor(), 
        T.Normalize(mean=std_list, std=std_list), 
        T.RandomHorizontalFlip(p=1)
    ])


    class CTdataset(Dataset):
        def __init__(self, imgs_list,transform = None, size_in = None):
            self.imgs_path = imgs_list
            self.transform = transform

        def __getitem__(self, index):
            img_nX = self.imgs_path[index]
            img_X = img_nX[0]
            if 'KF' in img_X:
                label = 0
            else:
                label = 1

            image_data = np.zeros([size_in, size_in,inchan])
            for i in range(inchan):
                pil_img_n = Image.open(img_nX[i]).convert('L').resize((size_in, size_in),Image.ANTIALIAS)
                image_data[:,:,i] = pil_img_n
            image_data = image_data.astype(np.float32)

            if self.transform:
                data = self.transform(image_data)
            else:
                data = image_data
            return data, label

        def __len__(self):
            return len(self.imgs_path)

def loss_batch(model, loss_func, xb, yb, opt=None):
    out = model(xb)
    preds = t.argmax(out,dim=1)
    loss = loss_func(out, yb)
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
    return loss.item(), preds


def get_data(train_ds,bs,sampler=None):
    if sampler:
        t1 = DataLoader(train_ds, batch_size=bs,sampler=sampler)
    else:
        t1 = DataLoader(train_ds, batch_size=bs)
    return t1

def get_ID_root_label(ID,path):
    ID_root_list = list()
                   
    if inchan >= 2:
        path_ID = os.path.join(path,ID)
        ID_P_list = sorted(glob.glob(path_ID+'/1.0/*.jpg'))
        ID_CE_list = sorted(glob.glob(path_ID+'/1.0CE/*.jpg'))
        if CT_type == 'P':
            ID_X_list = ID_P_list
        if CT_type == 'CE':
            ID_X_list = ID_CE_list

        for j in range(len(ID_X_list)-inchan+1):                    
            trans_list = list()
            for k in range(inchan):
                trans_list.append(ID_X_list[j+k])
            ID_root_list.append(trans_list)

    if inchan == 1:
        path_ID = os.path.join(path,ID)
        ID_P_list = sorted(glob.glob(path_ID+'/1.0/*.jpg'))
        ID_CE_list = sorted(glob.glob(path_ID+'/1.0CE/*.jpg'))       
        if CT_type == 'P':
            ID_root_list = ID_P_list
            ID_X_list = ID_P_list
        if CT_type == 'CE':
            ID_root_list = ID_CE_list
            ID_X_list = ID_CE_list
    
    if 'KF' in ID_X_list[0]:
        label = 0
    else:
        label = 1
    return(ID_root_list, label)

def Find_Optimal_Cutoff(TPR, FPR, Thresholds):
    y = TPR - FPR
    Youden_index = np.argmax(y) 
    optimal_threshold = Thresholds[Youden_index]
    point = [FPR[Youden_index], TPR[Youden_index]]
    return optimal_threshold, point

# Xception

In [5]:
class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout, kernel_size, padding, bias=False):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, padding=padding, groups=nin, bias=bias)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)                                               
        return out
    
class Xception(nn.Module):
    def __init__(self,num_classes= None ,inchan = None, p_drop = None):
        super(Xception, self).__init__()
        
        self.p_drop = p_drop
        
        # Entry Flow
        self.entry_flow_1 = nn.Sequential(
            nn.Conv2d(inchan, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        
        self.entry_flow_2 = nn.Sequential(
            depthwise_separable_conv(64, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            depthwise_separable_conv(128, 128, 3, 1),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.entry_flow_2_residual = nn.Conv2d(64, 128, kernel_size=1, stride=2, padding=0)
        
        self.entry_flow_3 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(128, 256, 3, 1),
            nn.BatchNorm2d(256),
            
            nn.ReLU(True),
            depthwise_separable_conv(256, 256, 3, 1),
            nn.BatchNorm2d(256),
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.entry_flow_3_residual = nn.Conv2d(128, 256, kernel_size=1, stride=2, padding=0)
        
        self.entry_flow_4 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(256, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.entry_flow_4_residual = nn.Conv2d(256, 728, kernel_size=1, stride=2, padding=0)
        
        # Middle Flow
        self.middle_flow = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728)
        )
        
        # Exit Flow
        self.exit_flow_1 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(728, 728, 3, 1),
            nn.BatchNorm2d(728),
            
            nn.ReLU(True),
            depthwise_separable_conv(728, 1024, 3, 1),
            nn.BatchNorm2d(1024),
            
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.exit_flow_1_residual = nn.Conv2d(728, 1024, kernel_size=1, stride=2, padding=0)
        self.exit_flow_2 = nn.Sequential(
            depthwise_separable_conv(1024, 1536, 3, 1),
            nn.BatchNorm2d(1536),
            nn.ReLU(True),
            
            depthwise_separable_conv(1536, 2048, 3, 1),
            nn.BatchNorm2d(2048),
            nn.ReLU(True)
        )
        
        self.linear1 = nn.Linear(2048, 1024)
        self.linear2 = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        entry_out1 = self.entry_flow_1(x)
        entry_out2 = self.entry_flow_2(entry_out1) + self.entry_flow_2_residual(entry_out1)
        entry_out3 = self.entry_flow_3(entry_out2) + self.entry_flow_3_residual(entry_out2)
        entry_out = self.entry_flow_4(entry_out3) + self.entry_flow_4_residual(entry_out3)
        
        middle_out = self.middle_flow(entry_out) + entry_out

        exit_out1 = self.exit_flow_1(middle_out) + self.exit_flow_1_residual(middle_out)
        exit_out2 = self.exit_flow_2(exit_out1)

        exit_avg_pool = F.adaptive_avg_pool2d(exit_out2, (1, 1))     
        exit_avg_pool_flat = exit_avg_pool.view(exit_avg_pool.size(0), -1)

        exit_avg_pool_flat = nn.Dropout(p=self.p_drop)(exit_avg_pool_flat)
        output = self.linear1(exit_avg_pool_flat)
        output = self.linear2(output)
        
        return output
    
def xception(x1,x2,x3):
    return Xception(num_classes= x1 ,inchan = x2, p_drop = x3)

# fit()

In [6]:
def fit(epochs, model, loss_func, opt, train_dl, test_dl):
    best_acc = 0
    
    for epoch in range(epochs):
        epoch_n.append(epoch)
        
        model.train()
        train_acc = list()
        train_loss = list()
        for xb, yb in train_dl:
            xb = xb.to(dev)
            yb = yb.to(dev)
            tbs_loss, tbs_preds = loss_batch(model, loss_func, xb, yb, opt)
            tbs_acc = (tbs_preds == yb).float().mean()
            train_loss.append(tbs_loss)
            train_acc.append(tbs_acc.tolist())
        tr_loss= np.array(train_loss).mean()
        tr_acc = np.array(train_acc).mean()
        
        print('epoch:', epoch, ", trian_loss:",tr_loss,', trian_acc:',tr_acc )
        with open(file_record , "a") as f:
            f.write('epoch:'+str(epoch)+", trian_loss:"+str(tr_loss)+', trian_acc:'+str(tr_acc)+'\n')
        
        loss_training.append(tr_loss)
        acc_training.append(tr_acc)
        
        if save_state == 1:
            t.save(model.state_dict(), state_name)
        
        model.eval()
        with t.no_grad(): 
            test_acc = list()
            test_loss = list()
            for xb, yb in test_dl:
                xb = xb.to(dev)
                yb = yb.to(dev)
                tebs_loss, tebs_preds = loss_batch(model, loss_func, xb, yb)
                tebs_acc = (tebs_preds == yb).float().mean()
                test_loss.append(tebs_loss)
                test_acc.append(tebs_acc.tolist())
            
            te_loss = np.array(test_loss).mean()
            te_acc = np.array(test_acc).mean()
            
            now_acc = te_acc
            
            print("test_loss:", te_loss ,", test_acc:", now_acc )
            with open(file_record , "a") as f:
                f.write("test_loss:"+str(te_loss)+", test_acc:"+str(now_acc)+'\n' )
            
            loss_testing.append(te_loss)
            acc_testing.append(te_acc)
            
            if epoch == 0 and save_state == 1: 
                if save_state == 1:
                    t.save(model.state_dict(), best_acc_name)
                best_acc = now_acc
                best_acc_epoch = epoch
            if now_acc > best_acc and epoch > 0 and save_state == 1:
                if save_state == 1:
                    t.save(model.state_dict(), best_acc_name)
                best_acc = now_acc
                best_acc_epoch = epoch

    print('best_acc_testing:',max(acc_testing),', best_epoch:',best_acc_epoch)
    
    with open(file_record , "a") as f:
        f.write('best_epoch:'+str(best_acc_epoch)+'\n'+'best_acc_testing:'+str(max(acc_testing))+'\n\n')

# start_fold_traning

In [None]:
path_save = (pre_dir_name + '/' + model_name +'-'+ opt_name)          
file_record = path_save+'/record-{0}-{1}.txt'.format(model_name, opt_name)
csv_name_path = path_save+ '/record-{0}-{1}-lines.csv'.format(model_name, opt_name)
state_name = path_save+ '/'+'state-{0}-{1}-done.pkl'.format(model_name, opt_name)
best_acc_name = path_save+ '/'+'best_acc-{0}-{1}-done.pkl'.format(model_name, opt_name)

print('path_save:',path_save)
print('file_record:',file_record)
print('csv_name_path:',csv_name_path)

if save_state == 1:
    if not os.path.exists(path_save):
        os.makedirs(path_save)
    else:
        print('exist')

train_imgs_list = list()
test_imgs_list = list()

if inchan >= 2:
    for ID in train_ID_list:
        path_ID = os.path.join(path_CV,ID)
        ID_P_list = sorted(glob.glob(path_ID+'/1.0/*.jpg'))
        ID_CE_list = sorted(glob.glob(path_ID+'/1.0CE/*.jpg'))

        if CT_type == 'P':
            ID_X_list = ID_P_list
        if CT_type == 'CE':
            ID_X_list = ID_CE_list

        for j in range(len(ID_X_list)-inchan+1):                    
            trans_list = list()
            for k in range(inchan):
                trans_list.append(ID_X_list[j+k])
            train_imgs_list.append(trans_list)

    for ID in test_ID_list:
        path_ID = os.path.join(path_CV,ID)
        ID_P_list  = sorted(glob.glob(path_ID+'/1.0/*.jpg'))
        ID_CE_list = sorted(glob.glob(path_ID+'/1.0CE/*.jpg'))
        
        if CT_type == 'P':
            ID_X_list = ID_P_list
        if CT_type == 'CE':
            ID_X_list = ID_CE_list

        for j in range(len(ID_X_list)-inchan+1):                    
            trans_list = list()
            for k in range(inchan):
                trans_list.append(ID_X_list[j+k])
            test_imgs_list.append(trans_list)    

if inchan == 1:
    for ID in train_ID_list:
        path_ID = os.path.join(path_CV,ID)

        ID_P_list = sorted(glob.glob(path_ID+'/1.0/*.jpg'))
        ID_CE_list = sorted(glob.glob(path_ID+'/1.0CE/*.jpg'))

        if CT_type == 'P':
            train_imgs_list = train_imgs_list + ID_P_list
        if CT_type == 'CE':
            train_imgs_list = train_imgs_list + ID_CE_list

    for ID in test_ID_list:
        path_ID = os.path.join(path_CV,ID)

        ID_P_list = sorted(glob.glob(path_ID+'/1.0/*.jpg'))
        ID_CE_list = sorted(glob.glob(path_ID+'/1.0CE/*.jpg'))

        if CT_type == 'P':
            test_imgs_list = test_imgs_list + ID_P_list
        if CT_type == 'CE':
            test_imgs_list = test_imgs_list + ID_CE_list         

random.shuffle(train_imgs_list)
random.shuffle(test_imgs_list)

train_ds1 = CTdataset(train_imgs_list,transform, size_in)
train_ds2 = CTdataset(train_imgs_list,mirror, size_in)
train_ds = train_ds1 + train_ds2
test_ds = CTdataset(test_imgs_list,transform, size_in)

N = len(train_ds)
weights = [4.0 if label == 0  else 1 for data, label in train_ds] # RCC_patients/fp_AML_patients ≈  4
sampler = WeightedRandomSampler(weights,num_samples= N ,replacement=True)
train_dl = get_data(train_ds, bs, sampler)
test_dl = get_data(test_ds, bs2)

model = xception(n_class, inchan, 0.2) # dropout = 0.2
model.to(dev)

opt = optim.Adam(model.parameters()) #default_lr = 0.001

epoch_n = list()
loss_training = list()
loss_testing = list()
acc_training = list()
acc_testing = list()

fit(epochs, model, loss_func, opt, train_dl, test_dl)

df_fore = {'epoch_n':epoch_n,
           'loss_training':loss_training,
           'loss_testing':loss_testing,
           'acc_training':acc_training,
           'acc_testing':acc_testing}

df = pd.DataFrame(df_fore)
df.to_csv(csv_name_path, header=True, index=False)

# fold_best_testing_AUC

In [None]:
model_val = xception(n_class,inchan,0)
model_val.load_state_dict(t.load(best_acc_name))
model_val.to(dev)

model_val.eval()
with t.no_grad():
    acc_for_imgs = list()
    labels_test = list()
    preds_list = list()
    for ID in test_ID_list:
        score_pic = list()
        ID_root,label = get_ID_root_label(ID,path_CV)
        if inchan >= 2:
            for root in ID_root:
                image_data = np.zeros([size_in, size_in,inchan])
                for j in range(inchan):
                    pil_img_n = Image.open(root[j]).convert('L').resize((size_in, size_in),Image.ANTIALIAS)
                    image_data[:,:,j] = pil_img_n
                image_data = image_data.astype(np.float32)
                img = transform(image_data).unsqueeze(0)
                img = img.to(dev)
                out= model_val(img)
                pred = t.argmax(out,dim=1)
                pred = pred.item()
                score_pic.append(pred)
                if pred == int(label):
                    acc_for_imgs.append(1)
                else:
                    acc_for_imgs.append(0)

        if inchan == 1:
            for root in ID_root:
                pil_img_P = Image.open(root).convert('L').resize((size_in, size_in),Image.ANTIALIAS)
                image_data = np.zeros([size_in, size_in,1])
                image_data[:,:,0] = pil_img_P
                image_data = image_data.astype(np.float32)
                img = transform(image_data).unsqueeze(0)
                img = img.to(dev)
                out= model_val(img)
                pred = t.argmax(out,dim=1)
                pred = pred.item()
                score_pic.append(pred)
                if pred == int(label):
                    acc_for_imgs.append(1)
                else:
                    acc_for_imgs.append(0)
        score_ID = np.array(score_pic).mean()
        labels_test.append(int(label))
        preds_list.append(score_ID)
        
    acc_img_level_testing = np.array(acc_for_imgs).mean()
    fpr, tpr, thresholds = roc_curve(labels_test, preds_list)
    roc_auc = auc(fpr, tpr)
    optimal_th, optimal_point = Find_Optimal_Cutoff(TPR=tpr, FPR=fpr, Thresholds=thresholds)
    
    print('Image_level_ACC:',acc_img_level_testing)
    print('Patient_level_AUC:',roc_auc,', fpr:',optimal_point[0], ', tpr:',optimal_point[1],', cut-off:',optimal_th)

    with open(file_record , "a") as f:
        f.write('Image_level_ACC:'+str(acc_img_level_testing)+'\n')
        f.write('Patient_level_AUC:'+str(roc_auc)+', fpr:'+str(optimal_point[0])+', tpr:'+str(optimal_point[1])+', cut-off:'+str(optimal_th)+'\n\n')