In [None]:
# !pip install numpy==1.15
!pip install pydicom
# !pip install pylibjpeg 
!pip install pylibjpeg-libjpeg
!pip install python-gdcm

# **This project use two datasets, the first one is CBIS-DDSM contains jpeg images and the second one is RSNA contains .dcm images**

In [None]:
# imports the torch_xla package
import torch_xla
torch_xla.__version__
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.experimental.pjrt_backend 
import torch_xla.experimental.pjrt as pjrt
import torch_xla.test.test_utils as test_utils

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

In [None]:
import pydicom
# import pylibjpeg
from pydicom import dcmread
import numpy as np
import pandas as pd
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
#import dicomsdl as dicoml
# import pydicom
import re
# from joblib import Parallel, delayed
from tqdm.notebook import tqdm
from multiprocessing import cpu_count
# from PIL import Image
import cv2
import gc
import glob
# import importlib
import os
# import joblib
import time
from torchvision.io import read_image
import torch.nn as nn
from torchvision.models import  densenet161, DenseNet161_Weights, efficientnet_v2_l, EfficientNet_V2_L_Weights, convnext_base, ConvNeXt_Base_Weights
from torchvision.models.convnext import CNBlock
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim

In [None]:
# Path for saving weights
weight_path_save = '/kaggle/working/'

In [None]:
!printenv

In [None]:
flags = {}
flags['number_epochs'] = 70
flags['model_name'] = 'convnext_base'
flags['batch_size'] = 16
flags['learning_rate'] = 0.01
flags['weight_decay'] = 0.005
flags['num_workers'] = 8
flags['final_freeze'] = 17
flags['img_standard_H'] = 850
flags['img_standard_W'] = 512

flags['agumentation'] = {}
flags['agumentation']['img_H'] = 360
flags['agumentation']['img_W'] = 224
flags['agumentation']['scale_min'] = 0.3
flags['agumentation']['scale_max'] = 1.2
flags['agumentation']['ratio_min'] = 1.1
flags['agumentation']['ratio_max'] = 1.7
flags['agumentation']['alpha'] = 180.0 # float
flags['agumentation']['sigma'] = 8.0 # float
flags['agumentation']['mean'] = 0 # float
flags['agumentation']['std'] = 0.4 # float
flags['use_crop'] = False

# **Processing first dataset**

In [None]:
# Path meta data infor
infor_path = '/kaggle/input/cbis-ddsm-breast-cancer-image-dataset/csv/dicom_info.csv'
calc_case_train_path = '/kaggle/input/cbis-ddsm-breast-cancer-image-dataset/csv/calc_case_description_train_set.csv'
calc_case_test_path = '/kaggle/input/cbis-ddsm-breast-cancer-image-dataset/csv/calc_case_description_test_set.csv'
mass_case_train_path = '/kaggle/input/cbis-ddsm-breast-cancer-image-dataset/csv/mass_case_description_train_set.csv'
mass_case_test_path = '/kaggle/input/cbis-ddsm-breast-cancer-image-dataset/csv/mass_case_description_test_set.csv'

In [None]:
#Read master infor file
master = pd.read_csv(infor_path)
# train cata
calc_case_train = pd.read_csv(calc_case_train_path)
mass_case_train = pd.read_csv(mass_case_train_path)
# test data
calc_case_test = pd.read_csv(calc_case_test_path)
mass_case_test = pd.read_csv(mass_case_test_path)

In [None]:
# remove row contains NaN
master.dropna(axis = 0,subset=['Laterality'] ,inplace = True)
master.head(10)

In [None]:
# remove row contains 16 bit image only
master = master[master['BitsAllocated'] == 16]
master.head(10)

In [None]:
# Get patient id in master dataframe
def get_patient_id(row):
    element = re.split("-|_",row['PatientID'])
    patient_id_list = element[2:4]
    patient_id_list.insert(1,'_')
    return ("".join(patient_id_list) , element[0], element[1])

master[['PatientID','tumor_type','data_type']] = master.apply(lambda x: get_patient_id(x),axis = 1,result_type = 'expand')
master[['PatientID','tumor_type','data_type']]

In [None]:
# Edit file path of jpg image
def replace_file_path(row):
    new_path = row.replace('CBIS-DDSM','/kaggle/input/cbis-ddsm-breast-cancer-image-dataset')
    return new_path
    
master['image_path'] = master['image_path'].apply(lambda x: replace_file_path(x))
master['image_path']

In [None]:
# Only select full mammogram images, remove cropped images
master_full = master[master['SeriesDescription'] == 'full mammogram images']
master_full

In [None]:
# Only select cropped images, remove full mammogram images
master_crop = master[master['SeriesDescription'] == 'cropped images']
master_crop

In [None]:
# select few columns only
master_full = master_full[['image_path', 'Laterality', 'PatientID', 'PatientOrientation','tumor_type','data_type']]
master_full.reset_index(inplace= True, drop = True)
master_full.head(10)

In [None]:
# select few columns only
master_crop = master_crop[['image_path', 'Laterality', 'PatientID', 'PatientOrientation','tumor_type','data_type']]
master_crop.reset_index(inplace= True, drop = True)
master_crop.head(10)

In [None]:
def make_key(curr_row):
    # get image id
    img_id = re.split("/",curr_row['image_path'])[-2]
    
    # Laterality
    Later = 'RIGHT' if curr_row['Laterality'] == 'R' else 'LEFT'
    
    # concat to make key
    match_key = curr_row['tumor_type']+'-'+curr_row['data_type']+'_'+ curr_row['PatientID']+'_'+Later+'_'+\
                curr_row['PatientOrientation']+'_'+img_id
    return match_key
###### make key
master_full['match_key'] =  master_full.apply(lambda x: make_key(x),axis = 1)

master_crop['match_key'] =  master_crop.apply(lambda x: make_key(x),axis = 1)

In [None]:
# n = 19
# print('ID: ',master['PatientID'][n], ',tumor: ', master['tumor_type'][n], master['image_path'][n])
# img = cv2.imread(master['image_path'][n])
# plt.imshow(img)
# plt.show()

In [None]:
# Processing mass case file .csv (train case)
mass_case_train = mass_case_train[['breast_density','pathology','image file path','cropped image file path']]
mass_case_train.head(5)

In [None]:
def make_keys_in_mass_case(curr_row):
    # get image id for full image
    elements = re.split("/",curr_row['image file path'])
    img_di = elements[-2]
    
    # get image id for cropped image
    elements_2 = re.split("/",curr_row['cropped image file path'])
    img_di_2 = elements_2[-2]
    
    return elements[0]+'_'+img_di, elements_2[0][0:-2]+'_'+img_di_2

mass_case_train[['match_key1','match_key2']] =  mass_case_train.apply(lambda x: make_keys_in_mass_case(x),axis = 1, result_type= 'expand')
mass_case_train.head(5)

In [None]:
# Processing mass case file .csv (test case)
mass_case_test = mass_case_test[['breast_density','pathology','image file path','cropped image file path']]
mass_case_test[['match_key1','match_key2']] =  mass_case_test.apply(lambda x: make_keys_in_mass_case(x),axis = 1, result_type = 'expand')
mass_case_test.head(5)

In [None]:
# Processing calc case file .csv (train case)
calc_case_train = calc_case_train[['breast density','pathology','image file path','cropped image file path']]
calc_case_train.head(5)

In [None]:
def make_key2_in_calc_case(curr_row):
    # get image id for full image
    elements = re.split("/",curr_row['image file path'])
    img_di = elements[-2]
    
    # get image id for cropped image
    elements_2 = re.split("/",curr_row['cropped image file path'])
    img_di_2 = elements_2[-2]
    
    return elements[0]+'_'+img_di, elements_2[0][0:-2]+'_'+img_di_2

calc_case_train[['match_key1','match_key2']] =  calc_case_train.apply(lambda x: make_key2_in_calc_case(x),axis = 1,result_type = 'expand')
calc_case_train.rename(columns = {'breast density':'breast_density'}, inplace = True)
calc_case_train.head(5)

In [None]:
# Processing calc case file .csv (test case)
calc_case_test = calc_case_test[['breast density','pathology','image file path','cropped image file path']]
calc_case_test.rename(columns = {'breast density':'breast_density'}, inplace = True)
calc_case_test[['match_key1','match_key2']] =  calc_case_test.apply(lambda x: make_key2_in_calc_case(x),axis = 1,result_type = 'expand')
calc_case_test.rename(columns = {'breast density':'breast_density'}, inplace = True)
calc_case_test.head(5)

In [None]:
# Merge case file into master dataframe
other_concat = pd.concat([mass_case_train, mass_case_test, calc_case_train, calc_case_test])
other_concat_key1 = other_concat.drop(columns = 'match_key2',inplace = False)
other_concat_key1.drop_duplicates(subset = 'match_key1', inplace = True)
other_concat_key1.reset_index(inplace= True, drop = True)
print("length of other key1: ", len(other_concat_key1))
other_concat_key2 = other_concat.drop(columns = 'match_key1', inplace = False)
other_concat_key2.drop_duplicates(subset = 'match_key2', inplace = True)
other_concat_key2.reset_index(inplace= True, drop = True)
print("length of other key2: ", len(other_concat_key2))


master_full = master_full.merge(other_concat_key1,how = 'left', left_on = 'match_key',copy = False, right_on = 'match_key1', validate = '1:m')
master_crop = master_crop.merge(other_concat_key2,how = 'left' ,left_on = 'match_key', right_on = 'match_key2', validate = '1:m')

In [None]:
master_crop

In [None]:
n = 100
print('ID: ',master_crop['PatientID'][n], ',tumor: ', master_crop['tumor_type'][n], master_crop['image_path'][n])
img = cv2.imread(master_crop['image_path'][n])

plt.imshow(img)
plt.show()

In [None]:
master_full['image_type'] = 'full'
master_full

In [None]:
master_full = master_full[['image_path','Laterality','PatientID', 'PatientOrientation', 'tumor_type','data_type','breast_density','pathology']]
master_full['image_type'] = 'full'
master_crop = master_crop[['image_path','Laterality','PatientID', 'PatientOrientation', 'tumor_type','data_type','breast_density','pathology']]
master_crop['image_type'] = 'crop'
# final dataframe master
final_master_dataset_1 = 0
if flags['use_crop']:
    final_master_dataset_1 = pd.concat([master_full, master_crop])
else:
    final_master_dataset_1 = master_full
    
final_master_dataset_1.sample(frac = 1.0, axis= 0, random_state = 1000)
final_master_dataset_1

In [None]:
# Make label base benign and malignant
def create_label(value):
    if value == 'MALIGNANT':
        return 1
    if value == 'BENIGN' or value == 'BENIGN_WITHOUT_CALLBACK':
        return 0

final_master_dataset_1['label'] = final_master_dataset_1['pathology'].apply(lambda x: create_label(x))
final_master_dataset_1

In [None]:
# training set
final_train_dataset_1= final_master_dataset_1[final_master_dataset_1['data_type']=='Training']
final_train_dataset_1 = final_train_dataset_1.reset_index(drop = True)
# train_master = train_master.head(2)
# test set
final_test_dataset_1= final_master_dataset_1[final_master_dataset_1['data_type']=='Test']
final_test_dataset_1 = final_test_dataset_1.reset_index(drop = True)
# val_master = val_master.head(2)
print("Length of training set of dataset 1: ", len(final_train_dataset_1))
print("Length of test set of dataset 1: ", len(final_test_dataset_1))

In [None]:
final_train_dataset_1

In [None]:
# # unit test dataset 1
# from pydicom import dcmread
# import numpy as np
# idx = 2
# sample_path = final_train_dataset_1.at[idx, "image_path"]
# print(sample_path,", view: ", final_train_dataset_1.at[idx, "Laterality"])

# tensor = read_image(sample_path)
# # transform instance
# aug = transforms.Compose([transforms.RandomResizedCrop(size = (flags['img_H'],flags['img_W']), scale=(0.3, 1.2), ratio=(1.1,1.7) ,antialias = True),
#                           transforms.ElasticTransform(alpha = 150.0, sigma = 8.0),
#                           transforms.RandomAutocontrast(),
#                           transforms.Normalize((0), (0.4))
#                           ])
# tensor = aug(tensor.float())
# plt.imshow(tensor.numpy().squeeze(), cmap = 'gray')
# plt.show()

# **Processing second dataset**

dont use test data because it doesnot contain labels

In [None]:
meta_dataset_2 = pd.read_csv("/kaggle/input/rsna-breast-cancer-detection/train.csv")
meta_dataset_2 = meta_dataset_2.sample(frac = 1.0, random_state = 300)
meta_dataset_2 = meta_dataset_2.head(6000)
print("Lenght of dataset 2: ",len(meta_dataset_2))

In [None]:
def make_full_path(current_row):
    return "/kaggle/input/rsna-breast-cancer-detection/train_images/" + \
                    str(current_row["patient_id"]) + "/" + str(current_row["image_id"]) + ".dcm"
    
meta_dataset_2["full_path"] = meta_dataset_2.apply(lambda x: make_full_path(x), axis = 1)
meta_dataset_2 = meta_dataset_2[["full_path","cancer","laterality","view"]]

In [None]:
# If later == L -> Width reduce N
# If later == R --> 0 increase N
def find_offset(curr_row, default_offset: int = 800):
    if curr_row["laterality"] == "L":
        return 0, default_offset
    elif curr_row["laterality"] == "R":
        return default_offset, 0
    
meta_dataset_2[["start_width","end_width"]] = meta_dataset_2.apply(lambda x: find_offset(x), axis = 1, result_type = 'expand' )
meta_dataset_2.head()

In [None]:
meta_dataset_2[["start_width","end_width"]].isna().values.any()

In [None]:
# shuffle and split datasource 2 into train and test
meta_dataset_2_train = meta_dataset_2.head(len(meta_dataset_2)//2)
meta_dataset_2_test = meta_dataset_2.tail(len(meta_dataset_2)//2)

Intergrate dataframe of dataset 2 into dataframe of dataset 1 by *.concat()* method

In [None]:
meta_dataset_2_train

In [None]:
integrated_train_df = pd.concat([final_train_dataset_1,meta_dataset_2_train],ignore_index = True)
#integrated_train_df = final_train_dataset_1
integrated_train_df = integrated_train_df.sample(frac = 1.0, random_state = 500)
integrated_train_df = integrated_train_df.fillna(0)
integrated_train_df.reset_index(inplace = True)

integrated_test_df = pd.concat([final_test_dataset_1,meta_dataset_2_test],ignore_index = True)
integrated_test_df = integrated_test_df.sample(frac = 1.0, random_state = 300)
integrated_test_df = integrated_test_df.fillna(0)
integrated_test_df.reset_index(inplace = True)

In [None]:
# integrated_train_df = integrated_train_df.head(200)

In [None]:
# # unit test dataset 2
# from pydicom import dcmread
# import numpy as np
# idx = 48946
# sample_path = meta_dataset_2_train.at[idx, "full_path"]
# print(sample_path,", view: ", meta_dataset_2_train.at[idx, "laterality"])
# dcm_file = dcmread(sample_path, force = True)
# image_array = dcm_file.pixel_array   #Output from extract dicom is uint16
# #image_array = image_array*255/32767  # scale from range (0,32767) to (0,255) dtype = float64
# image_array = image_array.astype(np.float32)
# H, W = image_array.shape
# image_array = image_array[200: H-200 ,0 + meta_dataset_2_train.at[idx, "start_width"]: W + \
#                           meta_dataset_2_train.at[idx, "end_width"]] 
# print(image_array.shape)
# plt.imshow(image_array.squeeze(), cmap = 'gray')
# plt.show()

# # transform instance
# aug = transforms.Compose([transforms.RandomResizedCrop(size = (flags['img_H'],flags['img_W']), 
#                                                        scale=(0.3, 1.2), ratio=(1.1,1.7) ,antialias = True),
#                           transforms.ElasticTransform(alpha = 150.0, sigma = 8.0),
#                           transforms.RandomAutocontrast(),
#                           transforms.Normalize((0), (0.4))
#                           ])

# image_array = np.expand_dims(image_array,axis = 0)
# tensor = torch.tensor(image_array).float() # cast to torch tensor dtype float32
# tensor = aug(tensor)
# plt.imshow(tensor.numpy().squeeze(), cmap = 'gray')
# plt.show()

In [None]:
# ### Implement Cutout in nn.Module ###
# import torch.nn as nn
# from typing import List, Tuple
# class Cutout(nn.Module):
#     def __init__(self, 
#                  num_holes: int, 
#                  holes_size: Tuple[int, int], 
#                  max_topleft_position_ratio: Tuple[float, float],
#                  min_topleft_position_ratio: Tuple[float, float]
#                 ) -> None:
#         self.num_holes = num_holes
#         self.holes_size = holes_size
#         self.max_topleft_position_ratio = max_topleft_position_ratio
#         self.min_topleft_position_ratio = min_topleft_position_ratio
#         super(Cutout, self).__init()
        
#     def forward(self, x):
#         assert x.ndim == 3, "Input dim of data must be 3"
#         input_shape = x.shape
#         mask = torch.ones((input_shape[1], input_shape[2])) # 2 dimentional tensor
        
        

In [None]:
######### Define class of dataset to read images and labels
class MyDataset(torch.utils.data.Dataset):
    """
    Class for loading images and labels that collected from 2 dataset sources
    """
    def __init__(self,integrated_dataframe, flags, dataset_type: str):
        self.integrated_dataframe = integrated_dataframe
        self.flags = flags
        if dataset_type == "training":
            self._full_img_transform = transforms.Compose([transforms.Resize((self.flags['img_standard_H'],self.flags['img_standard_W']), 
                                                                             antialias = True)
                                                            ])
        elif dataset_type == "validation":
            self._full_img_transform = transforms.Compose([transforms.Resize((self.flags['agumentation']['img_H'],self.flags['agumentation']['img_W']), 
                                                                             antialias = True)
                                                            ])
    
#     def _crop_img_transform(self, input_img):
#         return transforms.Compose([transforms.Resize((self.flags['img_H'],self.flags['img_W']), antialias = True)
#                                   ])(input_img)
    
    def __len__(self):
        return len(self.integrated_dataframe)
    
    def _path_to_arrary(self,idx):
        dcm_file = dcmread(self.integrated_dataframe['full_path'][idx], force = True)
        image_array = dcm_file.pixel_array
        image_array = image_array.astype(np.float32)
        
        H, W = image_array.shape
        image_array = image_array[150: H-150 ,int(self.integrated_dataframe["start_width"][idx]): W - int(self.integrated_dataframe["end_width"][idx]) ] 
        image_array = np.expand_dims(image_array,axis = 0)
        return image_array
    
    def __getitem__(self,idx):
        # get image from image path/full path
        temp_img_dir = self.integrated_dataframe['image_path'][idx] # from dataset 1
        if temp_img_dir != 0:   # case of dataset 1
            image = read_image(temp_img_dir)
            image = self._full_img_transform(image.float())
            image = image.expand(3,*image.shape[1:])
            # get label
            label = torch.tensor(0).to(torch.long) if self.integrated_dataframe['label'][idx] == 0 else \
                                                            torch.tensor(1).to(torch.long)
            return  image, label
            
            
        else: # case of dataset 2
            image = torch.tensor(self._path_to_arrary(idx))  # cast to torch tensor dtype float32
            image = self._full_img_transform(image.float())
            image = image.expand(3,*image.shape[1:])
            label = torch.tensor(0).to(torch.long) if self.integrated_dataframe['cancer'][idx] == 0 else \
                                                            torch.tensor(1).to(torch.long)
            return  image, label

In [None]:

def make_model(model_str_name: str, flags: dict ,final_freeze_layer_order: int = 17): 
    assert model_str_name in ["densenet","efficientnet", "convnext_base"], "Invalid model name" 
    
    ### Augmentation Model ###
    aug_model = nn.Sequential(
                transforms.RandomResizedCrop(size = (flags['agumentation']['img_H'],flags['agumentation']['img_W']), 
                                                                    scale=(flags['agumentation']['scale_min'], flags['agumentation']['scale_max']), 
                                                                    ratio=(flags['agumentation']['ratio_min'],flags['agumentation']['ratio_max']) ,antialias = True),
                transforms.ElasticTransform(alpha = flags['agumentation']['alpha'], sigma = flags['agumentation']['sigma']),
                transforms.RandomAutocontrast(),
                transforms.Normalize((flags['agumentation']['mean']), (flags['agumentation']['std']))
                )
    for para in aug_model.parameters():
        para.requires_grad = False
    
    if model_str_name == "densenet":
        classifier_model = densenet161(weights = DenseNet161_Weights.IMAGENET1K_V1, progress = False)  
        classifier_model.classifier = torch.nn.Sequential(
                                    nn.Linear(2208, 1024),
                                    nn.ReLU(inplace = True),
                                    nn.Dropout(inplace = True),
                                    nn.Linear(1024,2)
                                    )
        for params in classifier_model.parameters():
                params.requires_grad = False
        # Fine-tuning model
        for order, child in enumerate(classifier_model.children()):
            if order == 0 and isinstance(child,torch.nn.Sequential):
                for sub_order, sub_layer in enumerate(child):
                    if sub_order > final_freeze_layer_order:
                        for param in sub_layer.parameters():
                            param.requires_grad = True
            else:
                for params in child.parameters():
                    params.requires_grad = True
        return classifier_model, aug_model
    elif model_str_name == "efficientnet":
        classifier_model = efficientnet_v2_l(weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1, progress = False) 
        classifier_model.fc = torch.nn.Sequential(
                                    nn.Linear(1280, 512),
                                    nn.ReLU(inplace = True),
                                    nn.Dropout(inplace = True),
                                    nn.Linear(512,2)
                                    )
        for params in classifier_model.parameters():
                params.requires_grad = False
                # Fine-tuning model
        for order, child in enumerate(classifier_model.children()):
            if order == 0 and isinstance(child,torch.nn.Sequential):
                for sub_order, sub_layer in enumerate(child):
                    if sub_order > final_freeze_layer_order:
                        for param in sub_layer.parameters():
                            param.requires_grad = True
            else:
                for params in child.parameters():
                    params.requires_grad = True
        return classifier_model, aug_model
    
    elif model_str_name == "convnext_base":
        classifier_model = convnext_base(weights = ConvNeXt_Base_Weights.IMAGENET1K_V1, progress = False) 
        classifier_model.classifier.append(torch.nn.Linear(in_features = 1000, out_features = 2, bias = True))
        # Set all grad to False first
        for params in classifier_model.parameters():
            params.requires_grad = False
        
        # Finetuning
        for order,structure in enumerate(classifier_model.children()):
            if order == 0: # denoted as (features)
                # counting layer
                count_layer = []
                for seq_order, seq in enumerate(structure):
                    if not isinstance(seq,torch.nn.Sequential):
                        count_layer.append(1)
                    else:
                        count_cnb = 0
                        for cnb_block in seq.children():
                            if isinstance(cnb_block, CNBlock):
                                count_cnb += 1
                        if count_cnb == 0:
                            count_layer.append(1)
                        else:
                            count_layer.append(count_cnb)

                # Determine index of star unfreeze layer
                stop_index, cnb_stop_id = None, None
                for idx,element in enumerate(count_layer):
                    if idx != 0:
                        sum = 0
                        for ind in range(idx+1):
                            sum = sum + count_layer[ind]
                        if sum >= final_freeze_layer_order:
                            stop_index = idx
                            cnb_stop_id = count_layer[idx] - (sum - final_freeze_layer_order)-1
                            break
                # set requires grad
                for seq_order, seq in enumerate(structure):
                    if seq_order == stop_index:
                        if count_layer[seq_order] == cnb_stop_id + 1:
                            continue
                        else:
                            for cnb_id, cnb in enumerate(seq.children()):
                                if cnb_id > cnb_stop_id:
                                    for params in cnb.parameters():
                                        params.requires_grad = True

                    elif seq_order > stop_index:
                        for params in seq.parameters():
                            params.requires_grad = True

            else:
                for params in structure.parameters():
                    params.requires_grad = True
        return classifier_model, aug_model

In [None]:
def _training_loop(input_model,augmentation_model,train_data_loader,train_length:int, optimizer,scheduler,loss_fn, device, scaled_data_step:int =1):
    train_loss = torch.tensor(0.0, dtype = torch.float32,device = device)
    train_corrects = torch.tensor(0, dtype = torch.int16,device = device)
    stop_step = train_length//scaled_data_step
    for step, (inputs, labels) in enumerate(train_data_loader):
        xm.master_print("Step: ",step)
        if step == stop_step:
            break
        else: 
            # clean gradient
            optimizer.zero_grad()
            # apply augmentation 
            inputs = augmentation_model(inputs)
            # Make predictions for this batch
            outputs = input_model(inputs)
            # Compute the loss and its gradients
            loss = loss_fn(outputs, labels)
            loss.backward()
            # Adjust learning weights
            xm.optimizer_step(optimizer)  # for pjrt, for ddp use optimizer.step() xm.mark_step()
            train_loss = train_loss + loss
            # for monitor
            class_pred = torch.argmax(outputs,1)
            train_corrects = train_corrects + torch.div(torch.eq(class_pred,labels).sum(),labels.shape[0])
    
    scheduler.step()
    # in epoch to summary
    train_loss = torch.div(train_loss,torch.div(train_length,scaled_data_step))
    train_corrects = torch.div(train_corrects,torch.div(train_length,scaled_data_step))
    return train_loss, train_corrects
    
def _validation_loop(model,val_data_loader, val_length, loss_fn, device, scaled_data_step:int = 1):
    with torch.no_grad():
        val_loss = torch.tensor(0.0, dtype = torch.float32,device = device)
        val_corrects = torch.tensor(0, dtype = torch.int16,device = device)
        stop_step = val_length//scaled_data_step
        for i, (test_inputs, test_labels) in enumerate(val_data_loader):
            xm.master_print("val step: ", i)
            if i == stop_step:
                break
            else:
                test_outputs = model(test_inputs)
                current_val_loss = loss_fn(test_outputs, test_labels)
                val_loss += current_val_loss

                # for monitor
                val_class_pred = torch.argmax(test_outputs,1)
                val_corrects = val_corrects + torch.div(torch.eq(val_class_pred,test_labels).sum(),test_labels.shape[0])
        # in epoch to summary
        val_loss = torch.div(val_loss,torch.div(val_length,scaled_data_step))
        val_corrects = torch.div(val_corrects, torch.div(val_length,scaled_data_step))
    return val_loss, val_corrects
    
def training_pipeline(classifier_model,
                     augmentation_model,
                     num_epochs: int, 
                     learning_rate: float, 
                     weight_decay: float,
                     train_loader: pl.MpDeviceLoader, 
                     val_loader: pl.MpDeviceLoader,
                     train_length: int,
                     val_length: int,
                     device):
    
    # Define optimizer
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(classifier_model.parameters(), lr= learning_rate*xm.xrt_world_size(), momentum=0.9, weight_decay = weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = learning_rate*xm.xrt_world_size()*5,
                                                steps_per_epoch = train_length,
                                                epochs = num_epochs,
                                                anneal_strategy='cos')
    train_loss_list = []
    train_acc_list = []
    val_loss_list = []
    val_acc_list = []
    augmentation_model.eval()
    for epoch in range(num_epochs):
        classifier_model.train()
        train_loss, train_acc = _training_loop(classifier_model,augmentation_model,train_loader,train_length,optimizer,scheduler,loss_fn, device)

        # reduce tensor's values before appending to master
        train_loss, train_acc = xm.all_reduce(reduce_type = "sum", inputs = [train_loss, train_acc], scale = 0.125, groups = [[0,1,2,3,4,5,6,7]])
        xm.master_print("Lets wait")
        xm.wait_device_ops()
        xm.master_print('Epoch {} after reduce, train loss" {}, train accuracy" {}, current process: {} \n'.format(epoch,train_loss, train_acc, xm.get_ordinal()))  #testing
        train_loss_list.append(train_loss.clone().detach().cpu().numpy())
        train_acc_list.append(train_acc.clone().detach().cpu().numpy())
        
        
        # Validation process 
        classifier_model.eval()
        val_loss, val_acc = _validation_loop(classifier_model,val_loader, val_length, loss_fn, device)
        # reduce tensor's values before appending to master
        val_loss, val_acc = xm.all_reduce(reduce_type = "sum", inputs = [val_loss, val_acc], scale = 0.125, groups = [[0,1,2,3,4,5,6,7]])
        xm.master_print("Lets wait for validation")
#         xm.wait_device_ops()
        xm.master_print('Epoch {} after reduce, val loss" {}, val accuracy" {}, current process: {} \n'.format(epoch,val_loss, val_acc, xm.get_ordinal()))  #testing
        
        val_loss_list.append(val_loss.clone().detach().cpu().numpy())
        val_acc_list.append(val_acc.clone().detach().cpu().numpy())
        
        # save model's weights (condition: compare loss in epoch level)
#         if  val_running_loss > min(total_val_loss)*1.1 :
#             print('Saving model with val loss of {:.3f}'.format(val_running_loss))
#             torch.save(classifier_model.state_dict(), weight_path_save+str(epoch)+'_checkpoint_weights'+model_name)
    
    # Save at end of epochs
#     torch.save(classifier_model.state_dict(), weight_path_save+'final_checkpoint_weights'+model_name)
    return train_loss_list, train_acc_list, val_loss_list, val_acc_list

In [None]:
def draw_result(train_loss_list,train_acc_list,val_loss_list, val_acc_list, model_name: str):
    
    # draw plot
    plt.plot(train_acc_list, color = 'blue',label = "training accuracy")
    plt.plot(val_acc_list, color = 'orange',label = "validation accuracy")
    plt.ylabel("Accuracy")
    plt.xlabel("Epochs")
    plt.legend()
    plt.title("Training and validation accuracy of "+ model_name)
    plt.show()
    
    plt.plot(train_loss_list, color = 'blue',label = "training loss")
    plt.plot(val_loss_list, color = 'orange',label = "validation loss")
    plt.ylabel("Loss")
    plt.xlabel("Epochs")
    plt.legend()
    plt.title("Training and validation loss of "+ model_name)
    plt.show()

In [None]:
def _map_fn(rank, flags):
    # Device configuration
    device = xm.xla_device()
    dist.init_process_group('xla',init_method='pjrt://')
    
    # make model and transfer to device
    model, aug_model = make_model(model_str_name = flags['model_name'], flags = flags,final_freeze_layer_order = flags['final_freeze'])
    model = model.to(device)
    aug_model = aug_model.to(device)
    pjrt.broadcast_master_param(model)
    pjrt.broadcast_master_param(aug_model)
    # create sampler
    number_replicas = xm.xrt_world_size()
    train_sampler = torch.utils.data.distributed.DistributedSampler(
                                                      training_data,
                                                      num_replicas=number_replicas,
                                                      rank=rank, 
                                                      shuffle=True)
    val_sampler = torch.utils.data.distributed.DistributedSampler(
                                                      validation_data,
                                                      num_replicas=number_replicas,
                                                      rank=rank, 
                                                      shuffle=True)
    train_dataloader = DataLoader(training_data, 
                                  batch_size= flags['batch_size'], 
                                  sampler=train_sampler, 
                                  shuffle = False,
                                  num_workers = flags['num_workers'],
                                  drop_last = True)
    val_dataloader = DataLoader(validation_data, 
                                batch_size= flags['batch_size'], 
                                sampler=val_sampler,
                                shuffle = False,
                                num_workers = flags['num_workers'],
                                drop_last = True)
    # Distributed loader to device
    xla_train_loader = pl.MpDeviceLoader(train_dataloader, device)
    xla_val_loader = pl.MpDeviceLoader(val_dataloader, device)
    train_length = len(train_dataloader)
    val_length = len(val_dataloader)
    train_loss, train_acc, val_loss, val_acc = training_pipeline(model,
                                                                 aug_model,
                                                                 num_epochs = flags['number_epochs'], 
                                                                 learning_rate = flags['learning_rate'], 
                                                                 weight_decay = flags['weight_decay'],
                                                                 train_loader = xla_train_loader, 
                                                                 val_loader = xla_val_loader,
                                                                 train_length = train_length,
                                                                 val_length = val_length,
                                                                 device = device)
    # Visualization
    if xm.is_master_ordinal():
        draw_result( train_loss, train_acc, val_loss, val_acc, flags['model_name'])

# Main process
if __name__ == '__main__':
    os.environ.pop('TPU_PROCESS_ADDRESSES')
    os.environ.pop('CLOUD_TPU_TASK_ID')
    os.environ['PJRT_DEVICE'] = 'TPU'
    ########### Create dataset ############
    torch.manual_seed(300)
    train_df = integrated_train_df.head(int(len(integrated_train_df)*0.7))
    val_df = integrated_train_df.tail(len(integrated_train_df) - len(train_df))
    val_df.reset_index(drop = True, inplace = True)
    training_data = MyDataset(train_df, flags, "training")
    validation_data = MyDataset(val_df, flags, "validation")
    
    test_data = MyDataset(integrated_test_df, flags, "validation")
    
    if torch.distributed.is_available():
        #Note: in interactive notebooks, you must use start_method='fork'
        xmp.spawn(_map_fn, args=(flags,), start_method='fork')
    else:
        print("No distributed")