In [1]:
# Imports here
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms, models
import torchvision.models as models
from PIL import Image
import json
from matplotlib.ticker import FormatStrFormatter
from collections import OrderedDict
import torch.quantization
import catalyst as c
from sklearn.metrics import f1_score
from catalyst.contrib.losses import FocalLossMultiClass
import copy
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f18186ab3d0>

In [2]:
import coremltools as ct
from coremltools.optimize.torch.quantization import LinearQuantizer, LinearQuantizerConfig

In [3]:
print(torch.__version__)

2.0.0+cu117


In [4]:
print(ct.__version__)

7.0b1


In [5]:
print(c.__version__)

22.04


In [6]:
data_directory = '/opt/infilect/dev/dataset/resize'
# train_directory = data_directory + '/train'
# # validation_directory = data_directory + '/valid'
# test_directory = data_directory + '/test'  

In [7]:

from typing import Callable, Dict, List

from threading import Thread
from queue import Queue


def threadSkeleton(myId: str, workerFunc: Callable, recvQueue: Queue, 
                   sendQueue: Queue, **kwargs):

    print(F" Thread {myId} -> Starting...")
    objProcessed = 0

    while True:

        try: thisJob = recvQueue.get(block=True, timeout=5)
        except: thisJob = None

        if thisJob is None: continue
        if thisJob == "END": break
        thisJob = workerFunc(thisJob, **kwargs)

        if not thisJob is None: sendQueue.put(thisJob)
        objProcessed += 1

    print(F" Thread {myId} -> Ending. Samples processed {objProcessed}")

    pass

class ThreadHandler:

    def __init__(self, nProcesses: int=12) -> None:

        self.nProcesses = nProcesses

        self.processQueue = [Queue() for __ in range(nProcesses)]
        self.resultsQueue = Queue()

        self.currProcessIndex = 0
        self.threadList: List[Thread] = []
        pass
    
    def startThreads(self, workerFunc: Callable, **kwargs):
        for i in range(self.nProcesses):
            tThread = Thread(target=threadSkeleton, args=[i, workerFunc, self.processQueue[i], 
                                                            self.resultsQueue], kwargs=kwargs)
            tThread.start()
            self.threadList.append(tThread)
        pass
    
    def putJob(self, thisJob: Dict):
        self.processQueue[self.currProcessIndex].put(thisJob)
        self.currProcessIndex += 1
        if self.currProcessIndex == self.nProcesses: self.currProcessIndex = 0
        pass

    def endThreads(self,):
        for q in self.processQueue: q.put("END")
        for t in self.threadList: t.join()
        pass

    def getResults(self,) -> List[Dict]: return list(self.resultsQueue.queue)

    pass


In [8]:

from typing import Callable, Dict, Tuple, Set, List
import os
from threading import Thread

import numpy as np
import pandas as pd
from PIL import Image

import torch 
from torch.utils.data import Dataset

class BalancedLoader(Dataset):

    def __init__(self, rootPath: str, augObject: Callable=None, validLbls: Set[str]=set(), 
                cacheImgs: bool=True, minSampPerClass: int=-1, maxSampPerClass: int=-1, 
                nProcesses: int=15) -> None:

        self.rootPath = rootPath
        self.augObject = augObject
        self.validLbls = set(validLbls)
        self.cacheImgs = cacheImgs
        self.minSampPerClass = minSampPerClass
        self.maxSampPerClass = maxSampPerClass
        self.nProcesses = nProcesses

        if cacheImgs: 
            self.dataDict, self.filesPerClass, self.nFiles = self.__cacheImages()
        else: 
            self.dataDict, self.filesPerClass, self.nFiles = self.__getFileNames()
        
        self.allClasses = sorted(list(self.filesPerClass.keys()))
        self.label2Idx, self.idx2Label = self.__tagLabelMap()
        self.nClasses = len(self.allClasses)

        self.__clsItr, self.__balLen = 0, 0
        self.balIdxs = self.getBalancedIndices()
        pass
    
    def __len__(self,) -> int: return self.__balLen

    def __tagLabelMap(self,):
        
        if "notImportant" in self.allClasses:
            excClass = [self.allClasses.pop(self.allClasses.index("notImportant"))]
            self.allClasses = excClass + self.allClasses

        label2Idx, idx2Label = {}, {}
        for i in range(len(self.allClasses)): 
            label2Idx[self.allClasses[i]], idx2Label[i] = i, self.allClasses[i]
        return label2Idx, idx2Label

    def __getFileNames(self,) -> Tuple[Dict, int]:

        dataDict, filesPerClass, totalFiles = {}, {}, 0
        if len(self.validLbls) == 0: self.validLbls = set(os.listdir(self.rootPath))

        for c in os.listdir(self.rootPath):
            thisFiles = pd.Series(os.listdir(F"{self.rootPath}/{c}"))
            if self.minSampPerClass != -1 and len(thisFiles) < self.minSampPerClass: continue
            
            mCls = c if c in self.validLbls else "notImportant"
            dataDict[mCls] = dataDict.get(mCls, [])
            filesPerClass[mCls] = filesPerClass.get(mCls, 0)

            dataDict[mCls] += ( F"{self.rootPath}/{c}/" + thisFiles ).tolist()
            filesPerClass[mCls] = len(dataDict[mCls])
            totalFiles += len(thisFiles)

        return dataDict, filesPerClass, totalFiles
    
    def __cacheImages(self,) -> Dict:
        
        print(" Caching Images....")
        dataDict, filesPerClass, nFiles = self.__getFileNames()
        outDict = {k: [] for k in dataDict.keys()}

        def __loadImage(readImg: Dict) -> None:
            imgCls, imgFile = readImg["imgCls"], readImg["imgFile"]
            tImg = np.asarray(Image.open(imgFile)).copy()
            outDict[imgCls].append(tImg)
        
        threadHndlr = ThreadHandler(self.nProcesses)
        threadHndlr.startThreads(__loadImage)
        allCls, tI, tCls = list(filesPerClass.keys()), 0, 0
        for __ in range(nFiles):
            threadHndlr.putJob({"imgCls": allCls[tCls], "imgFile": dataDict[allCls[tCls]][tI]})
            tI += 1
            if tI < filesPerClass[allCls[tCls]]: continue
            tCls += 1 
            tI = 0
        threadHndlr.endThreads()

        filesPerClass, nFiles = {}, 0
        for c in outDict:
            filesPerClass[c] = len(outDict[c])
            nFiles += len(outDict[c])
        
        print(" Images Cached.")
        return outDict, filesPerClass, nFiles

    def getBalancedIndices(self,):
        
        if self.maxSampPerClass is None: return self.getImbalancedIndices()

        if self.maxSampPerClass == -1:
            for c in self.filesPerClass: 
                self.maxSampPerClass = max(self.maxSampPerClass, self.filesPerClass[c])
            pass
        
        self.__balLen = 0
        balIndices = {}
        for c in self.dataDict.keys():
            tList = np.arange(len(self.dataDict[c]))
            nReps = np.ceil(self.maxSampPerClass/len(tList))
            balIndices[c] = np.repeat(tList, nReps)
            np.random.shuffle(balIndices[c])
            
            balIndices[c] = balIndices[c][:self.maxSampPerClass].tolist()
            self.__balLen += len(balIndices[c])
        
        return balIndices

    def getImbalancedIndices(self,):

        self.__balLen = 0
        balIndices = {}
        for c in self.dataDict.keys():
            balIndices[c] = np.arange(len(self.dataDict[c])).tolist()
            self.__balLen += len(balIndices[c])

        return balIndices

    def __getitem__(self, index) -> Dict:

        if len(self.balIdxs[self.idx2Label[self.__clsItr]]) == 0:
            print(" Shuffling indices...", end="")
            self.__clsItr = 0
            self.balIdxs = self.getBalancedIndices()
            print("Done.")

        tCls = self.idx2Label[self.__clsItr]
        tIdx = self.balIdxs[tCls].pop()
        tImg = self.dataDict[tCls][tIdx]
        
        if isinstance(tImg, str): tImg = np.asarray(Image.open(tImg)).copy()
        else: tImg = tImg.copy()

        if not self.augObject is None: tImg = self.augObject(tImg)[0]

        if "notImportant" in self.allClasses:
            hotEncode = np.zeros(self.nClasses-1, dtype=np.float32)
            if self.__clsItr != 0: hotEncode[self.__clsItr-1] = 1
        else:
            hotEncode = np.zeros(self.nClasses, dtype=np.float32)
            hotEncode[self.__clsItr] = 1

        tImg = torch.from_numpy(tImg).moveaxis(-1, 0)
        hotEncode = torch.from_numpy(hotEncode)
        
        self.__clsItr += 1
        if self.__clsItr >= len(self.idx2Label): self.__clsItr = 0

        # return {"thisX": tImg, "thisY": hotEncode, "className": tCls, "classIdx": self.label2Idx[tCls]}
        return (tImg,), (hotEncode,)
    pass

class SimpleLoader(Dataset):

    def __init__(self, rootPath: str, validLbls: Set[str]=set(), augObject: Callable=None, 
                nProcesses: int=8, cacheImgs: bool=True, label2Idx: Dict=None, 
                idx2Label: Dict=None) -> None:
        super().__init__()

        self.rootPath = rootPath
        self.validLbls = validLbls
        self.augObject = augObject
        self.nProcesses = nProcesses
        self.cacheImgs = cacheImgs

        if self.cacheImgs: self.dataList, self.filesPerClass = self.__cacheImages()
        else: self.dataList, self.filesPerClass = self.getFileNames()
        self.nClasses, self.allClasses = len(self.filesPerClass), list(self.filesPerClass.keys())
        
        if label2Idx is None or idx2Label is None: 
            self.label2Idx, self.idx2Label = self.__tagLabelMap()
        else: self.label2Idx, self.idx2Label = label2Idx, idx2Label

    def __len__(self,): return len(self.dataList)

    def getFileNames(self,):
        
        if len(self.validLbls) == 0: self.validLbls = set(os.listdir(self.rootPath))

        dataList, filesPerClass = [], {}
        for c in os.listdir(self.rootPath):
            thisFiles = pd.Series(os.listdir(F"{self.rootPath}/{c}"))
            thisFiles = F"{self.rootPath}/{c}/" + thisFiles
            mCls = c if c in self.validLbls else "notImportant"
            thisFiles = pd.DataFrame({0: [mCls]*len(thisFiles), 1: thisFiles}).values.tolist()
            dataList += thisFiles
            filesPerClass[mCls] = filesPerClass.get(mCls, 0) + len(thisFiles)
        
        return dataList, filesPerClass    

    def __cacheImages(self,):
        
        allFiles, filesPerImages = self.getFileNames()
        outList = []
        
        def __loadImage(tIdx: int, imgInfo: List):#, clsName: str, imgPath: str):
            
            for i, (clsName, imgPath) in enumerate(imgInfo):
                if ((i+1)%100) == 0: print(F" Thread -> {tIdx} : Processed {i+1} images of {len(imgInfo)}")
                tImg = np.asarray(Image.open(imgPath)).copy()
                outList.append((clsName, tImg))
            print(F" Thread -> {tIdx} : Done.")
        
        flsPerThread = int(len(allFiles) / self.nProcesses)
        threadObjs = []
        for i in range(self.nProcesses):
            thObj = Thread(target=__loadImage, args=(i+1, allFiles[i*flsPerThread:(i+1)*flsPerThread]))
            thObj.start()
            threadObjs.append(thObj)

        for th in threadObjs: th.join()
        return outList, filesPerImages

    def __tagLabelMap(self,):
        if "notImportant" is self.allClasses:
            excClass = [self.allClasses.pop(self.allClasses.index("notImportant"))]
            self.allClasses = excClass + self.allClasses

        label2Idx, idx2Label = {}, {}
        for i in range(len(self.allClasses)): 
            label2Idx[self.allClasses[i]], idx2Label[i] = i, self.allClasses[i]
        return label2Idx, idx2Label

    def __getitem__(self, index):
        
        tCls, tImg = self.dataList[index]
        cIdx = self.label2Idx[tCls]

        if isinstance(tImg, str): tImg = np.asarray(Image.open(tImg)).copy()
        else: tImg = tImg.copy()
        if not self.augObject is None: tImg = self.augObject(tImg)[0]

        if "notImportant" in self.allClasses:
            hotEncode = np.zeros(self.nClasses-1, dtype=np.float32)
            if cIdx != 0: hotEncode[cIdx-1] = 1
        else:
            hotEncode = np.zeros(self.nClasses, dtype=np.float32)
            hotEncode[cIdx] = 1

        tImg = torch.from_numpy(tImg).moveaxis(-1, 0)
        hotEncode = torch.from_numpy(hotEncode)
        
        # return {"thisX": tImg, "thisY": hotEncode, "className": tCls, "classIdx": self.label2Idx[tCls]}
        return (tImg,), (hotEncode,)
    pass



In [9]:

from __future__ import annotations
from typing import Tuple, Callable, Any

import albumentations as AL
import numpy as np

# class ResizeCustom:

#     def __init__(self, newSize: Tuple[int]=None, largestAxis: int=255, smallestSize: int=None,
#                 padSize: Tuple[int]=None) -> None:
        
#         self.newSize = newSize
#         self.largestAxis = largestAxis
#         self.smallestSize = smallestSize
#         self.padSize = padSize

#         if not newSize is None: self.resizeObj = AL.Resize(newSize[0], newSize[1])
#         elif not largestAxis is None: self.resizeObj = AL.LongestMaxSize(largestAxis)
#         elif not smallestSize is None: self.resizeObj = AL.LongestMaxSize(smallestSize)
#         else: self.resizeObj = None
#         pass

#     def __call__(self, inpImage: np.ndarray) -> np.ndarray:

#         if not self.resizeObj is None: inpImage = self.resizeObj(image=inpImage)["image"]
#         if self.padSize is None: return inpImage
#         pass

#     pass

class PadImage:

    def __init__(self, outSize: Tuple[int]=(224, 224)) -> None:
        self.outSize = outSize
        self.maxSize = max(outSize)
        self.resizeObj = AL.LongestMaxSize(self.maxSize)
        pass

    def __call__(self, inpImage: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray:
        
        # if max(inpImage.shape) > self.maxSize: inpImage = self.resizeObj(image=inpImage)["image"]
        inpImage = self.resizeObj(image=inpImage)["image"]
        imgSh = inpImage.shape

        diffY, diffX = (self.outSize[0] - imgSh[0]), (self.outSize[1] - imgSh[1])
        if diffY < 0: diffY = 0
        if diffX < 0: diffX = 0

        pTop, pBot = int(np.floor(diffY/2)), int(np.ceil(diffY/2))
        pLft, pRht = int(np.floor(diffX/2)), int(np.ceil(diffX/2))

        return np.pad(inpImage, ((pTop, pBot), (pLft, pRht), (0, 0)))
    pass

class DataAugmentor:

    def __init__(self, 
                 probHVShift: float=0.5, 
                 probHVFlip: float=0.5, 
                 probRotate: float=0.5, 
                 probZoom: float=0.5, 
                 probRGBShift: float=0.5, 
                 probBrightness: float=0.5,
                 probSharpen: float=0.5, 
                 probEmboss: float=0.5, 
                 hvShiftPer: Tuple[float]=(0., 0.1), 
                 zoomLimits: Tuple[float]=(0.85, 1.15), 
                 rotLimit: float=20, 
                 sharpenAlpha: Tuple[float]=(0.1, 0.2), 
                 sharpenLight: Tuple[float]=(0., 0.), 
                 embossAlpha: Tuple[float]=(0., 1.), 
                 embossStrength: Tuple[float]=(0.1, 0.7), 
                 normImage: bool=True, 
                 bBoxFormat: str="yolo", 
                 KPFormat: str="xy", 
                 newSize: Tuple[int]=None, 
                 beforeAug: Callable=None, 
                 afterAug: Callable=None
                ):

        allTransforms = []
        self.beforeAug = beforeAug

        if probHVShift > 0.0: 
            allTransforms.append( AL.Affine(translate_percent=hvShiftPer, p=probHVShift, mode=0, cval=0) )
        if probHVFlip > 0.0: 
            allTransforms += [AL.HorizontalFlip(p=probHVFlip), AL.VerticalFlip(p=probHVFlip)]
        if probRotate > 0.0: 
            allTransforms.append( AL.Rotate(limit=rotLimit, p=probRotate, border_mode=0) )
        if probZoom > 0.0: 
            allTransforms.append( AL.Affine(scale=zoomLimits, p=probZoom) )
        if probRGBShift > 0.0: 
            allTransforms.append( AL.RGBShift(p=probRGBShift) )    
        if probBrightness > 0.0: 
            allTransforms.append( AL.RandomBrightnessContrast(p=probBrightness) )
        if probSharpen > 0.0:
            allTransforms.append( AL.Sharpen(sharpenAlpha, sharpenLight, p=probSharpen) )
        if probEmboss > 0.0:
            allTransforms.append( AL.Emboss(embossAlpha, embossStrength, p=probEmboss) )

        if normImage: allTransforms.append(AL.Normalize([0.485, 0.456, 0.406],
                                                        [0.229, 0.224, 0.225], p=1.0))
        if not newSize is None: allTransforms.append(AL.Resize(newSize[0], newSize[1]) )

        self.augFunction = AL.Compose(allTransforms,
                                      bbox_params=AL.BboxParams(format=bBoxFormat,),
                                      keypoint_params=AL.KeypointParams(format=KPFormat,
                                                                        remove_invisible=True, 
                                                                        angle_in_degrees=True
                                                                       )
                                     )
        
        self.afterAug = afterAug
        pass
    
    def __call__(self, augImage, bBoxes=[], keyPoints=[], segMask=None):
        
        augImage = np.asarray(augImage).copy()
        if not self.beforeAug is None: augImage = self.beforeAug(inpImage=augImage)

        if not segMask is None: 
            augOut = self.augFunction(image=augImage, bboxes=bBoxes, keypoints=keyPoints, mask=segMask)
        else: augOut = self.augFunction(image=augImage, bboxes=bBoxes, keypoints=keyPoints)

        if not self.afterAug is None: augImage = self.afterAug(inpImage=augImage)
        return ( augOut["image"], augOut["bboxes"], augOut["keypoints"], 
                    ( None if segMask is None else augOut["mask"]) )

    pass

In [10]:
trainAug = DataAugmentor(0., 0., 0., 0., 0., 0., 0., 0.,normImage=True)
# trainAug = DataAugmentor(0., 0., 0., 0., 0., 0., 0., 0., normImage=True, beforeAug=PadImage())
testAug = DataAugmentor(0., 0., 0., 0., 0., 0., 0., 0., normImage=True)

trainSet = BalancedLoader(F"{data_directory}/train", validLbls=set(), maxSampPerClass=100, augObject=trainAug, cacheImgs=False)
testSet = SimpleLoader(F"{data_directory}/test", augObject=testAug, 
               validLbls=set(), cacheImgs=False, label2Idx=trainSet.label2Idx.copy(), idx2Label=trainSet.idx2Label.copy())

In [11]:
# TODO: Using the image datasets and the trainforms, define the dataloaders
trainloader1 = torch.utils.data.DataLoader(trainSet, batch_size = 16, shuffle=True)
# vloader = torch.utils.data.DataLoader(validation_data, batch_size =32,shuffle = True)
testloader1 = torch.utils.data.DataLoader(testSet, batch_size = 16, shuffle = True)

# TODO: Define your transforms for the training, validation, and testing sets
train_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], 
                                                            [0.229, 0.224, 0.225])])

test_transforms = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], 
                                                           [0.229, 0.224, 0.225])])

# # validation_transforms = transforms.Compose([transforms.Resize(256),
# #                                             transforms.CenterCrop(224),
# #                                             transforms.ToTensor(),
# #                                             transforms.Normalize([0.485, 0.456, 0.406], 
# #                                                                  [0.229, 0.224, 0.225])])


# # TODO: Load the datasets with ImageFolder
train_data = datasets.ImageFolder(F"{data_directory}/train", transform=train_transforms)
# # validation_data = datasets.ImageFolder(validation_directory, transform=validation_transforms)
test_data = datasets.ImageFolder(F"{data_directory}/test" ,transform = test_transforms)

# TODO: Using the image datasets and the trainforms, define the dataloaders
trainloader2 = torch.utils.data.DataLoader(train_data, batch_size = 8, shuffle=True)
# vloader = torch.utils.data.DataLoader(validation_data, batch_size =32,shuffle = True)
testloader2 = torch.utils.data.DataLoader(test_data, batch_size = 1, shuffle = True)

In [12]:
print("Train Loader 1 Len:- ", len(trainloader1))
# print("Train Loader 2 Len:- ", len(trainloader2))
print("Test Loader 1 Len:- ", len(testloader1))
# print("Test Loader 2 Len:- ", len(testloader2))

Train Loader 1 Len:-  13
Test Loader 1 Len:-  21


import json

with open('/opt/infilect/dev/repos/image_classification/tag_to_json.json', 'r') as f:
    cat_to_name = json.load(f)

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if 
                    multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        total_epoch: target learning rate is reached at total_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    """

    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater than or equal to 1.')
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier 
                                                     for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        if self.multiplier == 1.0:
            if self.last_epoch == 0: 
                return [(base_lr * (float(self.last_epoch+1) / self.total_epoch))/2.
                        for base_lr in self.base_lrs]
            else:
                return [base_lr * (float(self.last_epoch) / self.total_epoch)
                        for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) 
                    for base_lr in self.base_lrs]

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
        if epoch is None: epoch = self.last_epoch + 1
        self.last_epoch = epoch if epoch != 0 else 1  
        
        if self.last_epoch <= self.total_epoch:
            warmup_lr = [base_lr * ((self.multiplier - 1) * self.last_epoch / self.total_epoch + 1.) 
                         for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 
                param_group['lr'] = lr
        else:
            if epoch is None: self.after_scheduler.step(metrics, None)
            else: self.after_scheduler.step(metrics, epoch - self.total_epoch)

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None: self.after_scheduler.step(None)
                else: self.after_scheduler.step(epoch - self.total_epoch)
                self._last_lr = self.after_scheduler.get_last_lr()
            else: return super(GradualWarmupScheduler, self).step(epoch)
        else: self.step_ReduceLROnPlateau(metrics, epoch)

def setup_model(model_name, num_classes):
    
    model = models.get_model(model_name, weights = "DEFAULT")
    
#     for param in model.parameters():
#         param.requires_grad = False
        
    model.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(
    #     optimizer, milestones=[4], gamma=0.3
    # )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, last_epoch=40, step_size=3, gamma=0.7)
    lrSched = GradualWarmupScheduler(optimizer, 1, 4, scheduler)
    
    model.cuda()
    return model, optimizer, lrSched

import torch.nn as nn
import torchvision.models as models

class ClassHead(nn.Module):
    
    def __init__(self, inpUnits=1280, opOut=2):
        super().__init__()
        self.seqBlock_0 = nn.Sequential(torch.nn.Dropout(0.3), torch.nn.Linear(inpUnits, opOut))

    def forward(self, inpX: torch.Tensor) -> torch.Tensor: return self.seqBlock_0(inpX)
    pass

class BasicModel(nn.Module):
    
    def __init__(self, model_name: str, nClasses: int):
        super().__init__()
        
        self.backBone = models.get_model(model_name, weights = "DEFAULT")
        # self.backBone = torchModels.resnet50(weights=torchModels.ResNet50_Weights.IMAGENET1K_V2)
        self.backBone.fc = nn.Identity()
        self.classHead = ClassHead(2048, nClasses)
        
    def forward(self, inpX: torch.Tensor) -> torch.Tensor:
        embX = self.backBone(inpX)
        out = self.classHead(embX)
        return out

In [15]:
def create_model(model_name , num_classes=2):

    # The number of channels in ResNet18 is divisible by 8.
    # This is required for fast GEMM integer matrix multiplication.
    # model = torchvision.models.resnet18(pretrained=False)
    model = models.get_model(model_name, weights = "DEFAULT")

    # We would use the pretrained ResNet18 as a feature extractor.
    for param in model.parameters():
        param.requires_grad = False
    
    # Modify the last FC layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)

    return model

In [16]:
model = create_model(model_name = "resnet50")

In [17]:
class QuantizedModel(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedModel, self).__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.quantization.DeQuantStub()
        # FP32 model
        self.model_fp32 = model_fp32

    def forward(self, inpX: torch.Tensor) -> torch.Tensor:
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        embX = self.quant(inpX)
        quantX = self.model_fp32(embX)
        dequantX = self.dequant(quantX)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
#         deQuantX = self.dequant(quantX)
        return dequantX

# model, optimizer, scheduler = setup_model(model_name = "resnet50", num_classes = 2)

In [18]:
# model = BasicModel(model_name = "resnet50", nClasses = 2)
# model = nn.DataParallel(model, device_ids=[0, 1])

In [19]:
lossObj = FocalLossMultiClass()

In [20]:
class MeanMeter:

    def __init__(self,):
        
        """ Simple mean tracker for metrics and losses.
        """

        self.sumValues = {}
        self.countValues = {}
        pass
    
    def __len__(self,) -> int: 
        """ Returns the number of currently tracked metrics/losses/values.

        Returns:
            int: Number of currently tracked values.
        """
        return len(self.sumValues)
    
    def __setitem__(self, index: Union[int, str], value: Union[int, float]) -> None:
        """ Add a new metric to track or update some metric that is being tracked.

        Args:
            index (Union[int, str]): 
                Key/name of the metric to track.
            value (Union[int, float]): 
                New value for the corresponding metric.
        """
        if not index in self.sumValues.keys(): self.sumValues[index], self.countValues[index] = 0, 0
        self.sumValues[index] += value
        self.countValues[index] += 1
        pass
    
    def __getitem__(self, index: Union[int, str]) -> float:
        """ Returns the current mean of a metric denoted by index.

        Args:
            index (Union[int, str]): 
                Key/name of metric.

        Raises:
            KeyError: 
                Raises error if given metric key is not present.

        Returns:
            float: Mean of metric denoted by "index".
        """
        if not index in self.sumValues: raise KeyError(F"Key {index} does not exist.")
        return self.sumValues[index] / self.countValues[index]

    @property
    def keys(self,) -> List[str]: 
        """ Get list of metrics currently being tracked.

        Returns:
            List[str]: List of key names.
        """
        return list(self.sumValues.keys())

    def hasKey(self, index: Union[int, str]) -> bool:
        """ Check if metric denoted by "index" is being tracked.

        Args:
            index (Union[int, str]): 
                Key/name to track.

        Returns:
            bool: True if metric exists.
        """
        return index in self.sumValues.keys()

    def getAll(self,) -> Dict:
        """ Get mean of all currently tracked metrics.

        Returns:
            Dict: Dictionary where each key is the metric key/name/index and the
                corresponding value is it's mean. 
        """
        return {k: self.sumValues[k]/self.countValues[k] for k in self.sumValues.keys()}
        
    def resetValues(self,) -> None:
        """ Clear and remove all metrics currently being tracked.
        """
        self.sumValues, self.countValues = {}, {}
    
    pass

In [21]:
from functools import wraps

def getLatency(callFunc: Callable) -> Callable:

    @wraps(callFunc)
    def wrapperFunc(*args, **kw):
        stTime = time()
        rootReturn = callFunc(*args, **kw)
        enTime = time()
        return rootReturn, enTime-stTime
    return wrapperFunc

In [22]:
from time import time
import copy
class TrainEngine:

    def __init__(self, engineModel, projectName: str, runName: str, savePath: str=None, 
                usageMode: str="Train", torDevice: str=None, useWandb: bool=True):
        """_summary_

        Args:
            engineModel (_type_): _description_
            projectName (str): _description_
            runName (str): _description_
            savePath (str, optional): _description_. Defaults to None.
            usageMode (str, optional): _description_. Defaults to "Train".
            torDevice (str, optional): _description_. Defaults to None.
            useWandb (bool, optional): _description_. Defaults to True.
        """
        
        self.usageMode = usageMode # Train/Validate/Infer/Test
        self.trainConfig = {}
        self.optimObj = None
        self.schedulerObj = None
        self.modelScaler = None
        self.lossObj = None
        
        savePath = F"{projectName}/{runName}" if savePath is None else F"{savePath}/{projectName}/{runName}"
        self.trainConfig["projectName"] = projectName
        self.trainConfig["runName"] = runName
        self.trainConfig["savePath"] = savePath
        self.trainConfig["batchSize"] = 32
        self.trainConfig["nEpochs"] = 20
        self.trainConfig["weightDecay"] = 0.001
        self.trainConfig["trainHalfPrec"] = False
        self.trainConfig["cosineCycle"] = 4
        self.trainConfig["currEpoch"] = 0
        self.trainConfig["logAfterBatch"] = 20
        self.trainConfig["printAfterBatch"] = 5000
        self.trainConfig["useWandB"] = useWandb
        self.trainConfig["logToFile"] = True
        self.trainConfig["logFile"] = {i:self.trainConfig["savePath"] + F"/{i}_Logs.csv"
                                        for i in ["Train", "Test", "Valid", "Eval"]}
        self.trainConfig["currentBest"] = None
        self.trainConfig["monitorKey"] = None
        self.trainConfig["saveIfLess"] = True
        self.trainConfig["evalOnTrain"] = True

        self.__logKeysEpoch = None
        self.__logKeys = None
        self.__logFile = None 
        self.meanMeters = {"Train": None, "Valid": None, "Test": None, "Eval": None}

        self.torDevice = (torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) 
                          if torDevice is None else torDevice)
        self.engineModel = engineModel.to(self.torDevice)
        self.bestModel = copy.deepcopy(self.engineModel)
        
        if self.trainConfig["useWandB"]: self.prepareWandB()
        pass
    
    def __del__(self,): self.cleanUp()
    
    def cleanUp(self,):
        try: wandb.finish()
        except: pass
        try:
            if not self.__logFile is None:  
                for v in self.__logFile.values(): v.close()
        except: pass
    
    def prepareWandB(self,):
        wandb.init(project=self.trainConfig["projectName"])
        wandb.config = self.trainConfig
        wandb.run.name = self.trainConfig["runName"]
        pass
    
    def setupOptimizer(self, optimObj=None, optimParams=None, learningRate=None, 
                       betaVals: Tuple[float]=(0.9, 0.999), weightDecay=None):
        """_summary_

        Args:
            optimObj (_type_, optional): _description_. Defaults to None.
            optimParams (_type_, optional): _description_. Defaults to None.
            learningRate (_type_, optional): _description_. Defaults to None.
            betaVals (Tuple[float], optional): _description_. Defaults to (0.9, 0.999).
            weightDecay (_type_, optional): _description_. Defaults to None.
        """
        
        # Get parameters for default optimizer.
        if weightDecay is None:
            if "weightDecay" in self.trainConfig: weightDecay = self.trainConfig["weightDecay"]
            else: weightDecay = 0.
        if learningRate is None:
            if "learningRate" in self.trainConfig: learningRate = self.trainConfig["learningRate"]
            else: learningRate = 0.0001
        
        # Set object
        if optimObj is None: 
            optimObj = torch.optim.Adam(params=self.engineModel.parameters(), lr=learningRate, 
                                        betas=betaVals, weight_decay=weightDecay)
        self.optimObj = optimObj
        
        # Load parameters if available.
        if optimParams is None: return
        if isinstance(optimParams, str): 
            optimParams = torch.load(optimParams, map_location=self.torDevice)
        self.optimObj.load_state_dict(optimParams)
        pass
    
    def setupLRScheduler(self, schedulerObj=None, schedulerParams=None):
        
        # Get object.
        if schedulerObj is None: return
        if self.optimObj is None: 
            raise Exception("Uninitialized optimizer. Setup optimizer first using setupOptimizer() \
method of the class.")
        
        if schedulerObj == "CosineAnnealing": 
            schedulerObj = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimObj, 
                        self.trainConfig.get("cosineCycle", 4), eta_min=0.00001, verbose=True)
        
        self.schedulerObj = schedulerObj
        
        # Load parameters if available.
        if schedulerParams is None: return
        if isinstance(schedulerParams, str): 
            schedulerParams = torch.load(schedulerParams, map_location=self.torDevice)
        self.schedulerObj.load_state_dict(schedulerParams)
        pass
    
    def createLogFiles(self, openMode: str="w") -> None:
 
        if not self.trainConfig["logToFile"] is None:

            self.__logFile = {}
            for k in self.trainConfig["logFile"]: 
                # Log files for each phase.
                openMode = openMode if os.path.exists(self.trainConfig["logFile"][k]) else "w"
                self.__logFile[k] = open(self.trainConfig["logFile"][k], openMode)

            # Log file for epoch.
            self.__logFile["Epoch_Logs"] = open(F"{self.trainConfig['savePath']}/Epoch_Logs.csv", 
            openMode if os.path.exists(F"{self.trainConfig['savePath']}/Epoch_Logs.csv") else "w")
        pass

    def prepareForTrain(self, optimObj=None, schedulerObj="CosineAnnealing", 
                        engineParams: Union[str, Dict]=None, modelParams: Union[str, Dict]=None, 
                        optimParams: Union[str, Dict]=None, schedulerParams: Union[str, Dict]=None, 
                        lossObj: Callable=None, currEpoch: int=None, nEpochs: int=100, 
                        monitorKey: str=None, saveIfLess: bool=True, batchSize: int=32, 
                        learningRate: float=0.0001, betaVals: Tuple[float]=(0.9, 0.999), 
                        weightDecay: float=0., cosineCycle: int=4):
        
        # Load pre-trained weights and config.
        if not engineParams is None:
            
            # Load params from file if required.
            if isinstance(engineParams, str):
                engineParams = torch.load(engineParams, map_location=self.torDevice)
            
            # Set params.
            modelParams = engineParams["modelParams"]
            optimParams = engineParams["optimParams"]
            schedulerParams = engineParams["schedulerParams"]
            
            # Set train config.
            trainConfig = engineParams["trainConfig"]
            for k in ["projectName", "runName", "savePath", "logFile"]: 
                trainConfig[k] = self.trainConfig[k]
            self.trainConfig = trainConfig
        
        # Load model params if available.
        if not modelParams is None:
            if isinstance(modelParams, str): 
                modelParams = torch.load(modelParams, map_location=self.torDevice)
            self.engineModel.load_state_dict(modelParams)

        # Get train details.
        self.trainConfig["batchSize"] = batchSize
        self.trainConfig["nEpochs"] = nEpochs
        self.trainConfig["weightDecay"] = weightDecay
        self.trainConfig["cosineCycle"] = cosineCycle
        self.trainConfig["monitorKey"] = "totalLoss" if monitorKey is None else monitorKey
        self.trainConfig["saveIfLess"] = saveIfLess 
        if not currEpoch is None: self.trainConfig["currEpoch"] = currEpoch
        
        # Setup objects.
        for k in self.meanMeters.keys(): self.meanMeters[k] = MeanMeter()
        if self.trainConfig["trainHalfPrec"]: self.modelScaler = torch.cuda.amp.GradScaler()
        
        self.setupOptimizer(optimObj, optimParams, learningRate, betaVals, weightDecay)
        self.setupLRScheduler(schedulerObj, schedulerParams)
        self.lossObj = torch.nn.CrossEntropyLoss() if lossObj is None else lossObj
        
        # Manage files and directories.
        os.makedirs(self.trainConfig["savePath"], exist_ok=True)
        os.makedirs(F"{self.trainConfig['savePath']}/checkPoints", exist_ok=True)
        
        # Create log files.
        self.createLogFiles()
        pass
    
    def getLoss(self, thisY: Union[List, Tuple, torch.Tensor], yHat: Union[List, Tuple, torch.Tensor]):
        # Compute Loss
        if (not isinstance(thisY, list)) and (not isinstance(thisY, tuple)): thisY = [thisY]
        if (not isinstance(yHat, list)) and (not isinstance(yHat, tuple)): yHat = [yHat]
        
        # Default cross entropy.
        totalLoss = self.lossObj(yHat[0], thisY[0])
        for i in range(len(thisY)): totalLoss += self.lossObj(yHat[i], thisY[i])
        return totalLoss, {"totalLoss": totalLoss.item()}
    
    def getMetrics(self, thisY: Union[List, Tuple, torch.Tensor], yHat: Union[List, Tuple, torch.Tensor]):
        
        # Compute Metrics
        if (not isinstance(thisY, list)) and (not isinstance(thisY, tuple)): thisY = [thisY]
        if (not isinstance(yHat, list)) and (not isinstance(yHat, tuple)): yHat = [yHat]
        
        # Default accuracy.
        avgAccuracy = 0
        for i in range(len(thisY)): avgAccuracy += (thisY[i] == yHat[i].argmax(1)).sum() / len(thisY[i])
        return {"accuracyValue": (avgAccuracy/len(thisY)).item()}
    
    def startTraining(self, trainLoader, validLoader=None, testLoader=None, ):
        
        # Train/Valid Loops
        if self.optimObj is None: raise Exception("Uninitialized optimizer. Setup optimizer first \
using setupOptimizer() method of the class.")
        
        # Prepare log files.
        self.createLogFiles(openMode="a")
        
        try:
            self.beforeTrain()
            for e in range(self.trainConfig["currEpoch"], self.trainConfig["nEpochs"]+1):
                
                print(F"\n Starting Epoch {e}. Training Phase...")
                stTime = time()
                self.beforeEpoch(epochIdx=e)
                
                # Save learning rate at the start of each epoch to log later.
                epochLogs = {}
                if not self.optimObj is None: epochLogs["LR"] = self.optimObj.param_groups[0]['lr']
                else: epochLogs["LR"] = None

                #  Forward on trainLoader if available.
                if not trainLoader is None:
                    self.beforeTrainEpoch(epchIdx=e)
                    __, trainTime = self.trainInferWithLoader(trainLoader, thisPhase="Train")
                    self.afterTrainEpoch(epchIdx=e)
                else: trainTime = None

                #  Forward on validLoader if available.
                if not validLoader is None:
                    self.beforeValidationEpoch(epochIdx=e)
                    print(F"\n  Starting Validation Phase...")
                    __, validTime = self.trainInferWithLoader(validLoader, thisPhase="Valid")
                    self.afterValidationEpoch(epochIdx=e)
                else: validTime = None

                #  Forward on testLoader if available.
                if not testLoader is None:
                    self.beforeTestingEpoch(epochIdx=e)
                    print(F"\n  Starting Testing Phase...")
                    __, testTime = self.trainInferWithLoader(testLoader, thisPhase="Test")
                    self.afterTestingEpoch(epochIdx=e)
                else: testTime = None

                # Evaluate on trainLoader if required.
                if (not trainLoader is None) and self.trainConfig["evalOnTrain"]:
                    self.beforeEvalEpoch(epochIdx=e)
                    print(F"\n  Starting Evaluation Phase...")
                    __, evalTime = self.trainInferWithLoader(trainLoader, thisPhase="Eval")
                    self.afterEvalEpoch(epochIdx=e)
                else: evalTime = None
                
                # Update current epoch value.
                self.trainConfig["currEpoch"] = e

                # Epoch verbose.
                epochLogs["EpochTime"] = time() - stTime
                epochLogs["TrainTime"], epochLogs["ValidTime"] = trainTime, validTime
                epochLogs["TestTime"], epochLogs["Evaltime"] = testTime, evalTime
                self.verboseEpoch(e, epochLogs=epochLogs)
                
                # Save checkpoints and update meters.
                self.saveCheckPoint(e)
                for k in self.meanMeters: self.meanMeters[k].resetValues()
                
                self.afterEpoch(epochIdx=e)
                pass
            
            self.afterTrain()
        except KeyboardInterrupt: print("Stop Requested. Performing Cleanup...")
        finally: self.cleanUp()
        pass
    
    @getLatency
    def trainInferWithLoader(self, thisLoader, saveResults: bool=False, thisPhase: str="Test"):
        # Train/Valid/Infer/Eval Epochs
        
        # Set model mode.
        if self.usageMode == "Train": 
            self.engineModel.train() if thisPhase == "Train" else self.engineModel.eval()
        else: self.engineModel.eval()
        
        if not self.optimObj is None:
          print(F"   Learning Rate {self.optimObj.param_groups[0]['lr']}")

        for i, (thisX, thisY) in enumerate(thisLoader):
            
            # Forward on batch
#             print("thisY", thisY)

            # Latency helpers.
            stTime = time()

            # Convert X to list if required and put them into device.
            if (not isinstance(thisX, list)) and (not isinstance(thisX, tuple)): thisX = [thisX]
            thisX = [t.to(self.torDevice) for t in thisX]

            # Convert Y to list if required and put them into device.
            if thisY is not None:
                if (not isinstance(thisY, list)) and (not isinstance(thisY, tuple)): thisY = [thisY]
                thisY = [t.to(self.torDevice) for t in thisY]

            # Forward feed on batch.
            self.beforeBatch(batchIdx=i, thisX=thisX, thisY=thisY, thisPhase=thisPhase)
            if self.usageMode == "Train" and thisPhase == "Train": 
                yHat, forwardTime = self.batchForwardTrain(thisX)
#                 print("TRAIN yHAT", yHat)
            else: 
                yHat, forwardTime = self.batchForwardInfer(thisX)
#                 print("INFET yHAT", yHat)
                
            
            # Backward propagate on batch.
            backpropTime = None
            if self.usageMode == "Train" and thisPhase == "Train": 
                thisLosses, backpropTime = self.batchBackward(thisY, yHat)
            elif not thisY is None: __, thisLosses = self.getLoss(thisY, yHat)
            else: thisLosses = None

            # Update batch level latencies.
            batchTime = time() - stTime
            self.meanMeters[thisPhase]["ForwardTime"] = forwardTime
            if not backpropTime is None: self.meanMeters[thisPhase]["BackPropTime"] = backpropTime
            self.meanMeters[thisPhase]["BatchTime"] = batchTime

            # Get metrics.
            thisY = None if thisY is None else [i.detach() for i in thisY]
            #yHat = [i.detach() for i in yHat]
#             print("thisY", thisY, "======================", "yHat", yHat)
            yHat = yHat.detach()
            if thisY is None: thisMetrics = None
            else: thisMetrics = self.getMetrics(thisY, yHat)
            self.afterBatch(batchIdx=i, thisX=thisX, thisY=thisY, thisPhase=thisPhase, yHat=yHat, 
                            thisLosses=thisLosses, thisMetrics=thisMetrics)
            
            # Update mean meters.
            self.updateMeanMeters(thisPhase=thisPhase, newLosses=thisLosses, newMetrics=thisMetrics)

            # Save if required and log the batch.
            if saveResults: 
                self.saveResults(batchIdx=i, inputData=thisX, groundTruth=thisY, modelPreds=yHat, 
                                thisPhase=thisPhase, thisLoss=thisLosses, thisMetrics=thisMetrics)
            self.logAndVerbose(i, thisLosses, thisMetrics, thisPhase=thisPhase)

        # Step scheduler if required.
        if (not self.schedulerObj is None) and thisPhase == "Train" and self.usageMode == "Train":
            self.schedulerObj.step()
        pass
    
    @getLatency
    @torch.no_grad()
    def batchForwardInfer(self, thisX):
        # Batch Forward
        
        # Forward feed data.
        yHat = self.engineModel(*thisX)
        return yHat
    
    @getLatency
    def batchForwardTrain(self, thisX):
        # Batch Forward
        self.optimObj.zero_grad()
        
        # Forward feed data.
        yHat = self.engineModel(*thisX)
        return yHat
    
    @getLatency
    def batchBackward(self, thisY, yHat):
        # Batch Backward

        # Get Loss
        thisLoss, lossForLog = self.getLoss(thisY, yHat)
        # Backward propagate.
        if self.modelScaler is None:
            thisLoss.backward()
            self.optimObj.step()
        else:
            self.modelScaler.scale(thisLoss).backward()
            self.modelScaler.step(self.optimObj)
            self.modelScaler.update()
        
        return lossForLog
    
    def updateMeanMeters(self, thisPhase: str, newLosses: Dict=None, newMetrics: Dict=None) -> None:

        # Get mean trackers.
        trackerObj = self.meanMeters.get(thisPhase, None)
        if trackerObj is None: return

        # Update tracker.
        if not newLosses is None:
            for k in newLosses: trackerObj[k] = newLosses[k]
        if not newMetrics is None:
            for k in newMetrics: trackerObj[k] = newMetrics[k]
        pass

    def writeModelToDisk(self, savePath: str, saveName: str):
        # Create state dict.
        saveDict = {"modelParams": self.engineModel.state_dict(), 
                    "optimParams": self.optimObj.state_dict(), 
                    "schedulerParams": self.schedulerObj.state_dict(),
                    "trainConfig": self.trainConfig.copy()}
        # Save to disk.
        os.makedirs(savePath, exist_ok=True)
        torch.save(saveDict, F"{savePath}/{saveName}")
        pass
    
    def saveCheckPoint(self, epochIdx: int):
        
        # Get monitoring key.
        monitorKey, lessThan = self.trainConfig["monitorKey"], self.trainConfig["saveIfLess"]
        
        # Switch to loss if monitoring metric not available.
        if not self.meanMeters["Valid"].hasKey(monitorKey): 
            print(F"  Key {monitorKey} not found. Will be monitoring loss to save checkpoint.")
            monitorKey = "totalLoss"

        # Save this epoch model.
        self.writeModelToDisk(F"{self.trainConfig['savePath']}/checkPoints/Epoch_{epochIdx}", 
                              saveName="engineParams.pth")

        if lessThan:
            # Skip if no improvement.
            if (not self.trainConfig["currentBest"] is None) and \
                self.meanMeters["Valid"][monitorKey] >= self.trainConfig["currentBest"]: 
                return
            # Save model and update current best.
            print(F"  Lower {monitorKey} achieved. Previous Value {self.trainConfig['currentBest']}\
, Current Value {self.meanMeters['Valid'][monitorKey]}. Saving Model.") 
            self.trainConfig["currentBest"] = self.meanMeters["Valid"][monitorKey]
            savVal = "{:.4f}".format(self.meanMeters["Valid"][monitorKey])
            self.bestModel = copy.deepcopy(self.engineModel)
            self.writeModelToDisk(F"{self.trainConfig['savePath']}/checkPoints/{epochIdx}_{savVal}", 
                                  saveName="engineParams.pth")
        else:
            # Skip if no improvement.
            if (not self.trainConfig["currentBest"] is None) and \
                self.meanMeters["Valid"][monitorKey] <= self.trainConfig["currentBest"]: 
                return
            # Save model and update current best.
            print(F"  Higher {monitorKey} achieved. Previous Value {self.trainConfig['currentBest']}\
, Current Value {self.meanMeters['Valid'][monitorKey]}. Saving Model.")
            self.trainConfig["currentBest"] = self.meanMeters["Valid"][monitorKey]
            savVal = "{:.4f}".format(self.meanMeters["Valid"][monitorKey])
            self.writeModelToDisk(F"{self.trainConfig['savePath']}/checkPoints/{epochIdx}_{savVal}", 
                                  saveName="engineParams.pth")

        pass
    
    def verboseEpoch(self, epochIdx: int, epochLogs: Dict={}) -> None:
        
        # Log to display.
        print(F"  Epoch {epochIdx} Stats.")
        for k in self.meanMeters.keys():
            print(F"   {k} Metrics: ", end="")
            for m in self.meanMeters[k].keys: 
                print(F"{m}:", "{:.4f}".format(self.meanMeters[k][m]), end=", ")
            print("")
        
        # Add column names in log csv.
        if self.__logKeysEpoch is None:
            lKeys = self.meanMeters[list(self.meanMeters.keys())[0]].keys + list(epochLogs.keys())
            self.__logKeysEpoch = "Epoch,Phase,"
            for s in lKeys: self.__logKeysEpoch += F"{s},"
            self.__logFile["Epoch_Logs"].write(F"{self.__logKeysEpoch[:-1]}\n")
            self.__logKeysEpoch = self.__logKeysEpoch[:-1].split(",")
            pass

        # Log to file.
        if "Epoch_Logs" in self.__logFile.keys() and (not self.__logFile["Epoch_Logs"] is None):
            
            for k in self.meanMeters.keys():
                # Get meter.
                thisMeter = self.meanMeters[k]
                if thisMeter is None: continue
                logDict: Dict = thisMeter.getAll()
                logDict.update(epochLogs)

                # Log string to file.
                logString = F"{epochIdx},{k},"
                for m in self.__logKeysEpoch: 
                    if m in ["Epoch", "Phase"]: continue
                    if logDict.get(m, None) is None: logString += "NaN,"
                    else: logString += "{:.4f},".format(logDict[m])  
                self.__logFile["Epoch_Logs"].write(F"{logString[:-1]}\n")
            pass
        
        # Log to wandb if required.
        if not self.trainConfig["useWandB"]: return
        wandbLogs = {}
        for k in self.meanMeters.keys(): 
          thisLogs = { F"EpochLog_{k}_{m}": self.meanMeters[k][m] for m in self.meanMeters[k].keys }
          wandbLogs.update(thisLogs)
        wandb.log(wandbLogs)
        pass

    def logAndVerbose(self, logIdx: int, thisLosses: Dict={}, thisMetrics: Dict={}, 
                      thisPhase: str="Test"):
        
        # Print if required.
        if self.trainConfig["printAfterBatch"] > 0 and logIdx%self.trainConfig["printAfterBatch"] == 0:
            print(F"   Batch {logIdx}")
            print("    Losses : ", end="")
            for k in thisLosses: print(F"{k}:", "{:.4f}".format(thisLosses[k]), end=", ")
            print("\n    Metrics : ", end="")
            for k in thisMetrics: print(F"{k}:", "{:.4f}".format(thisMetrics[k]), end=", ")
            print()
        
        # Add CSV keys once.
        if self.__logKeys is None: 
            
            # Create log keys.
            self.__logKeys = ""
            for k in thisLosses: self.__logKeys += F"{k},"
            for k in thisMetrics: self.__logKeys += F"{k},"
            self.__logKeys = self.__logKeys[:-1]

            # Write keys to file.
            for fileKey, fileObj in self.__logFile.items():
                if fileObj is None or "Epoch_" in fileKey: continue
                fileObj.write(F"{self.__logKeys}\n")
            self.__logKeys = self.__logKeys.split(",")
            pass
        
        # Log to file if required.
        if not self.__logFile is None:
            logMetrics = ""
            for k in self.__logKeys:
                tMetric = thisLosses.get(k, None)
                if tMetric is None: tMetric = thisMetrics.get(k)
                logMetrics += "{:.4f},".format(tMetric)
            self.__logFile[thisPhase].write(F"{logMetrics[:-1]}\n")
        
        # Skip if log to wandb not required.
        if self.trainConfig["logAfterBatch"] < 0 or logIdx%self.trainConfig["logAfterBatch"] != 0: 
            return

        # Log to WandB if required.
        thisMetrics.update(thisLosses)
        thisMetrics = { F"{thisPhase}_{k}": thisMetrics[k] for k in thisMetrics.keys() }
        if self.trainConfig["useWandB"]: wandb.log(thisMetrics)
        pass
    
    # Callbacks
    def saveResults(self, *args, **kwargs): pass
    def beforeTrain(self, *args, **kwargs): pass
    def afterTrain(self, *args, **kwargs): pass

    def beforeEpoch(self, *args, **kwargs): pass
    def afterEpoch(self, *args, **kwargs): pass

    def beforeTrainEpoch(self, *args, **kwargs): pass
    def afterTrainEpoch(self, *args, **kwargs): pass

    def beforeBatch(self, *args, **kwargs): pass
    def afterBatch(self, *args, **kwargs): pass

    def beforeValidationEpoch(self, *args, **kwargs): pass
    def afterValidationEpoch(self, *args, **kwargs): pass

    def beforeTestingEpoch(self, *args, **kwargs): pass
    def afterTestingEpoch(self, *args, **kwargs): pass

    def beforeEvalEpoch(self, *args, **kwargs): pass
    def afterEvalEpoch(self, *args, **kwargs): pass
    pass

In [23]:
class Trainer_Legacy(TrainEngine):

    def __init__(self, engineModel,  projectName: str, runName: str, savePath: str=None, usageMode: str="Train", torDevice: str=None, useWandb: bool=False):
        super().__init__(engineModel, projectName, runName, savePath, usageMode, torDevice, useWandb)
        pass

    def getLoss(self, thisY: Union[List, Tuple, torch.Tensor], yHat: Union[List, Tuple, torch.Tensor]) -> Tuple[torch.Tensor, Dict]:
        
#         print("YHA IN GET LOSS", yHat)

#         yHat = yHat[0]
#         print("YHA[0] IN GET LOSS", yHat[0])
        thisY = thisY[0].argmax(1)
        
        totalLoss = self.lossObj(yHat, thisY)
        outDict = {"totalLoss": totalLoss.item()}
        
        return totalLoss, outDict

    def getMetrics(self, thisY, yHat):
        
        
        thisY = thisY[0].detach().argmax(1).cpu().numpy()
        yHat = yHat.detach().softmax(1).argmax(1).cpu().numpy()
#         print("GET METRIC YHAT[0]", yHat[0])
        
        thisAcc = (thisY == yHat).sum() / len(yHat)
        F1Score = f1_score(thisY, yHat, average=None).mean()
        
        return {"F1_clsHead": F1Score, "Accuracy": thisAcc}
    pass

In [24]:

savePath = os.path.join(os.getcwd(), "models")
EPOCHS = 15
BATCH_SIZE = 16
myEngine = Trainer_Legacy(model, "QUANTIZED_AWARE_TRAINING_2", "raw_resnet50_1", useWandb=False, torDevice="cuda", savePath=savePath)
# myEngine.prepareForTrain(batchSize=BATCH_SIZE, weightDecay=0., lossObj=nn.BCEWithLogitsLoss(), nEpochs=EPOCHS, 
#                          learningRate=0.0001, cosineCycle=5)
myEngine.prepareForTrain(batchSize=BATCH_SIZE, weightDecay=0.000005, lossObj=lossObj, nEpochs=EPOCHS, 
                         learningRate=0.0000015, cosineCycle=5)
# myEngine.prepareForTrain(batchSize=BATCH_SIZE, weightDecay=0., lossObj=None, nEpochs=EPOCHS, 
#                          learningRate=0.0003, cosineCycle=5)


# myEngine.prepareForTrain(engineParams="/opt/infilect/dev/storage2/shyam/Data/KH_USA/KSSB_PastaSauce/splitData/ModelLogs/Sigmoid_KsPs/Run_3/checkPoints/5_0.001/engineParams.pth",
#                          batchSize=BATCH_SIZE, weightDecay=0., lossObj=nn.BCEWithLogitsLoss(), nEpochs=EPOCHS, 
#                          learningRate=0.0001, cosineCycle=5)

myEngine.trainConfig["label2Idx"], myEngine.trainConfig["idx2Label"] = trainSet.label2Idx.copy(), trainSet.idx2Label.copy()

Adjusting learning rate of group 0 to 1.5000e-06.


In [25]:
with open(F"{myEngine.trainConfig['savePath']}/label2Idx.json", 'w') as jF: json.dump(trainSet.label2Idx, jF)
with open(F"{myEngine.trainConfig['savePath']}/idx2Label.json", 'w') as jF: json.dump(trainSet.idx2Label, jF)

In [26]:
# lrSched_ = torch.optim.lr_scheduler.CosineAnnealingLR(myEngine.optimObj, 4, 0.00001)
# lrSched_ = torch.optim.lr_scheduler.StepLR(myEngine.optimObj, last_epoch=EPOCHS, step_size=5, gamma=0.7)
# lrSched = GradualWarmupScheduler(myEngine.optimObj, 1, 4, lrSched_)
# myEngine.setupLRScheduler(lrSched)
# myEngine.optimObj.param_groups[0]["lr"] = 0.00001

In [27]:
# scheduler = torch.optim.lr_scheduler.StepLR(myEngine.optimObj, step_size=3, gamma=0.7)
# lrSched = GradualWarmupScheduler(myEngine.optimObj, 1, 3, scheduler)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(myEngine.optimObj, T_max = 4, eta_min=0.00001, verbose = True)
lrSched = GradualWarmupScheduler(myEngine.optimObj, 1, 4, scheduler)
myEngine.setupLRScheduler(lrSched)

Adjusting learning rate of group 0 to 1.5000e-06.


In [28]:
Engine.startTraining(trainloader1, testloader1)


 Starting Epoch 0. Training Phase...
   Learning Rate 1.875e-07
   Batch 0
    Losses : totalLoss: 0.1731, 
    Metrics : F1_clsHead: 0.3043, Accuracy: 0.4375, 

  Starting Validation Phase...
   Learning Rate 3.75e-07
   Batch 0
    Losses : totalLoss: 0.1679, 
    Metrics : F1_clsHead: 0.3043, Accuracy: 0.4375, 

  Starting Evaluation Phase...
   Learning Rate 3.75e-07
 Shuffling indices...Done.
   Batch 0
    Losses : totalLoss: 0.1713, 
    Metrics : F1_clsHead: 0.3333, Accuracy: 0.5000, 
  Epoch 0 Stats.
   Train Metrics: ForwardTime: 0.1402, BackPropTime: 1.6052, BatchTime: 1.8117, totalLoss: 0.1703, F1_clsHead: 0.3859, Accuracy: 0.5240, 
   Valid Metrics: ForwardTime: 0.0080, BatchTime: 1.5574, totalLoss: 0.1666, F1_clsHead: 0.4080, Accuracy: 0.5339, 
   Test Metrics: 
   Eval Metrics: ForwardTime: 0.0062, BatchTime: 1.4831, totalLoss: 0.1695, F1_clsHead: 0.3527, Accuracy: 0.5096, 
  Lower totalLoss achieved. Previous Value None, Current Value 0.16659169111933028. Saving Model.


 Starting Epoch 8. Training Phase...
   Learning Rate 8.755203820042827e-06
 Shuffling indices...Done.
   Batch 0
    Losses : totalLoss: 0.1706, 
    Metrics : F1_clsHead: 0.4589, Accuracy: 0.5625, 
Adjusting learning rate of group 0 to 1.0000e-05.

  Starting Validation Phase...
   Learning Rate 1e-05
   Batch 0
    Losses : totalLoss: 0.1598, 
    Metrics : F1_clsHead: 0.3766, Accuracy: 0.4375, 

  Starting Evaluation Phase...
   Learning Rate 1e-05
 Shuffling indices...Done.
   Batch 0
    Losses : totalLoss: 0.1575, 
    Metrics : F1_clsHead: 0.4589, Accuracy: 0.5625, 
  Epoch 8 Stats.
   Train Metrics: ForwardTime: 0.0096, BackPropTime: 1.6381, BatchTime: 1.7146, totalLoss: 0.1637, F1_clsHead: 0.3794, Accuracy: 0.5192, 
   Valid Metrics: ForwardTime: 0.0064, BatchTime: 1.5458, totalLoss: 0.1571, F1_clsHead: 0.4840, Accuracy: 0.5792, 
   Test Metrics: 
   Eval Metrics: ForwardTime: 0.0066, BatchTime: 1.5016, totalLoss: 0.1633, F1_clsHead: 0.3955, Accuracy: 0.5288, 
  Lower totalL

   Batch 0
    Losses : totalLoss: 0.1396, 
    Metrics : F1_clsHead: 0.7949, Accuracy: 0.8750, 

  Starting Evaluation Phase...
   Learning Rate 8.755203820042822e-06
 Shuffling indices...Done.
   Batch 0
    Losses : totalLoss: 0.1534, 
    Metrics : F1_clsHead: 0.5636, Accuracy: 0.6250, 
  Epoch 15 Stats.
   Train Metrics: ForwardTime: 0.0100, BackPropTime: 1.6331, BatchTime: 1.7104, totalLoss: 0.1545, F1_clsHead: 0.3794, Accuracy: 0.5192, 
   Valid Metrics: ForwardTime: 0.0069, BatchTime: 1.5494, totalLoss: 0.1482, F1_clsHead: 0.5816, Accuracy: 0.6364, 
   Test Metrics: 
   Eval Metrics: ForwardTime: 0.0069, BatchTime: 1.5013, totalLoss: 0.1536, F1_clsHead: 0.4905, Accuracy: 0.5769, 
  Lower totalLoss achieved. Previous Value 0.14914445720967792, Current Value 0.148232672895704. Saving Model.


In [None]:
myEngine.writeModelToDisk(F"{myEngine.trainConfig['savePath']}/checkPoints/", "FinalParams.pth")
fPath = F"{myEngine.trainConfig['savePath']}/checkPoints/FinalParams.pth"

In [None]:
fPath = "/opt/infilect/dev/repos/image_classification_quant/models/QUANTIZED_AWARE_TRAINING_1/raw_resnet50_1/checkPoints/Epoch_10/engineParams.pth"

In [None]:
def load_model(model, model_filepath, device):
    model_skeleton = torch.load(model_filepath, device)

    model.load_state_dict(model_skeleton["modelParams"])

    return model

In [None]:
raw_model_resnet50 = load_model(model = model, model_filepath = fPath, device = "cuda")

In [None]:
raw_model_resnet50

x = [["backBone.conv1", "backBone.bn1", "backBone.relu"],
     ["backBone.layer1.0.conv1", "backBone.layer1.0.bn1"], ["backBone.layer1.0.conv2", "backBone.layer1.0.bn2"], ["backBone.layer1.0.conv3", "backBone.layer1.0.bn3", "backBone.layer1.0.relu"], ["backBone.layer1.0.downsample.0", "backBone.layer1.0.downsample.1"], ["backBone.layer1.1.conv1", "backBone.layer1.1.bn1"], ["backBone.layer1.1.conv2", "backBone.layer1.1.bn2"], ["backBone.layer1.1.conv3", "backBone.layer1.1.bn3", "backBone.layer1.1.relu"], ["backBone.layer1.2.conv1", "backBone.layer1.2.bn1"], ["backBone.layer1.2.conv2", "backBone.layer1.2.bn2"], ["backBone.layer1.2.conv3", "backBone.layer1.2.bn3", "backBone.layer1.2.relu"],
     ["backBone.layer2.0.conv1", "backBone.layer2.0.bn1"], ["backBone.layer2.0.conv2", "backBone.layer2.0.bn2"], ["backBone.layer2.0.conv3", "backBone.layer2.0.bn3", "backBone.layer2.0.relu"], ["backBone.layer2.0.downsample.0", "backBone.layer2.0.downsample.1"], ["backBone.layer2.1.conv1", "backBone.layer2.1.bn1"], ["backBone.layer2.1.conv2", "backBone.layer2.1.bn2"], ["backBone.layer2.1.conv3", "backBone.layer2.1.bn3", "backBone.layer2.1.relu"], ["backBone.layer2.2.conv1", "backBone.layer2.2.bn1"], ["backBone.layer2.2.conv2", "backBone.layer2.2.bn2"], ["backBone.layer2.2.conv3", "backBone.layer2.2.bn3", "backBone.layer2.2.relu"], ["backBone.layer2.3.conv1", "backBone.layer2.3.bn1"], ["backBone.layer2.3.conv2", "backBone.layer2.3.bn2"], ["backBone.layer2.3.conv3", "backBone.layer2.3.bn3", "backBone.layer2.3.relu"],
     ["backBone.layer3.0.conv1", "backBone.layer3.0.bn1"], ["backBone.layer3.0.conv2", "backBone.layer3.0.bn2"], ["backBone.layer3.0.conv3", "backBone.layer3.0.bn3", "backBone.layer3.0.relu"], ["backBone.layer3.0.downsample.0", "backBone.layer3.0.downsample.1"], ["backBone.layer3.1.conv1", "backBone.layer3.1.bn1"], ["backBone.layer3.1.conv2", "backBone.layer3.1.bn2"], ["backBone.layer3.1.conv3", "backBone.layer3.1.bn3", "backBone.layer3.1.relu"], ["backBone.layer3.2.conv1", "backBone.layer3.2.bn1"], ["backBone.layer3.2.conv2", "backBone.layer3.2.bn2"], ["backBone.layer3.2.conv3", "backBone.layer3.2.bn3", "backBone.layer3.2.relu"], ["backBone.layer3.3.conv1", "backBone.layer3.3.bn1"], ["backBone.layer3.3.conv2", "backBone.layer3.3.bn2"], ["backBone.layer3.3.conv3", "backBone.layer3.3.bn3", "backBone.layer3.3.relu"],["backBone.layer3.4.conv1", "backBone.layer3.4.bn1"], ["backBone.layer3.4.conv2", "backBone.layer3.4.bn2"], ["backBone.layer3.4.conv3", "backBone.layer3.4.bn3", "backBone.layer3.4.relu"],["backBone.layer3.5.conv1", "backBone.layer3.5.bn1"], ["backBone.layer3.5.conv2", "backBone.layer3.5.bn2"], ["backBone.layer3.5.conv3", "backBone.layer3.5.bn3", "backBone.layer3.5.relu"],
     ["backBone.layer4.0.conv1", "backBone.layer4.0.bn1"], ["backBone.layer4.0.conv2", "backBone.layer4.0.bn2"], ["backBone.layer4.0.conv3", "backBone.layer4.0.bn3", "backBone.layer4.0.relu"], ["backBone.layer4.0.downsample.0", "backBone.layer4.0.downsample.1"], ["backBone.layer4.1.conv1", "backBone.layer4.1.bn1"], ["backBone.layer4.1.conv2", "backBone.layer4.1.bn2"], ["backBone.layer4.1.conv3", "backBone.layer4.1.bn3", "backBone.layer4.1.relu"], ["backBone.layer4.2.conv1", "backBone.layer4.2.bn1"], ["backBone.layer4.2.conv2", "backBone.layer4.2.bn2"], ["backBone.layer4.2.conv3", "backBone.layer4.2.bn3", "backBone.layer4.2.relu"]]

In [None]:
x_raw = [["conv1", "bn1", "relu"],
     ["layer1.0.conv1", "layer1.0.bn1"], ["layer1.0.conv2", "layer1.0.bn2"], ["layer1.0.conv3", "layer1.0.bn3", "layer1.0.relu"], ["layer1.0.downsample.0", "layer1.0.downsample.1"], ["layer1.1.conv1", "layer1.1.bn1"], ["layer1.1.conv2", "layer1.1.bn2"], ["layer1.1.conv3", "layer1.1.bn3", "layer1.1.relu"], ["layer1.2.conv1", "layer1.2.bn1"], ["layer1.2.conv2", "layer1.2.bn2"], ["layer1.2.conv3", "layer1.2.bn3", "layer1.2.relu"],
     ["layer2.0.conv1", "layer2.0.bn1"], ["layer2.0.conv2", "layer2.0.bn2"], ["layer2.0.conv3", "layer2.0.bn3", "layer2.0.relu"], ["layer2.0.downsample.0", "layer2.0.downsample.1"], ["layer2.1.conv1", "layer2.1.bn1"], ["layer2.1.conv2", "layer2.1.bn2"], ["layer2.1.conv3", "layer2.1.bn3", "layer2.1.relu"], ["layer2.2.conv1", "layer2.2.bn1"], ["layer2.2.conv2", "layer2.2.bn2"], ["layer2.2.conv3", "layer2.2.bn3", "layer2.2.relu"], ["layer2.3.conv1", "layer2.3.bn1"], ["layer2.3.conv2", "layer2.3.bn2"], ["layer2.3.conv3", "layer2.3.bn3", "layer2.3.relu"],
     ["layer3.0.conv1", "layer3.0.bn1"], ["layer3.0.conv2", "layer3.0.bn2"], ["layer3.0.conv3", "layer3.0.bn3", "layer3.0.relu"], ["layer3.0.downsample.0", "layer3.0.downsample.1"], ["layer3.1.conv1", "layer3.1.bn1"], ["layer3.1.conv2", "layer3.1.bn2"], ["layer3.1.conv3", "layer3.1.bn3", "layer3.1.relu"], ["layer3.2.conv1", "layer3.2.bn1"], ["layer3.2.conv2", "layer3.2.bn2"], ["layer3.2.conv3", "layer3.2.bn3", "layer3.2.relu"], ["layer3.3.conv1", "layer3.3.bn1"], ["layer3.3.conv2", "layer3.3.bn2"], ["layer3.3.conv3", "layer3.3.bn3", "layer3.3.relu"],["layer3.4.conv1", "layer3.4.bn1"], ["layer3.4.conv2", "layer3.4.bn2"], ["layer3.4.conv3", "layer3.4.bn3", "layer3.4.relu"],["layer3.5.conv1", "layer3.5.bn1"], ["layer3.5.conv2", "layer3.5.bn2"], ["layer3.5.conv3", "layer3.5.bn3", "layer3.5.relu"],
     ["layer4.0.conv1", "layer4.0.bn1"], ["layer4.0.conv2", "layer4.0.bn2"], ["layer4.0.conv3", "layer4.0.bn3", "layer4.0.relu"], ["layer4.0.downsample.0", "layer4.0.downsample.1"], ["layer4.1.conv1", "layer4.1.bn1"], ["layer4.1.conv2", "layer4.1.bn2"], ["layer4.1.conv3", "layer4.1.bn3", "layer4.1.relu"], ["layer4.2.conv1", "layer4.2.bn1"], ["layer4.2.conv2", "layer4.2.bn2"], ["layer4.2.conv3", "layer4.2.bn3", "layer4.2.relu"]]

In [None]:
model

In [None]:
def quantizedModelTrainingPrerequsite(model, nestedList, device):
    #loding the model on cpu device cause cuda doesnt support quantization
    model.to(device)

    #creating a copy of model which will be used for quantization
    fused_model = copy.deepcopy(model)

    # putting model on training model otherwise QAT won't work
    fused_model.train()

    # Fuse the model in place rather manually.
    fused_model = torch.ao.quantization.fuse_modules_qat(fused_model, nestedList, inplace=False)
    return fused_model 

In [None]:
fused_model_resnet50 = quantizedModelTrainingPrerequsite(model = raw_model_resnet50, nestedList = x_raw, device = "cpu")

In [None]:
# Prepare the model for quantization aware training. This inserts observers in
# the model that will observe activation tensors during calibration.
# torch.distributed.init_process_group(backend='nccl', world_size = 1, rank=1)
quantized_model = QuantizedModel(model_fp32 = fused_model_resnet50)
# quantized_model = nn.parallel.DistributedDataParallel(quantized_model)

# Using un-fused model will fail.
# Because there is no quantized layer implementation for a single batch normalization layer.
# quantized_model = QuantizedResNet18(model_fp32=model)
# Select quantization schemes from
# https://pytorch.org/docs/stable/quantization-support.html

In [None]:
quantized_model

In [None]:
def prepareTrainingWithQuantization(model):
    quantization_config = torch.ao.quantization.get_default_qat_qconfig("x86", version=0)
    # Custom quantization configurations
    # quantization_config = torch.quantization.default_qconfig
    # quantization_config = torch.quantization.QConfig(activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.quint8), weight=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))

    model.qconfig = quantization_config

    # Print quantization configurations
    #print(quantized_model.qconfig)

    # https://pytorch.org/docs/stable/_modules/torch/quantization/quantize.html#prepare_qat
    quantized_model = torch.ao.quantization.prepare_qat(model, inplace=False)

    return quantized_model
    pass

In [None]:
quant_model = prepareTrainingWithQuantization(model = quantized_model)

In [None]:
quant_model

In [None]:
savePath = os.path.join(os.getcwd(), "models")
EPOCHS = 10
BATCH_SIZE = 8
# lossObject = nn.CrossEntropyLoss()
lossObj = FocalLossMultiClass()
myQuantEngine = Trainer_Legacy(quant_model, "QUANTIZED_AWARE_TRAINING_1", "quant_resnet50_1", useWandb=False, torDevice="cpu", savePath=savePath)
# myEngine.prepareForTrain(batchSize=BATCH_SIZE, weightDecay=0., lossObj=nn.BCEWithLogitsLoss(), nEpochs=EPOCHS, 
#                          learningRate=0.0001, cosineCycle=5)
myQuantEngine.prepareForTrain(batchSize=BATCH_SIZE, weightDecay=0.000002, lossObj=lossObj, nEpochs=EPOCHS, 
                         learningRate=0.000015, cosineCycle=5)
# myEngine.prepareForTrain(batchSize=BATCH_SIZE, weightDecay=0., lossObj=None, nEpochs=EPOCHS, 
#                          learningRate=0.0003, cosineCycle=5)


# myEngine.prepareForTrain(engineParams="/opt/infilect/dev/storage2/shyam/Data/KH_USA/KSSB_PastaSauce/splitData/ModelLogs/Sigmoid_KsPs/Run_3/checkPoints/5_0.001/engineParams.pth",
#                          batchSize=BATCH_SIZE, weightDecay=0., lossObj=nn.BCEWithLogitsLoss(), nEpochs=EPOCHS, 
#                          learningRate=0.0001, cosineCycle=5)

myQuantEngine.trainConfig["label2Idx"], myQuantEngine.trainConfig["idx2Label"] = trainSet.label2Idx.copy(), trainSet.idx2Label.copy()

In [None]:
with open(F"{myQuantEngine.trainConfig['savePath']}/label2Idx.json", 'w') as jF: json.dump(trainSet.label2Idx, jF)
with open(F"{myQuantEngine.trainConfig['savePath']}/idx2Label.json", 'w') as jF: json.dump(trainSet.idx2Label, jF)

In [None]:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(myQuantEngine.optimObj, T_max = 1, eta_min=0.00001, verbose = True)
lrSched = GradualWarmupScheduler(myQuantEngine.optimObj, 1, 4, scheduler)
myQuantEngine.setupLRScheduler(lrSched)

In [None]:
myQuantEngine.startTraining(trainloader1, testloader1)

In [None]:
myQuantEngine.writeModelToDisk(F"{myQuantEngine.trainConfig['savePath']}/checkPoints/", "FinalParams.pth")
fPath = F"{myQuantEngine.trainConfig['savePath']}/checkPoints/FinalParams.pth"

In [None]:
x = myQuantEngine.bestModel

In [None]:
modelPath_x = "/opt/infilect/dev/repos/image_classification_quant/models/QUANTIZED_AWARE_TRAINING/quant_resnet50_2/checkPoints/Epoch_4/engineParams.pth"

In [None]:
engineParams1 = torch.load(modelPath_x, map_location="cuda")
modelParams1 = engineParams1["modelParams"]
quant_model.load_state_dict(engineParams1["modelParams"])

In [None]:
x = x.to('cpu')
x.eval()
model_int8 = torch.ao.quantization.convert(x)

In [None]:
import os
def save_torchscript_model(model, model_dir, model_filename):

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_filepath = os.path.join(model_dir, model_filename)
    torch.jit.save(torch.jit.script(model), model_filepath)

In [None]:
model = model_int8
model_dir = "/opt/infilect/dev/repos/image_classification_quant/models/QUANTIZED_AWARE_TRAINING_1/quant_model"
model_filename = "model_int8_3.pth"
save_torchscript_model(model, model_dir, model_filename)

In [None]:
def load_torchscript_model(model_filepath, device):

    model = torch.jit.load(model_filepath, map_location=device)

    return model

In [None]:
# torch.save(model_int8, F"{myQuantEngine.trainConfig['savePath']}/checkPoints/MODEL_INT8.pth")

# Evaluation

In [None]:
from typing import Union, List, Tuple, Dict
import json
from copy import deepcopy
import os

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, f1_score, accuracy_score
from sklearn.metrics import balanced_accuracy_score, multilabel_confusion_matrix

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.models as torchModels

In [None]:
import torch.nn as nn
import torchvision.models as models
class ClassHead(nn.Module):
    def __init__(self, inpUnits=1280, opOut=2):
        super().__init__()
        self.seqBlock_0 = nn.Sequential(torch.nn.Dropout(0.3), torch.nn.Linear(inpUnits, opOut))
    def forward(self, inpX: torch.Tensor) -> torch.Tensor: return self.seqBlock_0(inpX)
    pass
class BasicModel(nn.Module):
    def __init__(self, model_name: str, nClasses: int):
        super().__init__()
        self.backBone = models.get_model(model_name, weights = "DEFAULT")
        # self.backBone = torchModels.resnet50(weights=torchModels.ResNet50_Weights.IMAGENET1K_V2)
        self.backBone.fc = nn.Identity()
        self.classHead = ClassHead(2048, nClasses)
    def forward(self, inpX: torch.Tensor) -> torch.Tensor:
        embX = self.backBone(inpX)
        out = self.classHead(embX)
        return out

class QuantizedModel(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedModel, self).__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.quantization.DeQuantStub()
        # FP32 model
        self.model_fp32 = model_fp32
    def forward(self, inpX: torch.Tensor) -> torch.Tensor:
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        embX = self.quant(inpX)
        quantX = self.model_fp32(embX)
        out_quant = self.dequant(quantX)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
#         deQuantX = self.dequant(quantX)
        return out_quant 

In [None]:
def savePred(savePath: str, gtName: str, prName: str, savImg: str, imgIdx: int):
    os.makedirs(F"{savePath}/{gtName}/{prName}", exist_ok=True)
    img = Image.open(savImg)
    img.save(F"{savePath}/{gtName}/{prName}/{imgIdx}.png")
    pass

@torch.no_grad()
def onEval(thisModel, evalLoader):
    predVals = { k:[] for k in [50, 60, 70, 80, 90, 95] }
    gtVals = []
    with torch.no_grad():
        for i, s in enumerate(evalLoader):
            print(F" Working on image {i+1} of {len(evalLoader)}")
            
            tX, tY = torch.unsqueeze(s[0][0], 0).to(torDevice), s[1][0].cpu().numpy()
            yH = thisModel(tX)[0].detach().sigmoid().cpu().numpy()
            if tY.sum() == 0: gtVals.append(-1)
            else: gtVals.append(np.argmax(tY))
            # gtVals.append(-1)
            
            yhMax = yH.max()
            for k in predVals.keys(): predVals[k].append(-1)
            if yhMax < 0.5: continue
            
            if yhMax > 0.5: predVals[50][-1] = (np.argmax(yH))
            if yhMax > 0.6: predVals[60][-1] = (np.argmax(yH))
            if yhMax > 0.7: predVals[70][-1] = (np.argmax(yH))
            if yhMax > 0.8: predVals[80][-1] = (np.argmax(yH))
            if yhMax > 0.9: predVals[90][-1] = (np.argmax(yH))
            if yhMax > 0.95: predVals[95][-1] = (np.argmax(yH))
            pass
        pass
    
    return gtVals, predVals

@torch.no_grad()
def onEvalOld(thisModel, evalLoader, idxToLabel, savePath=None):#, validLabels, newLabelIdxMap):
    
    predVals, gtVals = {50:[]}, []
    thisModel.eval()
    with torch.no_grad():
        for i, s in enumerate(evalLoader):
            print(F" Working on image {i+1} of {len(evalLoader)}")
#             print(evalLoader.dataList[i][1])
            
            tX, tY = torch.unsqueeze(s[0][0], 0).to("cpu"), s[1][0].cpu().numpy()
            yH = thisModel(tX).detach().softmax(dim=1).cpu().numpy().argmax(1)
#             print(yH)
            
            if tY.sum() == 0: 
                gtVals.append(-1)
            else: 
                gtVals.append(np.argmax(tY))
                
            predVals[50].append(yH)
            
            if savePath is None: continue
            savePred(savePath, idx2Label[gtVals[-1]], idx2Label[int(yH)], evalLoader.dataList[i][1], i)
        pass
    return gtVals, predVals

@torch.no_grad()
def onEvalSigMax(thisModel, evalLoader, thisSetMap, idxToLabel, labelToIdx, valLabelNames, valLabelIdx):
    
    modelOut, predVals, gtVals = [], {50:[]}, []
    with torch.no_grad():
        for i, s in enumerate(evalLoader):
            print(F" Working on image {i+1} of {len(evalLoader)}")
            
            tX, tY = torch.unsqueeze(s[0][0], 0).to(torDevice), np.argmax(s[1][0].cpu().numpy())
            yHat = thisModel(tX).detach()
            modelOut.append(yHat.cpu())
            yH = yHat.softmax(dim=1).cpu().numpy().argmax(1)[0]
            
            if thisSetMap[tY] in valLabelNames: gtVals.append(labelToIdx[thisSetMap[tY]])
            else: gtVals.append(-1)
            
            # if yHat[0].sigmoid().cpu().numpy()[valLabelIdx].max() < 0.4: predVals[50].append(-1)
            if idxToLabel[yH] in valLabelNames: predVals[50].append(yH)
            else: predVals[50].append(-1)
        pass
    return gtVals, predVals, modelOut

@torch.no_grad()
def onEvalGroup(thisModel, evalLoader, idxToLabel, labelsGroup):#, validLabels, newLabelIdxMap):
    
    predVals, gtVals = {50:[]}, []
    qcCheck = [set(cg) for cg in labelsGroup]
    with torch.no_grad():
        for i, s in enumerate(evalLoader):
            print(F" Working on image {i+1} of {len(evalLoader)}")
            
            tX, tY = torch.unsqueeze(s[0][0], 0).to(torDevice), s[1][0].cpu().numpy()
            yH = thisModel(tX)[0].detach()# .softmax(0).cpu().numpy().argmax()

            if tY.sum() == 0: gtVals.append(-1)
            else: gtVals.append(np.argmax(tY))
            
            for i, cg in enumerate(labelsGroup):
                if not idxToLabel[gtVals[-1]] in qcCheck[i]: continue
                finPred = torch.tensor([yH[testSet.label2Idx[l]] for l in cg]).softmax(0).numpy().argmax()
                finPred = testSet.label2Idx[cg[finPred]]
            predVals[50].append(finPred)
        pass
    return gtVals, predVals

In [None]:
def getMetricJson(gTs, mPs, idx2Label):
    
    outDict, aggMetrics = {}, {k:{} for k in mPs}
    allClasses = np.unique(gTs)
    
    for cI in allClasses:
        outDict[idx2Label[cI+1]] = {}
        for k in mPs:
            outDict[idx2Label[cI+1]][k] = {}
            tPrec = np.nan_to_num((gTs[mPs[k] == cI] == cI).sum() / (mPs[k] == cI).sum())
            tRcall = np.nan_to_num((mPs[k][gTs == cI] == cI).sum() / (gTs == cI).sum())
            outDict[idx2Label[cI+1]][k]["F1Score"] = np.nan_to_num(2 * ((tPrec*tRcall) / (tPrec+tRcall)))
            outDict[idx2Label[cI+1]][k]["precisionVal"] = tPrec
            outDict[idx2Label[cI+1]][k]["recallVal"] = tRcall
            
            aggMetrics[k]["precisionVal"] = aggMetrics[k].get("precisionVal", 0) + tPrec
            aggMetrics[k]["recallVal"] = aggMetrics[k].get("recallVal", 0) + tRcall
            aggMetrics[k]["F1Score"] = aggMetrics[k].get("F1Score", 0) + outDict[idx2Label[cI+1]][k]["F1Score"]
            pass
    
    for k in mPs.keys():
        aggMetrics[k]["Accuracy"] = (gTs == mPs[k]).sum() / len(gTs)
        aggMetrics[k]["BalAccuracy"] = balanced_accuracy_score(gTs, mPs[k])
        
        aggMetrics[k]["precisionVal"] = aggMetrics[k]["precisionVal"] / len(allClasses)
        aggMetrics[k]["recallVal"] = aggMetrics[k]["recallVal"] / len(allClasses)
        aggMetrics[k]["F1Score"] = aggMetrics[k]["F1Score"] / len(allClasses)
        pass
        
    return outDict, aggMetrics

def getMetricJson_Legacy(gTs, mPs, idx2Label):
    
    outDict, aggMetrics = {}, {k:{} for k in mPs}
    allClasses = np.unique(gTs)
    
    for cI in allClasses:
        outDict[idx2Label[cI]] = {}
        for k in mPs:
            outDict[idx2Label[cI]][k] = {}
            tPrec = np.nan_to_num((gTs[mPs[k] == cI] == cI).sum() / (mPs[k] == cI).sum())
            tRcall = np.nan_to_num((mPs[k][gTs == cI] == cI).sum() / (gTs == cI).sum())
            outDict[idx2Label[cI]][k]["F1Score"] = np.nan_to_num(2 * ((tPrec*tRcall) / (tPrec+tRcall)))
            outDict[idx2Label[cI]][k]["precisionVal"] = tPrec
            outDict[idx2Label[cI]][k]["recallVal"] = tRcall
            
            aggMetrics[k]["precisionVal"] = aggMetrics[k].get("precisionVal", 0) + tPrec
            aggMetrics[k]["recallVal"] = aggMetrics[k].get("recallVal", 0) + tRcall
            aggMetrics[k]["F1Score"] = aggMetrics[k].get("F1Score", 0) + outDict[idx2Label[cI]][k]["F1Score"]
            pass
    
    for k in mPs.keys():
        aggMetrics[k]["Accuracy"] = (gTs == mPs[k]).sum() / len(gTs)
        aggMetrics[k]["BalAccuracy"] = balanced_accuracy_score(gTs, mPs[k])
        
        aggMetrics[k]["precisionVal"] = aggMetrics[k]["precisionVal"] / len(allClasses)
        aggMetrics[k]["recallVal"] = aggMetrics[k]["recallVal"] / len(allClasses)
        aggMetrics[k]["F1Score"] = aggMetrics[k]["F1Score"] / len(allClasses)
        pass
        
    return outDict, aggMetrics

In [None]:
import numpy as np
import os
import cv2
import csv

class ClassificationMetrics:
    """Classification metrics class. Create an object of this class
    to individually track metric of different models/datasets/runs.
    """
    def __init__(self, name, classes):
        """Initialise an object of this class.
        
        Args:
            name (str): A name given to a metrics object. Currently not being used for anything but
            could be used for uniquely identify a metrics object by name
            classes (list): A list of names of classes.
        """
        self.name = name
        self.classes = sorted(classes)
        # Check if all elements in the classes list is unique.
        assert len(set(self.classes)) == len(classes)

        self.num_classes = len(classes)
        self.class2int = {c:i for i,c in enumerate(self.classes)}
        self.int2class = {self.class2int[c]:c for c in self.class2int}
        # Initialise the confusion matrix with all zeros
        self.classification_confusion_matrix = np.zeros((len(self.classes),len(self.classes)))

        self.font = cv2.FONT_HERSHEY_SIMPLEX
        self.fontScale = 0.8
        self.fontColor = (255, 0, 0)
        self.lineType = 2

    def create_cell(self,color_value,text,shape):
        """Create a cell image to be a part of confusion matrix image. This is just a square
        patch of some given size with a given color value and the text written in the center
        
        Args:
            color_value (str): Background Color of the cell
            text (str): Text to be displayed on the cell
            shape (list): Shape of cell in format [h,w]
        
        Returns:
            numpy.ndarray: A cell image of shape [h,w,3].
        """

        color_value = int(np.clip(color_value,0,255)) if not np.isnan(color_value) else 127
        text = str(text)
        cell_shape = (100,100)
        (label_width, label_height), baseline = cv2.getTextSize(text, self.font, self.fontScale, self.lineType)
        label_patch = np.ones((cell_shape[0], cell_shape[1], 3), np.uint8)*color_value
        
        textX = int((cell_shape[1] - label_width) / 2)
        textY = int((cell_shape[0] + label_height) / 2)
        cv2.putText(label_patch, text, (textX, textY), self.font, self.fontScale, self.fontColor, self.lineType)
        cell = cv2.resize(label_patch,(shape[1],shape[0]))
        return cell
        
    def generate_confusion_matrix_image(self,image_save_path):
        """Generate an image representation of the confusion matrix. This could be displayed in the
        tensorboard while training or could be used while showcasing model performance. This confusion
        matrix also includes the number of samples along with the color coding of cells
        
        Args:
            image_save_path (str): Path to store the confusion matrix image
        
        Returns:
            str: image_save_path, same as what was passed as argument
        """
        confusion_matrix_norm = self.classification_confusion_matrix/self.classification_confusion_matrix.sum(axis=1).reshape(-1,1)
        cell_shape = (80,80)
        confusion_matrix_image_header = np.hstack([self.create_cell(255,self.num_classes,cell_shape)] + [self.create_cell(150,i,cell_shape) for i in range(self.num_classes)])
        confusion_matrix_image = [confusion_matrix_image_header]

        for i in range(self.num_classes):
            confusion_matrix_image_row = [self.create_cell(150,i,cell_shape)]
            for j in range(self.num_classes):
                confusion_matrix_image_cell = self.create_cell(255-confusion_matrix_norm[i][j]*255,int(self.classification_confusion_matrix[i][j]),cell_shape)
                confusion_matrix_image_row.append(confusion_matrix_image_cell)
            confusion_matrix_image.append(np.hstack(confusion_matrix_image_row))
        confusion_matrix_image = np.vstack(confusion_matrix_image)

        if not os.path.exists(os.path.dirname(image_save_path)):
            os.makedirs(os.path.dirname(image_save_path))

        cv2.imwrite(image_save_path,confusion_matrix_image)
        print("Save confusion matrix image at",image_save_path)
        return image_save_path

    def add_sample(self, gt_class, pred_class):
        """Add a sample to the confusion matrix. A sample is one data point containing
        a ground truth and a prediction. This function also check if the ground truth and
        prediction exists in the classes or not
        
        Args:
            gt_class (str): Name of the ground truth class
            pred_class (str): Name of the predicted class
        """
        assert gt_class in self.classes, "class {} not found in classes".format(gt_class)
        assert pred_class in self.classes, "class {} not found in classes".format(pred_class)

        self.classification_confusion_matrix[self.class2int[gt_class]][self.class2int[pred_class]] += 1
    
    def get_precision_recall_accuracy_support(self, average_mode='none'):
        """Get performance metrics of current state of the object. This function will return
        precision, recall, accuracy and support. By default it will return a list for each of
        those of size len(classes) for per class metric. Optionally it can also return average
        over all classes
        
        Args:
            average_mode (bool, optional): If this is set to True, average over all classes will be returned
            If set to false [default], a list of per class metric will be returned. Defaults to False.
        
        Returns:
            list: list of [precision,recall,accuracy,support]
        """
        # Row wise and column wise sum for calculating precision and recall
        hr_sum = np.sum(self.classification_confusion_matrix,axis=1)
        vr_sum = np.sum(self.classification_confusion_matrix,axis=0)

        # Lists for per class precision, recall, accuracy, support
        precisions = []
        recalls = []
        accuracies = []
        supports = []

        for i,c in enumerate(self.classes):
            # tp: True positive is the number that appears on the diagonal

            tp = self.classification_confusion_matrix[i][i]

            # fp: For a given prediction i, False positive are all the predictions that were not actually i.
            fp = vr_sum[i]-tp

            # fn: For a ground truth i, False negative is all the ground truths that were not predicted as i.
            fn = hr_sum[i]-tp
            
            if tp == 0:
                precision = 0.0
                recall = 0.0
                accuracy = 0.0
            else:
                precision = (tp/(tp+fp))
                recall = (tp/(tp+fn))
                accuracy = tp/(tp+fp+fn)
            
            support = (tp+fn)

            if support==0:
                precision = 1.0
                recall = 1.0
                accuracy = 1.0 

            precisions.append(precision)
            recalls.append(recall)
            accuracies.append(accuracy)
            supports.append(support)

        if average_mode == 'weighted':
            # compute weighted average of accuracies, precision and recall.
            weighted_average_accuracy, weighted_average_precision, weighted_average_recall = 0., 0., 0.
            total_instances = sum(supports)
            for num_instances, accuracy, precision, recall in zip(supports, accuracies, precisions, recalls):
                class_fraction = num_instances/total_instances
                weighted_accuracy = class_fraction * accuracy
                weighted_precision = class_fraction * precision
                weighted_recall = class_fraction * recall
                weighted_average_accuracy += weighted_accuracy
                weighted_average_precision += weighted_precision
                weighted_average_recall += weighted_recall

            return weighted_average_precision, weighted_average_recall, weighted_average_accuracy, np.mean(supports)

        elif average_mode == 'unweighted':
            return np.mean(precisions), np.mean(recalls), np.mean(accuracies), np.mean(supports)

        return np.array(precisions), np.array(recalls), np.array(accuracies), np.array(supports)

    def create_confusion_matrix_report(self, csv_path):
        """Create a CSV report of the confusion matrix. This report will also have per class
        accuracy, precision, recall and support.
        
        Args:
            csv_path (str): Path to store the csv report file
        
        Returns:
            str: csv_path, same as passed in parameters
        """
        header = ["confusion matrix"] + self.classes + ["ground truth","precision", "recall", "accuracy","support","prediction","num predictions"]
        rows = [header]

        precisions, recalls, accuracies, supports = self.get_precision_recall_accuracy_support()

        for i,c in enumerate(self.classes):
            row = [c]+[int(x) for x in self.classification_confusion_matrix[i]]
            precision = precisions[i]
            recall = recalls[i]
            accuracy = accuracies[i]
            support = supports[i]

            row.append(c)
            row.append(precision)
            row.append(recall)
            row.append(accuracy)
            row.append(support)

            pred_index = np.argmax(self.classification_confusion_matrix[i])
            row.append(self.classes[pred_index])
            row.append(self.classification_confusion_matrix[i][pred_index])
        #     print(len(row))
            rows.append(row)
        
        footer = [""] + [""]*self.num_classes + ["Average",np.mean(precisions), np.mean(recalls), np.mean(accuracies),"","",""]
        rows.append(footer)
        
        if not os.path.exists(os.path.dirname(csv_path)):
            os.makedirs(os.path.dirname(csv_path))

        with open(csv_path, 'w') as csvFile:
            writer = csv.writer(csvFile)
            writer.writerows(rows)
        print("Saved classification report at",csv_path)

        return csv_path

In [None]:
modelPath = "/opt/infilect/dev/repos/image_classification_quant/models/QUANTIZED_AWARE_TRAINING_1/quant_model"
dataPath = "/opt/infilect/dev/dataset/resize"
jsnFilePath = "/opt/infilect/dev/repos/image_classification_quant/models/QUANTIZED_AWARE_TRAINING_1/quant_resnet50_1"
savePath = "/opt/infilect/dev/repos/image_classification_quant/models/QUANTIZED_AWARE_TRAINING_1/VIZ1"
with open(f"{jsnFilePath}/idx2Label.json", 'r') as jF: idx2Label = json.load(jF)
with open(f"{jsnFilePath}/label2Idx.json", 'r') as jF: label2Idx = json.load(jF)
idx2Label = {int(k):v for k, v in idx2Label.items()}

# with open("/opt/infilect/dev/storage2/shyam/Data/KH_USA/DPD/splitData/ModelLogs/SigMax/Run_1/khusa-brand-dpd-sigmax-1/tag_to_label.json", 'r') as jF: label2Idx = json.load(jF)
# idx2Label = {int(v): k for k, v in label2Idx.items()}


In [None]:
finalLabels = list(idx2Label.values())

In [None]:
finalLabels, len(finalLabels)

In [None]:
testAug = DataAugmentor(0., 0., 0., 0., 0., 0., 0., 0., normImage=True)
# testSet = SimpleLoader(F"{dataPath}/test", augObject=testAug, 
#                validLbls=set(), cacheImgs=True, label2Idx=None, idx2Label=None)
test_set = SimpleLoader(F"{dataPath}/test", augObject=testAug, 
               validLbls=set(), cacheImgs=False, label2Idx=label2Idx.copy(), idx2Label=idx2Label.copy())

In [None]:
testloader = torch.utils.data.DataLoader(test_set, batch_size = 1, shuffle = True)

In [None]:
modelPath = F"{modelPath}/model_int8_3.pth"

In [None]:
# myModel = BasicModel("resnet50", len(idx2Label)).to("cuda")

In [None]:
def load_torchscript_model(model_filepath, device = "cpu"):

    model = torch.jit.load(model_filepath, map_location=device)

    return model

In [None]:
myModel1 = load_torchscript_model(model_filepath = modelPath)

In [None]:
myModel1

In [None]:
# engineParams = torch.load(modelPath, map_location="cpu")
# modelParams = engineParams["modelParams"]
# myModel.load_state_dict(engineParams["modelParams"])

In [None]:
def evaluate_model(model, test_loader, device, criterion=None):

    model.eval()
    model.to(device)

    running_loss = 0
    running_corrects = 0

    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    return eval_loss, eval_accuracy

In [None]:
_, x = evaluate_model(model = myModel1, test_loader = test_set, device = "cpu", criterion=None)

In [None]:
validClassIdx = set([v for k, v in label2Idx.items() if k in finalLabels])
validClassIdx

In [None]:
os.makedirs(savePath, exist_ok = True)
gV_N, pV_N = onEvalOld(myModel1, test_set, idx2Label, f"{savePath}/test_eval_quant_3")

In [None]:
gV_N = np.asarray(gV_N)
pV_N = {k:np.asarray(pV_N[k]) for k in pV_N.keys()}

In [None]:
pV_N[50] = pV_N[50].flatten()

In [None]:
for k in pV_N.keys(): print(k, (pV_N[k] == gV_N).sum() / len(gV_N))

In [None]:
newIdxLabelMap = deepcopy(idx2Label)
# newIdxLabelMap[-1] = "notImportant"

In [None]:
isLegacy = True

In [None]:
if isLegacy: micMetrics, macMatrics = getMetricJson_Legacy(gV_N, pV_N, newIdxLabelMap)
else: micMetrics, macMatrics = getMetricJson(gV_N, pV_N, newIdxLabelMap)

In [None]:
micMetrics

In [None]:
macMatrics

In [None]:
for k in pV_N.keys():
    metClass = ClassificationMetrics("demo", list(newIdxLabelMap.values()))
    if isLegacy: 
        for g, p in zip(gV_N, pV_N[k]): metClass.add_sample(newIdxLabelMap[g], newIdxLabelMap[p])
    else: 
        for g, p in zip(gV_N, pV_N[k]): metClass.add_sample(newIdxLabelMap[g+1], newIdxLabelMap[p+1])
    avgPrec, avgRcall, avgAcc = [np.mean(i) for i in metClass.get_precision_recall_accuracy_support()[:3]]
    print(" Threshhold :", k, "Precision", avgPrec, " Recall", avgRcall, " Accuracy", avgAcc)

In [None]:
metClass = ClassificationMetrics("demo", list(newIdxLabelMap.values()))
for g, p in zip(gV_N, pV_N[50]): 
    if isLegacy: metClass.add_sample(newIdxLabelMap[g], newIdxLabelMap[p])
    else: metClass.add_sample(newIdxLabelMap[g+1], newIdxLabelMap[p+1])
metClass.create_confusion_matrix_report(F"{savePath}/quant_resnet50_2.csv")

# Trace Model

In [None]:
def load_model2(model, model_filepath, device):
    model_skeleton = torch.load(model_filepath, device)
    print(model_skeleton)

    model.load_state_dict(model_skeleton["modelParams"])

    return model

In [None]:
model = quantized_model
model_filepath = "/opt/infilect/dev/repos/image_classification_quant/models/QUANTIZATION/quant_model_2/checkPoints/Epoch_2/engineParams.pth"
device = "cuda"

quant_model = load_model2(model, model_filepath, device)

In [None]:
import os
def trace_model(modelArch, example, pathToSave:str, name:str):
    
    modelArch.to("cuda")
    
    modelArch.eval()
    
    model = torch.jit.trace(modelArch, torch.randn(example))
    
    model.save(os.path.join(pathToSave, f"{name}.pth"))

In [None]:
modelArch = model_int8
example = (1, 3, 1125, 1500)
pathToSave = "/opt/infilect/dev/repos/image_classification_quant/models/traced"
name = "quant_int8_resnet50"
trace_model(modelArch, example, pathToSave, name)


In [None]:
from pathlib import Path
import matplotlib.pyplot as plt

def trainTestPlot(plot, train_accu, test_accu, train_losses, test_losses, model_name):

    if plot:
        Path('plot/').mkdir(parents=True, exist_ok=True)
        plot1 = plt.figure(1)
        plt.plot(train_accu, '-o')
        plt.plot(test_accu, '-o')
        plt.xlabel('epoch')
        plt.ylabel('accuracy')
        plt.legend(['Train','Test'])
        plt.title('Train vs Test Accuracy')            
        plt.savefig('plot/'+model_name+'_train_test_acc.png')

        plot2 = plt.figure(2)
        plt.plot(train_losses,'-o')
        plt.plot(test_losses,'-o')
        plt.xlabel('epoch')
        plt.ylabel('losses')
        plt.legend(['Train','Test'])
        plt.title('Train vs Test Losses')
        plt.savefig('plot/'+model_name+'_train_test_loss.png')

In [None]:
trainTestPlot(plot = True, train_accu = train_acc, test_accu = val_acc, train_losses = train_loss, test_losses = val_loss, model_name = "resnet50")

In [None]:
n_epochs = 50
print_every = 10
valid_loss_min = np.Inf
val_loss = []
val_acc = []
train_loss = []
train_acc = []
total_step = len(trainloader)
model.to(device)

for epoch in range(1, n_epochs+1):
    
    model.train()
    
    running_loss = 0.0
    correct = 0
    total=0
    print(f'Epoch {epoch}\n')
    for batch_idx, (data_, target_) in enumerate(trainloader):
        data_, target_ = data_[0].to(device), target_[0.to(device)
        optimizer.zero_grad()
        data_ = quant(data_)
        outputs = model(data_)
        loss = criterion(outputs, target_)
        loss.backward()
        optimizer.step()
        quantizer.step()

        running_loss += loss.item()
        _,pred = torch.max(outputs, dim=1)
        correct += torch.sum(pred==target_).item()
        total += target_.size(0)
        if (batch_idx) % 100 == 0:
            print ('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}' 
                   .format(epoch, n_epochs, batch_idx, total_step, loss.item()))
    train_acc.append(100 * correct / total)
    train_loss.append(running_loss/total_step)
    print(f'\ntrain-loss: {np.mean(train_loss):.4f}, train-acc: {(100 * correct/total):.4f}')
    batch_loss = 0
    total_t=0
    correct_t=0
    with torch.no_grad():
        model.eval()
        for data_t, target_t in (testloader):
            data_t, target_t = data_t.to(device), target_t.to(device)
            outputs_t = model(data_t)
            loss_t = criterion(outputs_t, target_t)
            batch_loss += loss_t.item()
            _,pred_t = torch.max(outputs_t, dim=1)
            correct_t += torch.sum(pred_t==target_t).item()
            total_t += target_t.size(0)
        val_acc.append(100 * correct_t/total_t)
        val_loss.append(batch_loss/len(testloader))
        network_learned = batch_loss < valid_loss_min
        print(f'validation loss: {np.mean(val_loss):.4f}, validation acc: {(100 * correct_t/total_t):.4f}\n')

        
        if network_learned:
            valid_loss_min = batch_loss
            torch.save(model.state_dict(), f'/opt/infilect/dev/repos/image_classification/models/resnet_{epoch}.pth')
            print('Improvement-Detected, save-model')
    
    lrSched.step()

In [None]:
from torchvision.models import *

def create_model(num_classes=10):

    # The number of channels in ResNet18 is divisible by 8.
    # This is required for fast GEMM integer matrix multiplication.
    # model = torchvision.models.resnet18(pretrained=False)
    model = resnet18(num_classes=num_classes, pretrained=False)

    # We would use the pretrained ResNet18 as a feature extractor.
    # for param in model.parameters():
    #     param.requires_grad = False

    # Modify the last FC layer
    # num_features = model.fc.in_features
    # model.fc = nn.Linear(num_features, 10)

    return model

In [None]:
model1 = create_model(num_classes=2)

In [None]:
model1

In [None]:
for module_name, module in model1.named_children():
    print(module_name)

In [None]:
fused_model1 = torch.ao.quantization.fuse_modules_qat(model1, [["conv1", "bn1", "relu"]], inplace=False)

In [None]:
fused_model1

In [None]:
import torch
yHAT = [torch.tensor([[-0.1099,  0.0676],
        [ 0.0213, -0.1472]], device='cuda:0')]

In [None]:
yHAT[0]

In [None]:
x