<a href="https://colab.research.google.com/github/sisifo3/P_T_3/blob/main/Densenet121_NIH.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_001.tar.gz
!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_002.tar.gz
!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_003.tar.gz
!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_004.tar.gz
!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_005.tar.gz
!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_006.tar.gz


#!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_007.tar.gz
#!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_008.tar.gz
#!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_009.tar.gz
#!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_010.tar.gz
#!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_011.tar.gz
#!tar -xf /content/drive/MyDrive/NIH_Dataset_total/images_012.tar.gz

In [1]:
pkl_dir_path             = 'pickles'
train_val_df_pkl_path    = 'train_val_df.pickle'
test_df_pkl_path         = 'test_df.pickle'
disease_classes_pkl_path = 'disease_classes.pickle'
models_dir               = 'models'

from torchvision import transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# transforms.RandomHorizontalFlip() not used because some disease might be more likely to the present in a specific lung (lelf/rigth)
transform = transforms.Compose([transforms.ToPILImage(), 
                    transforms.Resize(224),
                    transforms.ToTensor(),
                    normalize])

In [2]:
import glob, os, sys, pdb, time
import pandas as pd
import numpy as np
import cv2
import pickle
from torch.utils.data import Dataset
from tqdm import tqdm
import torch

#import config 

def q(text = ''): # easy way to exiting the script. useful while debugging
    print('> ', text)
    sys.exit()

class XRaysTrainDataset(Dataset):
    def __init__(self, data_dir, transform = None):
        self.data_dir = data_dir

        self.transform = transform
        # print('self.data_dir: ', self.data_dir)

        # full dataframe including train_val and test set
        self.df = self.get_df()
        print('self.df.shape: {}'.format(self.df.shape))

        self.make_pkl_dir(pkl_dir_path)

        # get train_val_df
        if not os.path.exists(os.path.join(pkl_dir_path, train_val_df_pkl_path)):

            self.train_val_df = self.get_train_val_df()
            print('\nself.train_val_df.shape: {}'.format(self.train_val_df.shape))

            # pickle dump the train_val_df
            with open(os.path.join(pkl_dir_path, train_val_df_pkl_path), 'wb') as handle:
                pickle.dump(self.train_val_df, handle, protocol = pickle.HIGHEST_PROTOCOL)
            print('{}: dumped'.format(train_val_df_pkl_path))
            
        else:
            # pickle load the train_val_df
            with open(os.path.join(pkl_dir_path, train_val_df_pkl_path), 'rb') as handle:
                self.train_val_df = pickle.load(handle)
            print('\n{}: loaded'.format(train_val_df_pkl_path))
            print('self.train_val_df.shape: {}'.format(self.train_val_df.shape))

        self.the_chosen, self.all_classes, self.all_classes_dict = self.choose_the_indices()
    
        if not os.path.exists(os.path.join(pkl_dir_path, disease_classes_pkl_path)):
            # pickle dump the classes list
            with open(os.path.join(pkl_dir_path, disease_classes_pkl_path), 'wb') as handle:
                pickle.dump(self.all_classes, handle, protocol = pickle.HIGHEST_PROTOCOL)
                print('\n{}: dumped'.format(disease_classes_pkl_path))
        else:
            print('\n{}: already exists'.format(disease_classes_pkl_path))

        self.new_df = self.train_val_df.iloc[self.the_chosen, :] # this is the sampled train_val data
        print('\nself.all_classes_dict: {}'.format(self.all_classes_dict))
            
    def resample(self):
        self.the_chosen, self.all_classes, self.all_classes_dict = self.choose_the_indices()
        self.new_df = self.train_val_df.iloc[self.the_chosen, :]
        print('\nself.all_classes_dict: {}'.format(self.all_classes_dict))

    def make_pkl_dir(self, pkl_dir_path):
        if not os.path.exists(pkl_dir_path):
            os.mkdir(pkl_dir_path)

    def get_train_val_df(self):

        # get the list of train_val data 
        train_val_list = self.get_train_val_list()

        train_val_df = pd.DataFrame()
        print('\nbuilding train_val_df...')
        for i in tqdm(range(self.df.shape[0])):
            filename  = os.path.basename(self.df.iloc[i,0])
            # print('filename: ', filename)
            if filename in train_val_list:
                train_val_df = train_val_df.append(self.df.iloc[i:i+1, :])

        # print('train_val_df.shape: {}'.format(train_val_df.shape))

        return train_val_df

    def __getitem__(self, index):
        row = self.new_df.iloc[index, :]

        img = cv2.imread(row['image_links'])
        labels = str.split(row['Finding Labels'], '|')
        
        target = torch.zeros(len(self.all_classes))
        for lab in labels:
            lab_idx = self.all_classes.index(lab)
            target[lab_idx] = 1            
    
        if self.transform is not None:
            img = self.transform(img)
    
        return img, target
        
    def choose_the_indices(self):
        
        max_examples_per_class = 10000 # its the maximum number of examples that would be sampled in the training set for any class
        the_chosen = []
        all_classes = {}
        length = len(self.train_val_df)
        # for i in tqdm(range(len(merged_df))):
        print('\nSampling the huuuge training dataset')
        for i in tqdm(list(np.random.choice(range(length),length, replace = False))):
            
            temp = str.split(self.train_val_df.iloc[i, :]['Finding Labels'], '|')

            # special case of ultra minority hernia. we will use all the images with 'Hernia' tagged in them.
            if 'Hernia' in temp:
                the_chosen.append(i)
                for t in temp:
                    if t not in all_classes:
                        all_classes[t] = 1
                    else:
                        all_classes[t] += 1
                continue

            # choose if multiple labels
            if len(temp) > 1:
                bool_lis = [False]*len(temp)
                # check if any label crosses the upper limit
                for idx, t in enumerate(temp):
                    if t in all_classes:
                        if all_classes[t]< max_examples_per_class: # 500
                            bool_lis[idx] = True
                    else:
                        bool_lis[idx] = True
                # if all lables under upper limit, append
                if sum(bool_lis) == len(temp):                    
                    the_chosen.append(i)
                    # maintain count
                    for t in temp:
                        if t not in all_classes:
                            all_classes[t] = 1
                        else:
                            all_classes[t] += 1
            else:        # these are single label images
                for t in temp:
                    if t not in all_classes:
                        all_classes[t] = 1
                    else:
                        if all_classes[t] < max_examples_per_class: # 500
                            all_classes[t] += 1
                            the_chosen.append(i)

        # print('len(all_classes): ', len(all_classes))
        # print('all_classes: ', all_classes)
        # print('len(the_chosen): ', len(the_chosen))
        
        '''
        if len(the_chosen) != len(set(the_chosen)):
            print('\nGadbad !!!')
            print('and the difference is: ', len(the_chosen) - len(set(the_chosen)))
        else:
            print('\nGood')
        '''

        return the_chosen, sorted(list(all_classes)), all_classes
    
    def get_df(self):
        csv_path = os.path.join('/content/drive/MyDrive/NIH_Dataset_total/Data_Entry_2017_v2020.csv')
        print('\n{} found: {}'.format(csv_path, os.path.exists(csv_path)))
        
        all_xray_df = pd.read_csv(csv_path)

        df = pd.DataFrame()        
        df['image_links'] = [x for x in glob.glob(os.path.join('/content/images/*.png'))]

        df['Image Index'] = df['image_links'].apply(lambda x : x[len(x)-16:len(x)])
        merged_df = df.merge(all_xray_df, how = 'inner', on = ['Image Index'])
        merged_df = merged_df[['image_links','Finding Labels']]
        return merged_df
    
    def get_train_val_list(self):
        f = open(os.path.join('data', 'NIH Chest X-rays', '/content/drive/MyDrive/NIH_Dataset_total/train_val_list.txt'), 'r')
        train_val_list = str.split(f.read(), '\n')
        return train_val_list

    def __len__(self):
        return len(self.new_df)


# prepare the test dataset
class XRaysTestDataset(Dataset):
    def __init__(self, data_dir, transform = None):
        self.data_dir = data_dir
        self.transform = transform
        # print('self.data_dir: ', self.data_dir)

        # full dataframe including train_val and test set
        self.df = self.get_df()
        print('\nself.df.shape: {}'.format(self.df.shape))

        self.make_pkl_dir(pkl_dir_path)

        # loading the classes list
        with open(os.path.join(pkl_dir_path, disease_classes_pkl_path), 'rb') as handle:
            self.all_classes = pickle.load(handle)

        # get test_df
        if not os.path.exists(os.path.join(pkl_dir_path, test_df_pkl_path)):

            self.test_df = self.get_test_df()
            print('self.test_df.shape: ', self.test_df.shape)
            
            # pickle dump the test_df
            with open(os.path.join(pkl_dir_path, test_df_pkl_path), 'wb') as handle:
                pickle.dump(self.test_df, handle, protocol = pickle.HIGHEST_PROTOCOL)
            print('\n{}: dumped'.format(test_df_pkl_path))
        else:
            # pickle load the test_df
            with open(os.path.join(pkl_dir_path, test_df_pkl_path), 'rb') as handle:
                self.test_df = pickle.load(handle)
            print('\n{}: loaded'.format(test_df_pkl_path))
            print('self.test_df.shape: {}'.format(self.test_df.shape))

    def __getitem__(self, index):
        row = self.test_df.iloc[index, :]
        
        img = cv2.imread(row['image_links'])
        labels = str.split(row['Finding Labels'], '|')
        
        target = torch.zeros(len(self.all_classes))
        for lab in labels:
            lab_idx = self.all_classes.index(lab)
            target[lab_idx] = 1            
    
        if self.transform is not None:
            img = self.transform(img)
    
        return img, target

    def make_pkl_dir(self, pkl_dir_path):
        if not os.path.exists(pkl_dir_path):
            os.mkdir(pkl_dir_path)

    def get_df(self):
        csv_path = os.path.join('/content/drive/MyDrive/NIH_Dataset_total/Data_Entry_2017_v2020.csv')
        
        all_xray_df = pd.read_csv(csv_path)

        df = pd.DataFrame()        
        df['image_links'] = [x for x in glob.glob(os.path.join('/content/images/*.png'))]

        df['Image Index'] = df['image_links'].apply(lambda x : x[len(x)-16:len(x)])
        merged_df = df.merge(all_xray_df, how = 'inner', on = ['Image Index'])
        merged_df = merged_df[['image_links','Finding Labels']]
        return merged_df

    def get_test_df(self):

        # get the list of test data 
        test_list = self.get_test_list()

        test_df = pd.DataFrame()
        print('\nbuilding test_df...')
        for i in tqdm(range(self.df.shape[0])):
            filename  = os.path.basename(self.df.iloc[i,0])
            # print('filename: ', filename)
            if filename in test_list:
                test_df = test_df.append(self.df.iloc[i:i+1, :])
         
        print('test_df.shape: ', test_df.shape)

        return test_df

    def get_test_list(self):
        f = open( os.path.join('data', 'NIH Chest X-rays', '/content/drive/MyDrive/NIH_Dataset_total/test_list.txt'), 'r')
        test_list = str.split(f.read(), '\n')
        return test_list

    def __len__(self):
        return len(self.test_df)






'''
# prepare the test dataset
import random
class XRaysTestDataset2(Dataset):
    def __init__(self, test_data_dir, transform = None):
        self.test_data_dir = test_data_dir
        self.transform = transform
        self.data_list = self.get_data_list(self.test_data_dir)
        
        self.subset = self.data_list[:1000]
    def __getitem__(self, index):
        img_path = self.data_list[index]
        img = cv2.imread(img_path)
        
        if self.transform is not None:
            img = self.transform(img)
            
        return img_path
    
    def sample(self):
        random.shuffle(self.data_list)
        self.subset = self.data_list[:np.random.randint(500,700)]
    def __len__(self):
        return len(self.subset)
        
    def get_data_list(self, data_dir):
        data_list = []
        for path in glob.glob(data_dir + os.sep + '*'):
            data_list.append(path)
        return data_list
'''

"\n# prepare the test dataset\nimport random\nclass XRaysTestDataset2(Dataset):\n    def __init__(self, test_data_dir, transform = None):\n        self.test_data_dir = test_data_dir\n        self.transform = transform\n        self.data_list = self.get_data_list(self.test_data_dir)\n        \n        self.subset = self.data_list[:1000]\n    def __getitem__(self, index):\n        img_path = self.data_list[index]\n        img = cv2.imread(img_path)\n        \n        if self.transform is not None:\n            img = self.transform(img)\n            \n        return img_path\n    \n    def sample(self):\n        random.shuffle(self.data_list)\n        self.subset = self.data_list[:np.random.randint(500,700)]\n    def __len__(self):\n        return len(self.subset)\n        \n    def get_data_list(self, data_dir):\n        data_list = []\n        for path in glob.glob(data_dir + os.sep + '*'):\n            data_list.append(path)\n        return data_list\n"

In [3]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import sys, os, time, random, pdb
import numpy as np
import pandas as pd
import torch.nn.functional as F
import torch
import pickle
import tqdm, pdb
from sklearn.metrics import roc_auc_score

#import config

def get_roc_auc_score(y_true, y_probs):
    '''
    Uses roc_auc_score function from sklearn.metrics to calculate the micro ROC AUC score for a given y_true and y_probs.
    '''

    with open(os.path.join(pkl_dir_path, disease_classes_pkl_path), 'rb') as handle:
        all_classes = pickle.load(handle)
    
    NoFindingIndex = all_classes.index('No Finding')

    if True:
        print('\nNoFindingIndex: ', NoFindingIndex)
        print('y_true.shape, y_probs.shape ', y_true.shape, y_probs.shape)
        GT_and_probs = {'y_true': y_true, 'y_probs': y_probs}
        with open('GT_and_probs', 'wb') as handle:
            pickle.dump(GT_and_probs, handle, protocol = pickle.HIGHEST_PROTOCOL)

    class_roc_auc_list = []    
    useful_classes_roc_auc_list = []
    
    for i in range(y_true.shape[1]):
        class_roc_auc = roc_auc_score(y_true[:, i], y_probs[:, i])
        class_roc_auc_list.append(class_roc_auc)
        if i != NoFindingIndex:
            useful_classes_roc_auc_list.append(class_roc_auc)
    if True:
        print('\nclass_roc_auc_list: ', class_roc_auc_list)
        print('\nuseful_classes_roc_auc_list', useful_classes_roc_auc_list)

    return np.mean(np.array(useful_classes_roc_auc_list))

def make_plot(epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list, save_name):
    '''
    This function makes the following 4 different plots-
    1. mean train loss VS number of epochs
    2. mean val   loss VS number of epochs
    3. batch train loss for all the training   batches VS number of batches
    4. batch val   loss for all the validation batches VS number of batches
    '''
    fig = plt.figure(figsize=(16,16))
    fig.suptitle('loss trends', fontsize=20)
    ax1 = fig.add_subplot(221)
    ax2 = fig.add_subplot(222)
    ax3 = fig.add_subplot(223)
    ax4 = fig.add_subplot(224)

    ax1.title.set_text('epoch train loss VS #epochs')
    ax1.set_xlabel('#epochs')
    ax1.set_ylabel('epoch train loss')
    ax1.plot(epoch_train_loss)

    ax2.title.set_text('epoch val loss VS #epochs')
    ax2.set_xlabel('#epochs')
    ax2.set_ylabel('epoch val loss')
    ax2.plot(epoch_val_loss)

    ax3.title.set_text('batch train loss VS #batches')
    ax3.set_xlabel('#batches')
    ax3.set_ylabel('batch train loss')
    ax3.plot(total_train_loss_list)

    ax4.title.set_text('batch val loss VS #batches')
    ax4.set_xlabel('#batches')
    ax4.set_ylabel('batch val loss')
    ax4.plot(total_val_loss_list)
    
    plt.savefig(os.path.join(models_dir,'losses_{}.png'.format(save_name)))

def get_resampled_train_val_dataloaders(XRayTrain_dataset, transform, bs):
    '''
    Resamples the XRaysTrainDataset class object and returns a training and a validation dataloaders, by splitting the sampled dataset in 80-20 ratio.
    '''
    XRayTrain_dataset.resample()

    train_percentage = 0.8
    train_dataset, val_dataset = torch.utils.data.random_split(XRayTrain_dataset, [int(len(XRayTrain_dataset)*train_percentage), len(XRayTrain_dataset)-int(len(XRayTrain_dataset)*train_percentage)])

    print('\n-----Resampled Dataset Information-----')
    print('num images in train_dataset   : {}'.format(len(train_dataset)))
    print('num images in val_dataset     : {}'.format(len(val_dataset)))
    print('---------------------------------------')

    # make dataloaders
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = bs, shuffle = True)
    val_loader   = torch.utils.data.DataLoader(val_dataset,   batch_size = bs, shuffle = not True)

    print('\n-----Resampled Batchloaders Information -----')
    print('num batches in train_loader: {}'.format(len(train_loader)))
    print('num batches in val_loader  : {}'.format(len(val_loader)))
    print('---------------------------------------------\n')

    return train_loader, val_loader
    
def train_epoch(device, train_loader, model, loss_fn, optimizer, epochs_till_now, final_epoch, log_interval):
    '''
    Takes in the data from the 'train_loader', calculates the loss over it using the 'loss_fn' 
    and optimizes the 'model' using the 'optimizer'  
    
    Also prints the loss and the ROC AUC score for the batches, after every 'log_interval' batches. 
    '''
    model.train()
    
    running_train_loss = 0
    train_loss_list = []

    start_time = time.time()
    for batch_idx, (img, target) in enumerate(train_loader):
        # print(type(img), img.shape) # , np.unique(img))

        img = img.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()    
        out = model(img)        
        loss = loss_fn(out, target)
        running_train_loss += loss.item()*img.shape[0]
        train_loss_list.append(loss.item())

        loss.backward()
        optimizer.step()
        
        if (batch_idx+1)%log_interval == 0:
            # batch metric evaluation
# #             out_detached = out.detach()
# #             batch_roc_auc_score = get_roc_auc_score(target, out_detached.numpy())
            # 'out' is a torch.Tensor and 'roc_auc_score' function first tries to convert it into a numpy array, but since 'out' has requires_grad = True, it throws an error
            # RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead. 
            # so we have to 'detach' the 'out' tensor and then convert it into a numpy array to avoid the error !  

            batch_time = time.time() - start_time
            m, s = divmod(batch_time, 60)
            print('Train Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(train_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), int(m), round(s, 2)))
        
        start_time = time.time()
            
    return train_loss_list, running_train_loss/float(len(train_loader.dataset))

def val_epoch(device, val_loader, model, loss_fn, epochs_till_now = None, final_epoch = None, log_interval = 1, test_only = False):
    '''
    It essentially takes in the val_loader/test_loader, the model and the loss function and evaluates
    the loss and the ROC AUC score for all the data in the dataloader.
    
    It also prints the loss and the ROC AUC score for every 'log_interval'th batch, only when 'test_only' is False
    '''
    model.eval()

    running_val_loss = 0
    val_loss_list = []
    val_loader_examples_num = len(val_loader.dataset)

    probs = np.zeros((val_loader_examples_num, 15), dtype = np.float32)
    gt    = np.zeros((val_loader_examples_num, 15), dtype = np.float32)
    k=0

    with torch.no_grad():
        batch_start_time = time.time()    
        for batch_idx, (img, target) in enumerate(val_loader):
            if test_only:
                per = ((batch_idx+1)/len(val_loader))*100
                a_, b_ = divmod(per, 1)
                print(f'{str(batch_idx+1).zfill(len(str(len(val_loader))))}/{str(len(val_loader)).zfill(len(str(len(val_loader))))} ({str(int(a_)).zfill(2)}.{str(int(100*b_)).zfill(2)} %)', end = '\r')
    #         print(type(img), img.shape) # , np.unique(img))

            img = img.to(device)
            target = target.to(device)    
    
            out = model(img)        
            loss = loss_fn(out, target)    
            running_val_loss += loss.item()*img.shape[0]
            val_loss_list.append(loss.item())

            # storing model predictions for metric evaluat`ion 
            probs[k: k + out.shape[0], :] = out.cpu()
            gt[   k: k + out.shape[0], :] = target.cpu()
            k += out.shape[0]

            if ((batch_idx+1)%log_interval == 0) and (not test_only): # only when ((batch_idx + 1) is divisible by log_interval) and (when test_only = False)
                # batch metric evaluation
#                 batch_roc_auc_score = get_roc_auc_score(target, out)

                batch_time = time.time() - batch_start_time
                m, s = divmod(batch_time, 60)
                print('Val Loss   for batch {}/{} @epoch{}/{}: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(val_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), int(m), round(s, 2)))
            
            batch_start_time = time.time()    
            
    # metric scenes
    roc_auc = get_roc_auc_score(gt, probs)

    return val_loss_list, running_val_loss/float(len(val_loader.dataset)), roc_auc

def fit(device, XRayTrain_dataset, train_loader, val_loader, test_loader, model,
                                         loss_fn, optimizer, losses_dict,
                                         epochs_till_now, epochs, 
                                         log_interval, save_interval, 
                                         lr, bs, stage, test_only = False):
    '''
    Trains or Tests the 'model' on the given 'train_loader', 'val_loader', 'test_loader' for 'epochs' number of epochs.
    If training ('test_only' = False), it saves the optimized 'model' and  the loss plots ,after every 'save_interval'th epoch.
    '''
    epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list = losses_dict['epoch_train_loss'], losses_dict['epoch_val_loss'], losses_dict['total_train_loss_list'], losses_dict['total_val_loss_list']

    final_epoch = epochs_till_now + epochs

    if test_only:
        print('\n======= Testing... =======\n')
        test_start_time = time.time()
        test_loss, mean_running_test_loss, test_roc_auc = val_epoch(device, test_loader, model, loss_fn, log_interval, test_only = test_only)
        total_test_time = time.time() - test_start_time
        m, s = divmod(total_test_time, 60)
        print('test_roc_auc: {} in {} mins {} secs'.format(test_roc_auc, int(m), int(s)))
        sys.exit()

    starting_epoch  = epochs_till_now
    print('\n======= Training after epoch #{}... =======\n'.format(epochs_till_now))

    # epoch_train_loss = []
    # epoch_val_loss = []
    
    # total_train_loss_list = []
    # total_val_loss_list = []

    for epoch in range(epochs):

        if starting_epoch != epochs_till_now:
            # resample the train_loader and val_loader
            train_loader, val_loader = get_resampled_train_val_dataloaders(XRayTrain_dataset, config.transform, bs = bs)

        epochs_till_now += 1
        print('============ EPOCH {}/{} ============'.format(epochs_till_now, final_epoch))
        epoch_start_time = time.time()
        
        print('TRAINING')
        train_loss, mean_running_train_loss  =  train_epoch(device, train_loader, model, loss_fn, optimizer, epochs_till_now, final_epoch, log_interval)
        print('VALIDATION')
        val_loss, mean_running_val_loss, roc_auc     =  val_epoch(device, val_loader, model, loss_fn, epochs_till_now, final_epoch, log_interval)
        
        epoch_train_loss.append(mean_running_train_loss)
        epoch_val_loss.append(mean_running_val_loss)

        total_train_loss_list.extend(train_loss)
        total_val_loss_list.extend(val_loss)

        save_name = 'stage{}_{}_{}'.format(stage, str.split(str(lr), '.')[-1], str(epochs_till_now).zfill(2))

        # the follwoing piece of codw needs to be worked on !!! LATEST DEVELOPMENT TILL HERE
        if ((epoch+1)%save_interval == 0) or test_only:
            save_path = os.path.join(models_dir, '{}.pth'.format(save_name))
            
            torch.save({
            'epochs': epochs_till_now,
            'model': model, # it saves the whole model
            'losses_dict': {'epoch_train_loss': epoch_train_loss, 'epoch_val_loss': epoch_val_loss, 'total_train_loss_list': total_train_loss_list, 'total_val_loss_list': total_val_loss_list}
            }, save_path)
            
            print('\ncheckpoint {} saved'.format(save_path))

            make_plot(epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list, save_name)
            print('loss plots saved !!!')

        print('\nTRAIN LOSS : {}'.format(mean_running_train_loss))
        print('VAL   LOSS : {}'.format(mean_running_val_loss))
        print('VAL ROC_AUC: {}'.format(roc_auc))

        total_epoch_time = time.time() - epoch_start_time
        m, s = divmod(total_epoch_time, 60)
        h, m = divmod(m, 60)
        print('\nEpoch {}/{} took {} h {} m'.format(epochs_till_now, final_epoch, int(h), int(m)))



'''   
def pred_n_write(test_loader, model, save_name):
    res = np.zeros((3000, 15), dtype = np.float32)
    k=0
    for batch_idx, img in tqdm.tqdm(enumerate(test_loader)):
        model.eval()
        with torch.no_grad():
            pred = torch.sigmoid(model(img))
            # print(k)
            res[k: k + pred.shape[0], :] = pred
            k += pred.shape[0]
            
    # write csv
    print('populating the csv')
    submit = pd.DataFrame()
    submit['ImageID'] = [str.split(i, os.sep)[-1] for i in test_loader.dataset.data_list]
    with open('disease_classes.pickle', 'rb') as handle:
        disease_classes = pickle.load(handle)
    
    for idx, col in enumerate(disease_classes):
        if col == 'Hernia':
            submit['Hern'] = res[:, idx]
        elif col == 'Pleural_Thickening':
            submit['Pleural_thickening'] = res[:, idx]
        elif col == 'No Finding':
            submit['No_findings'] = res[:, idx]
        else:
            submit[col] = res[:, idx]
    rand_num = str(random.randint(1000, 9999))
    csv_name = '{}___{}.csv'.format(save_name, rand_num)
    submit.to_csv('res/' + csv_name, index = False)
    print('{} saved !'.format(csv_name))
'''

"   \ndef pred_n_write(test_loader, model, save_name):\n    res = np.zeros((3000, 15), dtype = np.float32)\n    k=0\n    for batch_idx, img in tqdm.tqdm(enumerate(test_loader)):\n        model.eval()\n        with torch.no_grad():\n            pred = torch.sigmoid(model(img))\n            # print(k)\n            res[k: k + pred.shape[0], :] = pred\n            k += pred.shape[0]\n            \n    # write csv\n    print('populating the csv')\n    submit = pd.DataFrame()\n    submit['ImageID'] = [str.split(i, os.sep)[-1] for i in test_loader.dataset.data_list]\n    with open('disease_classes.pickle', 'rb') as handle:\n        disease_classes = pickle.load(handle)\n    \n    for idx, col in enumerate(disease_classes):\n        if col == 'Hernia':\n            submit['Hern'] = res[:, idx]\n        elif col == 'Pleural_Thickening':\n            submit['Pleural_thickening'] = res[:, idx]\n        elif col == 'No Finding':\n            submit['No_findings'] = res[:, idx]\n        else:\n    

In [4]:
###################Temporal model densenet121##################
#######################Using in CheXpert######################

import torch
import  torchvision
import torch.nn as nn
from torchsummary import summary

class DenseNet121(nn.Module):
    
    def __init__(self, out_size, test=False):
        super(DenseNet121, self).__init__()
        self.test = test
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, num_ftrs),
            nn.BatchNorm1d(num_ftrs),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(num_ftrs, num_ftrs//2),
            nn.BatchNorm1d(num_ftrs//2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(num_ftrs//2, out_size),
        )

    def forward(self, x):
        x = self.densenet121(x)

        if self.test:
            return torch.sigmoid(x)
        else:
            return x

In [5]:
import torch, sys, os, pdb
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    
    def __init__(self, device, gamma = 1.0):
        super(FocalLoss, self).__init__()
        self.device = device
        self.gamma = torch.tensor(gamma, dtype = torch.float32).to(device)
        self.eps = 1e-6
        
#         self.BCE_loss = nn.BCEWithLogitsLoss(reduction='none').to(device)
        
    def forward(self, input, target):
        
        BCE_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none').to(self.device)
#         BCE_loss = self.BCE_loss(input, target)
        pt = torch.exp(-BCE_loss) # prevents nans when probability 0
        F_loss =  (1-pt)**self.gamma * BCE_loss
        
        return F_loss.mean() 

    # def forward(self, input, target):

    #     # input are not the probabilities, they are just the cnn out vector
    #     # input and target shape: (bs, n_classes)
    #     # sigmoid
    #     probs = torch.sigmoid(input)
    #     log_probs = -torch.log(probs)

    #     focal_loss = torch.sum(   torch.pow(1-probs + self.eps, self.gamma).mul(log_probs).mul(target)   , dim=1)
    #     # bce_loss = torch.sum(log_probs.mul(target), dim = 1)
        
    #     return focal_loss.mean() #, bce_loss

if __name__ == '__main__':
    inp = torch.tensor([[1., 0.95], 
                        [.9, 0.3], 
                        [0.6, 0.4]], requires_grad = True)
    target = torch.tensor([[1., 1], 
                        [1, 0], 
                        [0, 0]])

    print('inp\n',inp, '\n')
    print('target\n',target, '\n')

    print('inp.requires_grad:', inp.requires_grad, inp.shape)
    print('target.requires_grad:', target.requires_grad, target.shape)


    #loss = FocalLoss(device,gamma = 2)

    #focal_loss, bce_loss = loss(inp ,target)
    #print('\nbce_loss',bce_loss, '\n')
    #print('\nfocal_loss',focal_loss, '\n')

inp
 tensor([[1.0000, 0.9500],
        [0.9000, 0.3000],
        [0.6000, 0.4000]], requires_grad=True) 

target
 tensor([[1., 1.],
        [1., 0.],
        [0., 0.]]) 

inp.requires_grad: True torch.Size([3, 2])
target.requires_grad: False torch.Size([3, 2])


In [None]:
import argparse
import os, pdb, sys, glob, time
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2

import torch
import torch.nn as nn
import torchvision.models as models 

# import custom dataset classes
#from datasets import XRaysTrainDataset  
#from datasets import XRaysTestDataset

# import neccesary libraries for defining the optimizers
import torch.optim as optim

#from trainer import fit
#import config

def q(text = ''): # easy way to exiting the script. useful while debugging
    print('> ', text)
    sys.exit()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'\ndevice: {device}')
    
parser = argparse.ArgumentParser(description='Following are the arguments that can be passed form the terminal itself ! Cool huh ? :D')
parser.add_argument('--data_path', type = str, default = 'NIH Chest X-rays', help = 'This is the path of the training data')
parser.add_argument('--bs', type = int, default = 16, help = 'batch size')
parser.add_argument('--lr', type = float, default = 0.001, help = 'Learning Rate for the optimizer')
parser.add_argument('--stage', type = int, default = 1, help = 'Stage, it decides which layers of the Neural Net to train')
parser.add_argument('--loss_func', type = str, default = 'FocalLoss', choices = {'BCE', 'FocalLoss'}, help = 'loss function')
parser.add_argument('-r','--resume', action = 'store_false') # args.resume will return True if -r or --resume is used in the terminal
parser.add_argument('--ckpt', type = str, default = '/content/models/stage1_001_01.pth',help = 'Path of the ckeckpoint that you wnat to load')
parser.add_argument('-t','--test', action = 'store_true')   # args.test   will return True if -t or --test   is used in the terminal
parser.add_argument('-f')
args = parser.parse_args()

if args.resume and args.test: # what if --test is not defiend at all ? test case hai ye ek
    q('The flow of this code has been designed either to train the model or to test it.\nPlease choose either --resume or --test')

stage = args.stage
if not args.resume:
    print(f'\nOverwriting stage to 1, as the model training is being done from scratch')
    stage = 1
    
if args.test:
    print('TESTING THE MODEL')
else:
    if args.resume:
        print('RESUMING THE MODEL TRAINING')
    else:
        print('TRAINING THE MODEL FROM SCRATCH')

script_start_time = time.time() # tells the total run time of this script

# mention the path of the data
data_dir = os.path.join('data',args.data_path) # Data_Entry_2017.csv should be present in the mentioned path

# define a function to count the total number of trainable parameters
def count_parameters(model): 
    num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_parameters/1e6 # in terms of millions

# make the datasets
XRayTrain_dataset = XRaysTrainDataset(data_dir, transform = transform)
train_percentage = 0.8
train_dataset, val_dataset = torch.utils.data.random_split(XRayTrain_dataset, [int(len(XRayTrain_dataset)*train_percentage), len(XRayTrain_dataset)-int(len(XRayTrain_dataset)*train_percentage)])

XRayTest_dataset = XRaysTestDataset(data_dir, transform = transform)

print('\n-----Initial Dataset Information-----')
print('num images in train_dataset   : {}'.format(len(train_dataset)))
print('num images in val_dataset     : {}'.format(len(val_dataset)))
print('num images in XRayTest_dataset: {}'.format(len(XRayTest_dataset)))
print('-------------------------------------')

# make the dataloaders
batch_size = args.bs # 128 by default
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = not True)
test_loader = torch.utils.data.DataLoader(XRayTest_dataset, batch_size = batch_size, shuffle = not True)

print('\n-----Initial Batchloaders Information -----')
print('num batches in train_loader: {}'.format(len(train_loader)))
print('num batches in val_loader  : {}'.format(len(val_loader)))
print('num batches in test_loader : {}'.format(len(test_loader)))
print('-------------------------------------------')

# sanity check
if len(XRayTrain_dataset.all_classes) != 15: # 15 is the unique number of diseases in this dataset
    q('\nnumber of classes not equal to 15 !')

a,b = train_dataset[0]
print('\nwe are working with \nImages shape: {} and \nTarget shape: {}'.format( a.shape, b.shape))

# make models directory, where the models and the loss plots will be saved
if not os.path.exists(models_dir):
    os.mkdir(models_dir)

# define the loss function
if args.loss_func == 'FocalLoss': # by default
    #from losses import FocalLoss
    loss_fn = FocalLoss(device = device, gamma = 2.).to(device)
elif args.loss_func == 'BCE':
    loss_fn = nn.BCEWithLogitsLoss().to(device)

# define the learning rate
lr = args.lr

if not args.test: # training

    # initialize the model if not args.resume
    if not args.resume:
        print('\ntraining from scratch')
        #########################################3
        ######################
        # import pretrained model
        ####model = models.resnet50(pretrained=True) # pretrained = False bydefault
        # change the last linear layer
        ####num_ftrs = model.fc.in_features
        ####model.fc = nn.Linear(num_ftrs, len(XRayTrain_dataset.all_classes)) # 15 output classes 
        num_classes = (len(XRayTrain_dataset.all_classes))
        model = DenseNet121(num_classes)
        model = torch.nn.DataParallel(model)
        model.cuda()

        #checkpoint = '/content/drive/MyDrive/NIH_Dataset_total/weights_Densenet121/weightsDensenet_121_v13'  
        #checkpoint = torch.load(checkpoint)
        #model.load_state_dict(checkpoint['model_state_dict'])
        
        #############################3
        ##############################

        model.to(device)
        
        print('----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc'
        for name, param in model.named_parameters(): # all requires_grad by default, are True initially
            # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters  
            if ('layer2' in name) or ('layer3' in name) or ('layer4' in name) or ('fc' in name):
                param.requires_grad = True 
            else:
                param.requires_grad = False

        # since we are not resuming the training of the model
        epochs_till_now = 0

        # making empty lists to collect all the losses
        losses_dict = {'epoch_train_loss': [], 'epoch_val_loss': [], 'total_train_loss_list': [], 'total_val_loss_list': []}

    else:
        if args.ckpt == None:
            q('ERROR: Please select a valid checkpoint to resume from')
            
        print('\nckpt loaded: {}'.format(args.ckpt))
        ckpt = torch.load(os.path.join(models_dir, args.ckpt)) 

        # since we are resuming the training of the model
        epochs_till_now = ckpt['epochs']
        model = ckpt['model']
        model.to(device)
        
        # loading previous loss lists to collect future losses
        losses_dict = ckpt['losses_dict']

    # printing some hyperparameters
    print('\n> loss_fn: {}'.format(loss_fn))
    print('> epochs_till_now: {}'.format(epochs_till_now))
    print('> batch_size: {}'.format(batch_size))
    print('> stage: {}'.format(stage))
    print('> lr: {}'.format(lr))

else: # testing
    if args.ckpt == None:
        q('ERROR: Please select a checkpoint to load the testing model from')
        
    print('\ncheckpoint loaded: {}'.format(args.ckpt))
    ckpt = torch.load(os.path.join(models_dir, args.ckpt)) 

    # since we are resuming the training of the model
    epochs_till_now = ckpt['epochs']
    model = ckpt['model']
    
    # loading previous loss lists to collect future losses
    losses_dict = ckpt['losses_dict']

# make changes(freezing/unfreezing the model's layers) in the following, for training the model for different stages 
if (not args.test) and (args.resume):

    if stage == 1:

        print('\n----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc'
        for name, param in model.named_parameters(): # all requires_grad by default, are True initially
            # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters  
            if ('layer2' in name) or ('layer3' in name) or ('layer4' in name) or ('fc' in name):
                param.requires_grad = True 
            else:
                param.requires_grad = False

    elif stage == 2:

        print('\n----- STAGE 2 -----') # only training 'layer3', 'layer4' and 'fc'
        for name, param in model.named_parameters(): 
            # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters  
            if ('layer3' in name) or ('layer4' in name) or ('fc' in name):
                param.requires_grad = True 
            else:
                param.requires_grad = False

    elif stage == 3:

        print('\n----- STAGE 3 -----') # only training  'layer4' and 'fc'
        for name, param in model.named_parameters(): 
            # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters  
            if ('layer4' in name) or ('fc' in name):
                param.requires_grad = True 
            else:
                param.requires_grad = False

    elif stage == 4:

        print('\n----- STAGE 4 -----') # only training 'fc'
        for name, param in model.named_parameters(): 
            # print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters  
            if ('fc' in name):
                param.requires_grad = True 
            else:
                param.requires_grad = False


if not args.test:
    # checking the layers which are going to be trained (irrespective of args.resume)
    trainable_layers = []
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            layer_name = str.split(name, '.')[0]
            if layer_name not in trainable_layers: 
                trainable_layers.append(layer_name)
    print('\nfollowing are the trainable layers...')
    print(trainable_layers)

    print('\nwe have {} Million trainable parameters here in the {} model'.format(count_parameters(model), model.__class__.__name__))

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)

# make changes in the parameters of the following 'fit' function
fit(device, XRayTrain_dataset, train_loader, val_loader,    
                                        test_loader, model, loss_fn, 
                                        optimizer, losses_dict,
                                        epochs_till_now = epochs_till_now, epochs = 1,
                                        log_interval = 25, save_interval = 1,
                                        lr = lr, bs = batch_size, stage = stage,
                                        test_only = args.test)

script_time = time.time() - script_start_time
m, s = divmod(script_time, 60)
h, m = divmod(m, 60)
print('{} h {}m laga poore script me !'.format(int(h), int(m)))

# ''' 
# This is how the model is trained...
# ##### STAGE 1 ##### FocalLoss lr = 1e-5
# training layers = layer2, layer3, layer4, fc 
# epochs = 2
# ##### STAGE 2 ##### FocalLoss lr = 3e-4
# training layers = layer3, layer4, fc 
# epochs = 5
# ##### STAGE 3 ##### FocalLoss lr = 7e-4
# training layers = layer4, fc 
# epochs = 4
# ##### STAGE 4 ##### FocalLoss lr = 1e-3
# training layers = fc 
# epochs = 3
# '''


device: cuda
RESUMING THE MODEL TRAINING

/content/drive/MyDrive/NIH_Dataset_total/Data_Entry_2017_v2020.csv found: True
self.df.shape: (54999, 2)

train_val_df.pickle: loaded
self.train_val_df.shape: (44597, 2)

Sampling the huuuge training dataset


100%|██████████| 44597/44597 [00:07<00:00, 5952.16it/s]



disease_classes.pickle: already exists

self.all_classes_dict: {'Consolidation': 1455, 'Infiltration': 5800, 'Mass': 1743, 'No Finding': 10000, 'Atelectasis': 4082, 'Effusion': 4166, 'Emphysema': 762, 'Pneumothorax': 1282, 'Pneumonia': 421, 'Cardiomegaly': 863, 'Pleural_Thickening': 1089, 'Nodule': 2198, 'Edema': 642, 'Fibrosis': 792, 'Hernia': 79}

self.df.shape: (54999, 2)

test_df.pickle: loaded
self.test_df.shape: (10402, 2)

-----Initial Dataset Information-----
num images in train_dataset   : 21716
num images in val_dataset     : 5430
num images in XRayTest_dataset: 10402
-------------------------------------

-----Initial Batchloaders Information -----
num batches in train_loader: 1358
num batches in val_loader  : 340
num batches in test_loader : 651
-------------------------------------------

we are working with 
Images shape: torch.Size([3, 224, 224]) and 
Target shape: torch.Size([15])

ckpt loaded: /content/models/stage1_001_01.pth

> loss_fn: FocalLoss()
> epochs_till_now

In [1]:
!cp /content/models -r /content/drive/MyDrive/NIH_Dataset_total/folders/
!cp /content/pickles -r /content/drive/MyDrive/NIH_Dataset_total/folders/
!cp /content/GT_and_probs -r /content/drive/MyDrive/NIH_Dataset_total/folders/


#!cp /content/drive/MyDrive/NIH_Dataset_total/folders/models -r /content/ 
#!cp /content/drive/MyDrive/NIH_Dataset_total/folders/pickles -r /content/
#!cp /content/drive/MyDrive/NIH_Dataset_total/folders/GT_and_probs -r /content/

In [12]:
PATH = 'weightsDensenet_121_v1_2'
torch.save(model.state_dict(), PATH)