In [1]:
!pip install libauc==1.2.0
!pip install medmnist

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting libauc==1.2.0
  Downloading libauc-1.2.0-py3-none-any.whl (73 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.6/73.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: libauc
Successfully installed libauc-1.2.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting medmnist
  Downloading medmnist-2.2.1-py3-none-any.whl (21 kB)
Collecting fire
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.5.0-py2.py3-none-any.whl size=116952 sha256=ed47ddc13bb414ea12539189a6fc18ec3

In [2]:
import libauc;
import numpy as np
import pandas as pd
from medmnist import VesselMNIST3D

from libauc.models import resnet18
from libauc.losses import AUCMLoss, CrossEntropyLoss
from libauc.optimizers import PESG, Adam
from libauc.utils import ImbalancedDataGenerator
from libauc.sampler import DualSampler  # data resampling (for binary class)
from libauc.metrics import auc_roc_score
from sklearn.metrics import roc_auc_score

import random
import scipy
from scipy.ndimage import rotate
from scipy import ndimage
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage import zoom

import torch 
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import warnings
warnings.filterwarnings("ignore")
SEED=123

  from scipy.ndimage.filters import gaussian_filter


In [3]:
'''
Adapted from kuangliu/pytorch-cifar .
'''

import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        # self.bn1 = nn.GroupNorm(num_groups=2, num_channels=planes)
        
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        # self.bn2 = nn.GroupNorm(num_groups=2, num_channels=planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
                # nn.GroupNorm(num_groups=2, num_channels=self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        # self.bn1 = nn.GroupNorm(num_groups=2, num_channels=planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        # self.bn2 = nn.GroupNorm(num_groups=2, num_channels=planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)
        # self.bn3 = nn.GroupNorm(num_groups=2, num_channels=self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
                # nn.GroupNorm(num_groups=2, num_channels=self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=1, num_classes=2):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # self.bn1 = nn.GroupNorm(num_groups=2, num_channels=64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        # out = F.avg_pool2d(out, 4)
        # out = F.adaptive_avg_pool3d(out, output_size=4)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(in_channels, num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, num_classes=num_classes)


def ResNet50(in_channels, num_classes):
    return ResNet(Bottleneck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes)

In [4]:
train_npz=VesselMNIST3D(split="train", download=True)
val_npz=VesselMNIST3D(split="val", download=True)
test_npz=VesselMNIST3D(split="test", download=True)

Downloading https://zenodo.org/record/6496656/files/vesselmnist3d.npz?download=1 to /root/.medmnist/vesselmnist3d.npz


100%|██████████| 398373/398373 [00:01<00:00, 300186.87it/s]


Using downloaded and verified file: /root/.medmnist/vesselmnist3d.npz
Using downloaded and verified file: /root/.medmnist/vesselmnist3d.npz


In [5]:
train_npz

Dataset VesselMNIST3D (vesselmnist3d)
    Number of datapoints: 1335
    Root location: /root/.medmnist
    Split: train
    Task: binary-class
    Number of channels: 1
    Meaning of labels: {'0': 'vessel', '1': 'aneurysm'}
    Number of samples: {'train': 1335, 'val': 192, 'test': 382}
    Description: The VesselMNIST3D is based on an open-access 3D intracranial aneurysm dataset, IntrA, containing 103 3D models (meshes) of entire brain vessels collected by reconstructing MRA images. 1,694 healthy vessel segments and 215 aneurysm segments are generated automatically from the complete models. We fix the non-watertight mesh with PyMeshFix and voxelize the watertight mesh with trimesh into 28×28×28 voxels. We split the source dataset with a ratio of 7:1:2 into training, validation and test set.
    License: CC BY 4.0

In [6]:
def gaussian_blur_3d(img):
    random.seed(SEED)
    sigma = random.uniform(0.1,0.9)
    blurred = gaussian_filter(img, sigma=sigma)
    return blurred

In [7]:
def x_flip(img):
    random.seed(SEED)
    flipped = img[:, :, ::-1]
    return flipped

In [8]:
def y_flip(img):
    random.seed(SEED)
    flipped = img[:, ::-1, :]
    return flipped

In [9]:
def zoom_xy(img, min_zoom, max_zoom):
    random.seed(SEED)
    zoom_factor = random.uniform(min_zoom, max_zoom)
    h, w = img.shape[0], img.shape[1]

    # For multichannel images we don't want to apply the zoom factor to the RGB
    # dimension, so instead we create a tuple of zoom factors, one per array
    # dimension, with 1's for any trailing dimensions after the width and height.
    zoom_tuple = (1, zoom_factor, zoom_factor)

    # Zooming out
    if zoom_factor < 1:

        # Bounding box of the zoomed-out image within the output array
        zh = int(np.round(h * zoom_factor))
        zw = int(np.round(w * zoom_factor))
        top = (h - zh) // 2
        left = (w - zw) // 2

        # Zero-padding
        out = np.zeros_like(img)
        zoomed_img = zoom(img, zoom_tuple, order=0)
        #print(f"zoomed shape: {zoomed_img.shape}")
        #print(f"out shape:{out.shape}")
        #print(f"w:{w},h:{h},l:{left},t:{top},zw:{zw}, zh:{zh}")
        out[:, top:top+zh, left:left+zw] = zoomed_img

    # Zooming in
    elif zoom_factor > 1:

        # Bounding box of the zoomed-in region within the input array
        zh = int(np.ceil(h / zoom_factor))
        zw = int(np.ceil(w / zoom_factor))
        top = (h - zh) // 2
        left = (w - zw) // 2

        #out_template = np.zeros_like(img)
        out = zoom(img[:, top:top+zh, left:left+zw], zoom_tuple, order=0)
        #print(f"out shape:{out.shape}")
        #print(f"w:{w},h:{h},l:{left},t:{top},zw:{zw}, zh:{zh}")

        # `out` might still be slightly larger than `img` due to rounding, so
        # trim off any extra pixels at the edges
        trim_top = ((out.shape[1] - h) // 2)
        trim_left = ((out.shape[2] - w) // 2)
        #print(f"out shape before:{out.shape}")
        out = out[:, trim_top:trim_top+h, trim_left:trim_left+w]
        #print(f"out shape after:{out.shape}")
        #print(f"w:{w},h:{h},l:{left},trimtop:{trim_top},trimleft:{trim_left}")

    # If zoom_factor == 1, just return the input array
    else:
        out = img
    #print(out.shape)
    return out

In [10]:
def random_rotation_3d(img, min_angle, max_angle):
    """ Randomly rotate an image by a random angle (-max_angle, max_angle).

    Arguments:
    max_angle: `float`. The maximum rotation angle.

    Returns:
    rotated 3D image
    """
    random.seed(SEED)
    img_rot = np.zeros(img.shape)
    angle = random.uniform(min_angle, max_angle)
    if random.randint(1,100) > 50:
        #in half the cases, rotate left. in other half, rotate right.
        angle *= -1
        # Following lines would rotate on z and y axis as well, but not using them in this kernel
#        # rotate along z-axis
#        image2 = scipy.ndimage.interpolation.rotate(image1, angle, mode='nearest', axes=(0, 1), reshape=False)
#        # rotate along y-axis
#        image3 = scipy.ndimage.interpolation.rotate(image2, angle, mode='nearest', axes=(0, 2), reshape=False)

    # rotate along x-axis
    img_rot = scipy.ndimage.interpolation.rotate(img, angle, mode='nearest', axes=(1, 2), reshape=False)
    return img_rot.reshape(img.shape)

In [11]:
def img_augment_3d(X_train,y_train):
      my_img=X_train
      my_label=y_train
      for i in range(0,X_train.shape[0]):
        img=X_train[i]
        img1 = gaussian_blur_3d(img)
        my_img=np.append(my_img,np.expand_dims(img1,axis=0),axis=0)
      my_label=np.append(my_label,y_train,axis=0)
      print("done")
      for i in range(0,X_train.shape[0]):
        img=X_train[i]
        img1 = x_flip(img)
        my_img=np.append(my_img,np.expand_dims(img1,axis=0),axis=0)
      my_label=np.append(my_label,y_train,axis=0)
      print("done")
      for i in range(0,X_train.shape[0]):
        img=X_train[i]
        img1 = y_flip(img)
        my_img=np.append(my_img,np.expand_dims(img1,axis=0),axis=0)
      my_label=np.append(my_label,y_train,axis=0)

      for i in range(0,X_train.shape[0]):
        img=X_train[i]
        img1 = random_rotation_3d(img, 1, 10)
        my_img=np.append(my_img,np.expand_dims(img1,axis=0),axis=0)
      my_label=np.append(my_label,y_train,axis=0)

      for i in range(0,X_train.shape[0]):
        img=X_train[i]
        img1 = zoom_xy(img, 0.9, 1.1)
        my_img=np.append(my_img,np.expand_dims(img1,axis=0),axis=0)
      my_label=np.append(my_label,y_train,axis=0)
    
      return my_img,my_label

In [12]:
X_train=train_npz.imgs
y_train=train_npz.labels

X_val=val_npz.imgs
y_val=val_npz.labels

X_test=test_npz.imgs
y_test=test_npz.labels

In [13]:
print(X_train.shape)
print(y_train.shape)
print(X_val.shape)
print(y_val.shape)
print(X_test.shape)
print(y_test.shape)

(1335, 28, 28, 28)
(1335, 1)
(192, 28, 28, 28)
(192, 1)
(382, 28, 28, 28)
(382, 1)


In [14]:
# X_train,y_train=img_augment_3d(X_train,y_train)

In [15]:
class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=28, crop_size=26, mode='train'):
       self.images = images.astype(np.uint8)
       self.targets = targets
       self.mode = mode
       self.transform_train = transforms.Compose([                                               
                              transforms.ToTensor(),
                              # transforms.RandomCrop((crop_size, crop_size, crop_), padding=None),
                              transforms.RandomHorizontalFlip(),
                              # transforms.Resize((image_size, image_size, image_size)),
                              ])
       self.transform_test = transforms.Compose([
                             transforms.ToTensor(),
                            #  transforms.Resize((image_size, image_size, image_size)),
                              ])
       
       # for loss function
       self.pos_indices = np.flatnonzero(targets==1)
       self.pos_index_map = {}
       for i, idx in enumerate(self.pos_indices):
           self.pos_index_map[idx] = i

    def __len__(self):
        # print(len(self.images))
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        # print("GetItem enter: {}".format(image.shape))
        # print("hello12")
        target = self.targets[idx]
        # image = Image.fromarray(image.astype('uint8'))
        # image = Image.fromarray(image.squeeze(), mode='L')
        # print("Image shape: {}".format(image.shape))
        # image = Image.fromarray(image.squeeze().astype('uint8'), mode='L')
        if self.mode == 'train':
            idx = self.pos_index_map[idx] if idx in self.pos_indices else -1
            image = self.transform_train(image)
            # print(type(image))
        else:
            image = self.transform_test(image)
            # print(type(image))
        return image, target, int(idx)



In [16]:
def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [17]:

# HyperParameters
SEED = 123
batch_size = 64
total_epochs = 80
decay_epochs = [50, 75]

lr = 0.07
margin = 1.0
epoch_decay = 0.003 # refers gamma in the paper
weight_decay = 0.0001

# oversampling minority class, you can tune it in (0, 0.5]
# e.g., sampling_rate=0.2 is that num of positive samples in mini-batch is sampling_rate*batch_size=13
sampling_rate = 0.2


In [18]:
imratio = 0.3
generator = ImbalancedDataGenerator(shuffle=True, verbose=True, random_seed=0)

(train_images, train_labels) = generator.transform(X_train, y_train, imratio=imratio)
(eval_images, eval_labels) = generator.transform(X_val, y_val, imratio=imratio)
(test_images, test_labels) = generator.transform(X_test, y_test, imratio=0.5) 

print((train_images.shape))
print((eval_images.shape))
print((test_images.shape))
trainSet = ImageDataset(train_images, train_labels)
evalSet = ImageDataset(eval_images, eval_labels)
testSet = ImageDataset(test_images, test_labels, mode='test')

sampler = DualSampler(trainSet, batch_size, sampling_rate=sampling_rate)
trainloader = torch.utils.data.DataLoader(trainSet, batch_size=batch_size,  sampler=sampler,  shuffle=False,  num_workers=1)
evalloader = torch.utils.data.DataLoader(evalSet, batch_size=batch_size,  shuffle=False,  num_workers=1)
testloader = torch.utils.data.DataLoader(testSet , batch_size=batch_size, shuffle=False, num_workers=1)

#SAMPLES: [1335], POS:NEG: [150 : 1185], POS RATIO: 0.1124
#SAMPLES: [192], POS:NEG: [22 : 170], POS RATIO: 0.1146
#SAMPLES: [382], POS:NEG: [43 : 339], POS RATIO: 0.1126
(1335, 28, 28, 28)
(192, 28, 28, 28)
(382, 28, 28, 28)


In [19]:


# model
set_all_seeds(SEED)
# model = resnet18(pretrained=False, num_classes=1, last_activation=None) 
model = ResNet18(in_channels = 28, num_classes= 2)

model = model.cuda()

# You can also pass Loss.a, Loss.b, Loss.alpha to optimizer (for old version users)
loss_fn = AUCMLoss()
optimizer = PESG(model, 
                 loss_fn=loss_fn,
                 lr=lr, 
                 momentum=0.9,
                 margin=margin,
                 epoch_decay=epoch_decay, 
                 weight_decay=weight_decay)

# model

In [20]:
print(evalloader)

<torch.utils.data.dataloader.DataLoader object at 0x7f8f1b3c4e20>


In [22]:
# # training
# print ('Start Training')
# print ('-'*30)

# best_val_auc = 0 
# for epoch in range(total_epochs):
#     # if epoch % 10 == 0:
#     #     optimizer.update_regularizer(decay_factor=2)    

#     for idx, (data, targets, _) in enumerate(trainloader):
#       train_data, train_labels = data, targets
#       train_data, train_labels  = train_data.cuda(), train_labels.cuda()
#       y_pred = model(train_data)
#       y_pred = torch.sigmoid(y_pred)
#       loss = loss_fn(y_pred, train_labels)
#       optimizer.zero_grad()
#       loss.backward()
#       optimizer.step()

#       # scheduler1.step(loss)
#       # torch.save(model.state_dict(), "Epoch: {}".format(epoch))
#       # print("Saving model for epoch number: {}".format(epoch))

        
#       # validation  
#       if idx % 20 == 0:
#          model.eval()
#          with torch.no_grad():    
#               test_pred = []
#               test_true = [] 
#               for jdx, (data, targets, _) in enumerate(testloader):
#                   test_data, test_labels = data, targets
#                   test_data = test_data.cuda()
#                   y_pred = model(test_data)
#                   y_pred = torch.sigmoid(y_pred)
#                   test_pred.append(y_pred.cpu().detach().numpy())
#                   test_true.append(test_labels.numpy())
            
#               test_true = np.concatenate(test_true)
#               test_pred = np.concatenate(test_pred)
#               # val_auc_mean = auc_roc_score(test_true, test_pred)
#               val_auc_mean = roc_auc_score(test_true, test_pred)[0]
#               print(val_auc_mean)
#               # val_auc_mean=val_auc_mean[0]
#               val_auc_mean=val_auc_mean

#               model.train()

#               if best_val_auc < val_auc_mean:
#                  best_val_auc = val_auc_mean
#                  torch.save(model.state_dict(), 'vessel_model.pt')

#               print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc))
    


Start Training
------------------------------


ValueError: ignored

In [26]:
print ('Start Training')
print ('-'*30)
test_best = 0
best_val_auc = 0 
train_list, test_list = [], []
for epoch in range(total_epochs):
    #if epoch in decay_epochs:
     #   optimizer.update_lr(decay_factor=10, coef_decay_factor=10)
            
    train_pred, train_true = [], []
    model.train() 
    for idx, (data, targets, index) in enumerate(trainloader):
        data, targets  = data.cuda(), targets.cuda()
        y_pred = model(data)
        y_prob = torch.sigmoid(y_pred)
        loss = loss_fn(y_prob, targets) # Notes: make index>0 for positive samples, and index<0 for negative samples
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_pred.append(y_prob.cpu().detach().numpy())
        train_true.append(targets.cpu().detach().numpy())

    # validation  
    model.eval()
    with torch.no_grad():    
        test_pred = []
        test_true = [] 
        for jdx, data in enumerate(testloader):
            test_data, test_labels, index = data
            test_data = test_data.cuda()
            y_pred = model(test_data)
            test_pred.append(y_pred.cpu().detach().numpy())
            test_true.append(test_labels.numpy())
      
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_auc_mean =  roc_auc_score(test_true[:,0], test_pred[:,0]) 
        model.train()

        if best_val_auc < val_auc_mean:
            best_val_auc = val_auc_mean
            torch.save(model.state_dict(), 'ce_pretrained_model.pth')

        print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc ))
    


Start Training
------------------------------
Epoch=0, BatchID=21, Val_AUC=0.8293, Best_Val_AUC=0.8293
Epoch=1, BatchID=21, Val_AUC=0.8033, Best_Val_AUC=0.8293
Epoch=2, BatchID=21, Val_AUC=0.8263, Best_Val_AUC=0.8293
Epoch=3, BatchID=21, Val_AUC=0.8567, Best_Val_AUC=0.8567
Epoch=4, BatchID=21, Val_AUC=0.8461, Best_Val_AUC=0.8567
Epoch=5, BatchID=21, Val_AUC=0.8638, Best_Val_AUC=0.8638
Epoch=6, BatchID=21, Val_AUC=0.8557, Best_Val_AUC=0.8638
Epoch=7, BatchID=21, Val_AUC=0.8641, Best_Val_AUC=0.8641
Epoch=8, BatchID=21, Val_AUC=0.8749, Best_Val_AUC=0.8749
Epoch=9, BatchID=21, Val_AUC=0.8666, Best_Val_AUC=0.8749
Epoch=10, BatchID=21, Val_AUC=0.8711, Best_Val_AUC=0.8749
Epoch=11, BatchID=21, Val_AUC=0.8725, Best_Val_AUC=0.8749
Epoch=12, BatchID=21, Val_AUC=0.8740, Best_Val_AUC=0.8749
Epoch=13, BatchID=21, Val_AUC=0.8705, Best_Val_AUC=0.8749
Epoch=14, BatchID=21, Val_AUC=0.8732, Best_Val_AUC=0.8749
Epoch=15, BatchID=21, Val_AUC=0.8728, Best_Val_AUC=0.8749
Epoch=16, BatchID=21, Val_AUC=0.8723

In [28]:
 # Testing
 ckpt =  torch.load("ce_pretrained_model.pth")
 model.load_state_dict(ckpt)
 model.eval()
 with torch.no_grad():    
      test_pred = []
      test_true = [] 
      for jdx, (data, targets, _) in enumerate(testloader):
          test_data, test_labels = data, targets
          test_data = test_data.cuda()
          y_pred = model(test_data)
          y_pred = torch.sigmoid(y_pred)
          test_pred.append(y_pred.cpu().detach().numpy())
          test_true.append(test_labels.numpy())

      test_true = np.concatenate(test_true)
      test_pred = np.concatenate(test_pred)
      # test_auc_mean = auc_roc_score(test_true, test_pred)[0]
      test_auc_mean =  roc_auc_score(test_true[:,0], test_pred[:,0])
      model.train()


      print ('Test result   ::::::::   Test_AUC = %.4f'%(test_auc_mean))

Test result   ::::::::   Test_AUC = 0.8764


In [30]:
from google.colab import files
files.download('ce_pretrained_model.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>