In [1]:
import re
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from torch.hub import load_state_dict_from_url
from torch import Tensor
from typing import Any, List, Tuple
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import albumentations
import torch.utils.model_zoo as model_zoo
from torch.optim import Adam
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
from skimage.measure import label
from torch.optim import lr_scheduler


from PIL import Image
import timm

import cv2

In [2]:
torch.cuda.get_device_name(0)

'Tesla P100-PCIE-12GB'

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

In [4]:
df= pd.read_csv('/home/parkar.s/NIH_multilabel_classification/Data_Entry_2017.csv')

In [5]:
df.head()

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Unnamed: 11
0,00000001_000.png,Cardiomegaly,0,1,58,M,PA,2682,2749,0.143,0.143,
1,00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.143,0.143,
2,00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,
3,00000002_000.png,No Finding,0,2,81,M,PA,2500,2048,0.171,0.171,
4,00000003_000.png,Hernia,0,3,81,F,PA,2582,2991,0.143,0.143,


In [6]:
df['View Position'].value_counts()

PA    67310
AP    44810
Name: View Position, dtype: int64

## Creating Dataset

In [7]:
N_CLASSES = 14
CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']

In [8]:
file_path = "/scratch/parkar.s/NIH/images/"

In [9]:
train_list = []
targets = []
with open('/home/parkar.s/NIH_multilabel_classification/NIH_labels/train_list.txt') as fp:
    for line in fp:
        file, target = line.split(' ',1)
        a = os.path.join(file_path,file)
        train_list.append(a)
        targets.append([int(i) for i in target.strip().split()])

In [10]:
train_list[5], targets[5]

('/scratch/parkar.s/NIH/images/00023313_004.png',
 [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [11]:
class ImageDataset:
    def __init__(
        self,
        data_path,
        image_paths = [],
        targets = [],
        resize = None,
        augmentations=None,
        backend="cv2",
        channel_first= True

    ):
        """
        :param image_paths: list of paths to images
        :param targets: numpy array
        :param resize: tuple or None
        :param augmentations: albumentations augmentations
        """
        self.data_path = data_path
        self.resize = resize
        self.augmentations = augmentations
        self.backend = backend
        self.channel_first = channel_first
        
        self.image_paths = []
        self.targets = []
        with open(data_path) as fp:
            for line in fp:
                file, target = line.split(' ',1)
                a = os.path.join(file_path,file)
                self.image_paths.append(a)
                self.targets.append([int(i) for i in target.strip().split()])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        targets = self.targets[item]
        
        if self.backend == "cv2":
            
            #Load the image
            image = cv2.imread(self.image_paths[item],1)
            
            #Resize
            if self.resize is not None:
                image = cv2.resize(
                    image,
                    (self.resize[1], self.resize[0]),
                    interpolation=cv2.INTER_CUBIC,
                )
            
            if self.augmentations is not None:
                augmented = self.augmentations(image=image)
                image = augmented["image"]
        
        else:
            raise Exception("Backend not implemented")
        
        # converting to pytorch image format & 2,0,1 because pytorch excepts image channel first then dimension of image
        if self.channel_first:
            image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        
        # finally returning image tensor and its image id
        return {
            "image": torch.tensor(image),
            "targets": torch.tensor(targets),
        }

In [12]:
# learning rate
LR_G = 1e-8
LR_L = 1e-8
LR_F = 1e-3
num_epochs = 50
BATCH_SIZE = 16


FLAGS = {
    'fold': 0,
    'model': 'resnet152d',
    'pretrained': True,
    'batch_size': 4,
    'num_workers': 4,
    'lr': 3e-4,
    'epochs': 10,
    'beta1': 0.9,
    'beta2': 0.999,
    'cuda': True
}

In [13]:
train_aug = albumentations.Compose(
    [       albumentations.Resize(224, 224),
            albumentations.CenterCrop(224, 224),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            )
    ]
)

In [14]:
train_path = '/home/parkar.s/NIH_multilabel_classification/NIH_labels/train_list.txt'

train_dataset = ImageDataset(
    data_path = train_path,
    resize= None,
    augmentations= train_aug,
)


'''
The drop_last=True parameter ignores the last batch 
(when the number of examples in your dataset is not divisible by your batch_size) 
while drop_last=False will make the last batch smaller than your batch_size 
'''

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
                         shuffle=True, num_workers=4, pin_memory=True, drop_last = True)



In [15]:
train_loader.__len__()

4904

In [16]:
val_path = '/home/parkar.s/NIH_multilabel_classification/NIH_labels/val_list.txt'

valid_dataset = ImageDataset(
    data_path = val_path,
    resize= None,
    augmentations= train_aug,
)

valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE,
                         shuffle=True, num_workers=4, pin_memory=True, drop_last = True)

In [17]:
valid_loader.__len__()

701

## Helper Functions

In [18]:
def plot_image(img_dict):
    image_tensor = img_dict["image"]
    target = img_dict["targets"]
    print(target)
    plt.figure(figsize=(5, 5))
    image = image_tensor.permute(1, 2, 0)
    
    '''
    Loading an RGB image will result in an image consisting of integer values
    Converting into tensor will convert them into float. Now when you try to display such an image, you ll get the following error-
    Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
    
    So we first need to convert it back to uint8
    '''
    
    plt.imshow(image.numpy().astype('uint8'))

In [19]:
def compute_AUCs(gt, pred):
    """Computes Area Under the Curve (AUC) from prediction scores.
    Args:
        gt: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          true binary labels.
        pred: Pytorch tensor on GPU, shape = [n_samples, n_classes]
          can either be probability estimates of the positive class,
          confidence values, or binary decisions.
    Returns:
        List of AUROCs of all classes.
    """
    AUROCs = []
    gt_np = gt.cpu().numpy()
    pred_np = pred.cpu().numpy()
    for i in range(N_CLASSES):
        AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i]))
    return AUROCs

In [20]:
def Attention_gen_patchs(ori_image, fm_cuda):
    # feature map -> feature mask (using feature map to crop on the original image) -> crop -> patchs
    feature_conv = fm_cuda.data.cpu().numpy()
    size_upsample = (224, 224) 
    bz, nc, h, w = feature_conv.shape

    patchs_cuda = torch.FloatTensor().cuda()

    for i in range(0, bz):
        feature = feature_conv[i]
        cam = feature.reshape((nc, h*w))
        cam = cam.sum(axis=0)
        cam = cam.reshape(h,w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)

        heatmap_bin = binImage(cv2.resize(cam_img, size_upsample))
        heatmap_maxconn = selectMaxConnect(heatmap_bin)
        heatmap_mask = heatmap_bin * heatmap_maxconn

        ind = np.argwhere(heatmap_mask != 0)
        minh = min(ind[:,0])
        minw = min(ind[:,1])
        maxh = max(ind[:,0])
        maxw = max(ind[:,1])
        
        # to ori image 
        image = ori_image[i].numpy().reshape(224,224,3)
        image = image[int(224*0.334):int(224*0.667),int(224*0.334):int(224*0.667),:]

        image = cv2.resize(image, size_upsample)
        image_crop = image[minh:maxh,minw:maxw,:] * 256 # because image was normalized before
        image_crop = preprocess(Image.fromarray(image_crop.astype('uint8')).convert('RGB')) 

        img_variable = torch.autograd.Variable(image_crop.reshape(3,224,224).unsqueeze(0).cuda())

        patchs_cuda = torch.cat((patchs_cuda,img_variable),0)

    return patchs_cuda

In [21]:
def binImage(heatmap):
    _, heatmap_bin = cv2.threshold(heatmap , 0 , 255 , cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    # t in the paper
    #_, heatmap_bin = cv2.threshold(heatmap , 178 , 255 , cv2.THRESH_BINARY)
    return heatmap_bin


def selectMaxConnect(heatmap):
    labeled_img, num = label(heatmap, connectivity=2, background=0, return_num=True)    
    max_label = 0
    max_num = 0
    for i in range(1, num+1):
        if np.sum(labeled_img == i) > max_num:
            max_num = np.sum(labeled_img == i)
            max_label = i
    lcc = (labeled_img == max_label)
    if max_num == 0:
        lcc = (labeled_img == -1)
    lcc = lcc + 0
    return lcc 

In [22]:
normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
   transforms.Resize((256,256)),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   normalize,
])

## Global Only Resnet (Pretrained)

### Train and eval loops

In [23]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #device, will be different for each core on the TPU
epochs = FLAGS['epochs']
fold = FLAGS['fold']

In [24]:
def train_loop_fn(data_loader, loss_fn, model, optimizer, device, scheduler, epoch):
    model.train() # put model in training mode
    
    correct = 0
    total = 0
    running_loss = 0.0
    for index, d in enumerate(data_loader): # enumerate through the dataloader
        
        images = d['image'] # obtain the ids
        targets = d['targets'] # obtain the target
        if FLAGS['cuda']:
            images, targets = images.to(device), targets.to(device)

        # pass image to model
        
        # clear out the accumulated gradients
        optimizer.zero_grad()
        
        # make predictions
        outputs = model(images)
        
        targets = targets.to(torch.float32)
        
        # calculate loss
        loss = loss_fn(outputs, targets)
        
        if (index%500) == 0: 
            print('step: {} totalloss: {loss:.3f} '.format(index, loss = loss))

        # backpropagate
        loss.backward()
        
        # Performs parameter update
        optimizer.step()
        
        
        running_loss += loss.data.item()



        # Step the scheduler
        if scheduler is not None: 
            scheduler.step()
            
    epoch_loss = float(running_loss) / float(index)
    print(' Epoch over  Loss: {:.5f}'.format(epoch_loss))

     # put model in eval mode for later use
    
def eval_loop_fn(data_loader, loss_fn, model, device):
    
    #will notify all your layers that you are in eval mode,
    #that way, batchnorm or dropout layers will work in eval
    #mode instead of training mode.
    
    model.eval()
    
    fin_targets = []
    fin_outputs = []
    for bi, d in enumerate(data_loader): # enumerate through dataloader
        
        images = d['image'] # obtain the ids
        targets = d['targets']# # obtain the targets
        
        if FLAGS['cuda']:
            images, targets = images.to(device), targets.to(device)
            

        # pass image to model
        
        # no_grad impacts the autograd engine and deactivate it.
        # It will reduce memory usage and speed up computations
        # but you won’t be able to backprop
        
        with torch.no_grad(): 
            outputs = model(images)

        # Add the outputs and targets to a list 
        targets_np = targets.cpu().detach().numpy().tolist()
        outputs_np = outputs.cpu().detach().numpy().tolist()
        fin_targets.extend(targets_np)
        fin_outputs.extend(outputs_np)    
        del targets_np, outputs_np
        gc.collect() # delete for memory conservation
                
    o,t = np.array(fin_outputs), np.array(fin_targets)
    

    
    # calculate loss
    # loss = loss_fn(torch.tensor(o), t)
    
    AUROCs_g = compute_AUCs(torch.tensor(t), torch.tensor(o))
    AUROC_avg = np.array(AUROCs_g).mean()
    print('Global branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs_g[i]))
    


### Models

In [25]:
# Using Ross Wightman's timm package
class TimmModels(nn.Module):
    def __init__(self, model_name,pretrained=True, num_classes=3):
        super(TimmModels, self).__init__()
        self.Sigmoid = nn.Sigmoid()
        self.classifier = nn.Linear(2048, num_classes)

        self.m = timm.create_model(model_name,pretrained=pretrained)
        model_list = list(self.m.children())
        model_list = model_list[:-2]
        model_list[-1][-1].act3 = nn.Identity()
        self.m = nn.Sequential(*model_list)
        
    def forward(self, image):
        features = self.m(image)
        out = F.relu(features, inplace=True)
        out_after_pooling = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)
        out = self.classifier(out_after_pooling)
        out = self.Sigmoid(out)
        return out

In [26]:
class Fusion_Branch(nn.Module):
    def __init__(self, input_size, output_size):
        super(Fusion_Branch, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.Sigmoid = nn.Sigmoid()

    def forward(self, global_pool, local_pool):
        #fusion = torch.cat((global_pool.unsqueeze(2), local_pool.unsqueeze(2)), 2).cuda()
        #fusion = fusion.max(2)[0]#.squeeze(2).cuda()
        #print(fusion.shape)
        fusion = torch.cat((global_pool,local_pool), 1).cuda()
        fusion_var = torch.autograd.Variable(fusion)
        x = self.fc(fusion_var)
        x = self.Sigmoid(x)

        return x

In [27]:
MX = TimmModels(FLAGS['model'],pretrained=FLAGS['pretrained'], num_classes=14)

### Training

In [33]:
model = MX.to(device) # put model onto the current GPU
loss_fn = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=FLAGS['lr']) # often a good idea to scale the learning rate by number of cores
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*FLAGS['epochs']) #let's use a scheduler

In [None]:


print(f'========== training fold {FLAGS["fold"]} for {FLAGS["epochs"]} epochs ==========')
for i in range(2):
    print(f'EPOCH {i}:')
    # train one epoch
    train_loop_fn(train_loader, loss_fn, model, optimizer, device, scheduler, i)

    # validation one epoch
    eval_loop_fn(valid_loader, loss_fn, model, device)

    gc.collect()

print('Saving model...')

torch.save(model.state_dict(), f'resnet_NIH_{FLAGS["epochs"]}_epochs_pretrained.pth')

### Testing

In [34]:
val_path = '/home/parkar.s/NIH_multilabel_classification/NIH_labels/test_list.txt'

test_dataset = ImageDataset(
    data_path = val_path,
    resize= None,
    augmentations= train_aug,
)


test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                         shuffle=True, num_workers=4, pin_memory=True, drop_last = True)

In [35]:
MX_test = TimmModels(FLAGS['model'],pretrained=FLAGS['pretrained'], num_classes=14)
model_test = MX_test.to(device)

In [36]:
model_path = 'resnet_NIH_10_epochs_pretrained.pth'


In [37]:
model_test.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [38]:
eval_loop_fn(test_loader, loss_fn, model_test, device)

Global branch: The average AUROC is 0.831
The AUROC of Atelectasis is 0.8146734892948528
The AUROC of Cardiomegaly is 0.907767266665094
The AUROC of Effusion is 0.8788024915771728
The AUROC of Infiltration is 0.7049163282638219
The AUROC of Mass is 0.8338879825915774
The AUROC of Nodule is 0.768601744856423
The AUROC of Pneumonia is 0.7596380247226246
The AUROC of Pneumothorax is 0.863845226467845
The AUROC of Consolidation is 0.7939528478960858
The AUROC of Edema is 0.8908690128611136
The AUROC of Emphysema is 0.9247012695891238
The AUROC of Fibrosis is 0.8347302780955623
The AUROC of Pleural_Thickening is 0.7795339818358678
The AUROC of Hernia is 0.8815617090963228


# Training with attention

## Flags

In [23]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #device, will be different for each core on the TPU
epochs = FLAGS['epochs']
fold = FLAGS['fold']

In [24]:
FLAGS = {
    'fold': 0,
    'model': 'resnet152d',
    'pretrained': True,
    'batch_size': 4,
    'num_workers': 4,
    'lr': 3e-6,
    'lr_f': 1e-4,
    'epochs': 10,
    'beta1': 0.9,
    'beta2': 0.999,
    'cuda': True
}

## Models

In [25]:
# Using Ross Wightman's timm package
class TimmModels(nn.Module):
    def __init__(self, model_name,pretrained=True, num_classes=3):
        super(TimmModels, self).__init__()
        self.Sigmoid = nn.Sigmoid()
        self.classifier = nn.Linear(2048, num_classes)

        self.m = timm.create_model(model_name,pretrained=pretrained)
        model_list = list(self.m.children())
        model_list = model_list[:-2]
        model_list[-1][-1].act3 = nn.Identity()
        self.m = nn.Sequential(*model_list)
        
    def forward(self, image):
        features = self.m(image)
        out = F.relu(features, inplace=True)
        out_after_pooling = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)
        out = self.classifier(out_after_pooling)
        out = self.Sigmoid(out)
        return out, features, out_after_pooling

In [26]:
class Fusion_Branch(nn.Module):
    def __init__(self, input_size, output_size):
        super(Fusion_Branch, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.Sigmoid = nn.Sigmoid()

    def forward(self, global_pool, local_pool):
        #fusion = torch.cat((global_pool.unsqueeze(2), local_pool.unsqueeze(2)), 2).cuda()
        #fusion = fusion.max(2)[0]#.squeeze(2).cuda()
        #print(fusion.shape)
        fusion = torch.cat((global_pool,local_pool), 1).cuda()
        fusion_var = torch.autograd.Variable(fusion)
        x = self.fc(fusion_var)
        x = self.Sigmoid(x)

        return x

## Training and validation

In [27]:
def train_with_attn(data_loader, loss_fn, Global_Branch_model,Local_Branch_model, Fusion_Branch_model, 
                    optimizer_global, optimizer_local, optimizer_fusion, 
                    lr_scheduler_global, lr_scheduler_local, lr_scheduler_fusion, device, epoch):
    # put model in training mode
    Global_Branch_model.train()  #set model to training mode
    Local_Branch_model.train()
    Fusion_Branch_model.train()
    
    correct = 0
    total = 0
    running_loss = 0.0
    for index, d in enumerate(data_loader): # enumerate through the dataloader
        
        images = d['image'] # obtain the ids
        targets = d['targets'] # obtain the target
        if FLAGS['cuda']:
            images, targets = images.to(device), targets.to(device)

        # pass image to model
        
        # clear out the accumulated gradients
        
        optimizer_global.zero_grad()
        optimizer_local.zero_grad()
        optimizer_fusion.zero_grad()
        
        # make predictions

        output_global, fm_global, pool_global = Global_Branch_model(images)
        
        patchs_var = Attention_gen_patchs(images.cpu(),fm_global)
        
        output_local, _, pool_local = Local_Branch_model(patchs_var)

        output_fusion = Fusion_Branch_model(pool_global, pool_local)
        
        targets = targets.to(torch.float32)
        
        # calculate loss
        loss1 = loss_fn(output_global, targets)
        loss2 = loss_fn(output_local, targets)
        loss3 = loss_fn(output_fusion, targets)
        
        loss = loss1*0.8 + loss2*0.1 + loss3*0.1 

        if (index%500) == 0: 
            print('step: {} totalloss: {loss:.3f} loss1: {loss1:.3f} loss2: {loss2:.3f} loss3: {loss3:.3f}'.format(index, loss = loss, loss1 = loss1, loss2 = loss2, loss3 = loss3))

        # backpropagate
        loss.backward()
        
        # Performs parameter update
        optimizer_global.step()  
        optimizer_local.step()
        optimizer_fusion.step()
        
        running_loss += loss.data.item()

        # Step the scheduler
        lr_scheduler_global.step()
        lr_scheduler_local.step()
        lr_scheduler_fusion.step()
            
    epoch_loss = float(running_loss) / float(index)
    print(' Epoch over  Loss: {:.5f}'.format(epoch_loss))

    
     # put model in eval mode for later use

In [28]:
def eval_with_attn(data_loader, loss_fn, model_global, model_local, model_fusion, device):
    
    #will notify all your layers that you are in eval mode,
    #that way, batchnorm or dropout layers will work in eval
    #mode instead of training mode.
    
    model_global.eval()
    model_local.eval()
    model_fusion.eval()
    
    fin_targets = []
    global_outputs = []
    local_outputs = []
    fusion_outputs = []
    
    for bi, d in enumerate(data_loader): # enumerate through dataloader
        
        images = d['image'] # obtain the ids
        targets = d['targets']# # obtain the targets
        
        if FLAGS['cuda']:
            images, targets = images.to(device), targets.to(device)
            

        # pass image to model
        
        # no_grad impacts the autograd engine and deactivate it.
        # It will reduce memory usage and speed up computations
        # but you won’t be able to backprop
        
        with torch.no_grad(): 
            output_global, fm_global, pool_global = model_global(images)
        
            patchs_var = Attention_gen_patchs(images.cpu(),fm_global)

            output_local, _, pool_local = model_local(patchs_var)

            output_fusion = model_fusion(pool_global, pool_local)

        # Add the outputs and targets to a list 
        targets_np = targets.cpu().detach().numpy().tolist()
        outputs_g = output_global.cpu().detach().numpy().tolist()
        outputs_l = output_local.cpu().detach().numpy().tolist()
        outputs_f = output_fusion.cpu().detach().numpy().tolist()
        
        
        fin_targets.extend(targets_np) 
        
        global_outputs.extend(outputs_g)
        local_outputs.extend(outputs_l)
        fusion_outputs.extend(outputs_f)
        
        del targets_np, outputs_g, outputs_l, outputs_f
        gc.collect() # delete for memory conservation
                
    og, ol, of, t = np.array(global_outputs), np.array(local_outputs), np.array(fusion_outputs), np.array(fin_targets)
    

    
    # calculate loss
    # loss = loss_fn(torch.tensor(o), t)
    
    AUROCs_g = compute_AUCs(torch.tensor(t), torch.tensor(og))
    AUROC_avg = np.array(AUROCs_g).mean()
    print('Global branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs_g[i]))
    
    AUROCs_l = compute_AUCs(torch.tensor(t), torch.tensor(ol))
    AUROC_avg = np.array(AUROCs_l).mean()
    print('\n')
    print('Local branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs_l[i]))

    AUROCs_f = compute_AUCs(torch.tensor(t), torch.tensor(of))
    AUROC_avg = np.array(AUROCs_f).mean()
    print('\n')
    print('Fusion branch: The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs_f[i]))


In [29]:
MX_global = TimmModels(FLAGS['model'],pretrained=FLAGS['pretrained'], num_classes=14)
MX_local = TimmModels(FLAGS['model'],pretrained=FLAGS['pretrained'], num_classes=14)

Global_Branch_model = MX_global.to(device) # put model onto the current GPU
Local_Branch_model = MX_local.to(device) # put model onto the current GPU
Fusion_Branch_model = Fusion_Branch(input_size = 4096, output_size = N_CLASSES).to(device)

loss_fn = nn.BCELoss()

optimizer_global = Adam(Global_Branch_model.parameters(), lr=FLAGS['lr'])
optimizer_local = Adam(Local_Branch_model.parameters(), lr=FLAGS['lr'])
optimizer_fusion = Adam(Fusion_Branch_model.parameters(), lr=FLAGS['lr_f'])

#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*FLAGS['epochs']) #let's use a scheduler

lr_scheduler_global = lr_scheduler.StepLR(optimizer_global , step_size = 10, gamma = 1)
lr_scheduler_local = lr_scheduler.StepLR(optimizer_local , step_size = 10, gamma = 1)
lr_scheduler_fusion = lr_scheduler.StepLR(optimizer_fusion , step_size = 15, gamma = 0.1)

print(f'========== training fold {FLAGS["fold"]} for {FLAGS["epochs"]} epochs ==========')
for i in range(FLAGS['epochs']):
    print(f'EPOCH {i}:')
    # train one epoch
    train_with_attn(train_loader, loss_fn,  Global_Branch_model,Local_Branch_model, Fusion_Branch_model, optimizer_global, optimizer_local, optimizer_fusion, lr_scheduler_global, lr_scheduler_local, lr_scheduler_fusion, device,   i)

    # validation one epoch
    eval_with_attn(valid_loader, loss_fn, Global_Branch_model, Local_Branch_model, Fusion_Branch_model, device)

    gc.collect()
    
    torch.save(Global_Branch_model.state_dict(), f'densenet_attn_NIH_{FLAGS["epochs"]}_'+str(i)+'_global.pth')
    torch.save(Local_Branch_model.state_dict(), f'densenet_attn_NIH_{FLAGS["epochs"]}_'+str(i)+'_local.pth')
    torch.save(Fusion_Branch_model.state_dict(), f'densenet_attn_NIH_{FLAGS["epochs"]}_'+str(i)+'_fusion.pth')

print('Saving model...')

torch.save(Global_Branch_model.state_dict(), f'densenet_attn_NIH_{FLAGS["epochs"]}_global.pth')
torch.save(Local_Branch_model.state_dict(), f'densenet_attn_NIH_{FLAGS["epochs"]}_local.pth')
torch.save(Fusion_Branch_model.state_dict(), f'densenet_attn_NIH_{FLAGS["epochs"]}_fusion.pth')

EPOCH 0:
step: 0 totalloss: 0.695 loss1: 0.693 loss2: 0.699 loss3: 0.702
step: 500 totalloss: 0.390 loss1: 0.395 loss2: 0.440 loss3: 0.293
step: 1000 totalloss: 0.190 loss1: 0.185 loss2: 0.232 loss3: 0.191
step: 1500 totalloss: 0.202 loss1: 0.201 loss2: 0.202 loss3: 0.210
step: 2000 totalloss: 0.145 loss1: 0.144 loss2: 0.137 loss3: 0.159
step: 2500 totalloss: 0.183 loss1: 0.180 loss2: 0.184 loss3: 0.205
step: 3000 totalloss: 0.149 loss1: 0.143 loss2: 0.167 loss3: 0.177
step: 3500 totalloss: 0.090 loss1: 0.087 loss2: 0.091 loss3: 0.106
step: 4000 totalloss: 0.127 loss1: 0.126 loss2: 0.123 loss3: 0.135
step: 4500 totalloss: 0.127 loss1: 0.120 loss2: 0.142 loss3: 0.170
 Epoch over  Loss: 0.22480
Global branch: The average AUROC is 0.691
The AUROC of Atelectasis is 0.6991853728396662
The AUROC of Cardiomegaly is 0.6433295483176501
The AUROC of Effusion is 0.7787464919144008
The AUROC of Infiltration is 0.6418663858282632
The AUROC of Mass is 0.6565093381172694
The AUROC of Nodule is 0.6291

KeyboardInterrupt: 