In [41]:
import mimetypes
from pathlib import Path
import os
import sys
import random
import shutil
from tempfile import TemporaryDirectory
import cv2
import matplotlib.pyplot as plt
import json
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from collections import Counter
import torch as tc
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.optim as optim
import albumentations as A
import torch.nn as nn
from tqdm.notebook import trange, tqdm

In [2]:
dataDir = TemporaryDirectory(dir='/content/sample_data', prefix='github_')
!git clone https://github.com/waruna-wickramasingha/landmark-detection.git $dataDir.name
!git -C $dataDir.name/landmarks checkout feature/datarestructure

Cloning into '/content/sample_data/github_ouhb_gxq'...
remote: Enumerating objects: 1913, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 1913 (delta 2), reused 4 (delta 1), pack-reused 1907[K
Receiving objects: 100% (1913/1913), 1.91 GiB | 34.36 MiB/s, done.
Resolving deltas: 100% (46/46), done.
Updating files: 100% (1663/1663), done.
Updating files: 100% (3310/3310), done.
Branch 'feature/datarestructure' set up to track remote branch 'feature/datarestructure' from 'origin'.
Switched to a new branch 'feature/datarestructure'


In [3]:
def getFilesList(path: str, desiredExtensionList, recursion=False):
  fileList = []
  path = Path(path)
  if recursion == False:
    for x in path.iterdir():
      if x.suffix.lower() in desiredExtensionList:
        fileList.append(x)
  else:
    for x in path.glob('**/*'):
      if x.suffix.lower() in desiredExtensionList:
        fileList.append(x)
  return fileList

In [4]:
def getImageNameToClassMap(annotationsDir:str):
  aggrImageNameToClassMap = dict()
  annotationsDir = Path(annotationsDir)
  try:
    for x in annotationsDir.iterdir():
      if x.suffix.lower() == ".json":
        with open(x, 'r') as f:
          annotationDict = json.load(f)
          annList = list(map(dict, annotationDict.values()))
          image_name_to_class_map = { it['filename'] : it['regions'][0]['region_attributes']['class'].lower() for it in annList }
          aggrImageNameToClassMap.update(image_name_to_class_map)
  except:
    raise Exception("Failed to aggregate annotated json files")
  finally:
    print("Annotation files aggragated successfully!")
  return aggrImageNameToClassMap

In [5]:
im_name_to_class = getImageNameToClassMap(dataDir.name + '/annotations')
len(im_name_to_class)

Annotation files aggragated successfully!


1635

In [6]:
all_classes = set(im_name_to_class.values())
all_classes

{'aldi wallisdown',
 'aniba',
 'art studios',
 'arts bar',
 'arts bu library',
 'ashley automotive',
 'auds',
 'aush',
 'auss',
 'autg',
 'baboo ji',
 'bailey point',
 'beales',
 'bellaton house',
 'block a and block b arts university',
 'bobbys',
 'bournemouth and poole college',
 'bournemouth gateway building',
 'bu executive business centre',
 'bu international college',
 'bu lansdowne',
 'bu student house',
 'buch',
 'burley court hotel',
 'buta',
 'buth',
 'careers center',
 'cbd flower shop',
 'church',
 'coop store',
 'court royal',
 'courtleigh manor',
 'cranborne house',
 'design and engineering innovation center',
 'dorchester house',
 'dorset house',
 'east cliff urc church',
 'enterprise house',
 'fairways care home',
 'fern arrow roundabout',
 'fusion building',
 'gorscliff court',
 'hilton',
 'home park',
 'hot rocks',
 'iq building',
 'jakey house',
 'kimmeridge house',
 'lester aldridge',
 'lidl bournemouth',
 'lush',
 'mccarthy stone head office',
 'minton lodge hotels

In [7]:
len(all_classes)

99

In [8]:
for k,v in im_name_to_class.items():
  if v == '':
    print(k)

In [9]:
beforeAugImagesRoot = TemporaryDirectory(dir='/content/sample_data', prefix='BeforeAugmentation_')
testImagesRoot = TemporaryDirectory(dir='/content/sample_data', prefix='Test_')

In [10]:
image_extensions = [k for k,v in mimetypes.types_map.items() if 'image/' in v]
random.seed(13)

def reArrangeImagesIntoTrainTest(lm_dir:str, trainDir:str, testDir:str, trainSplit=0.75):
  """
  This method would split the original list of images provided in imagesList 
  by moving them into trainDir and testDir keeping the original file name.
  """
  global image_extensions
  imagesList = getFilesList(lm_dir, image_extensions)
  imagesList = np.array(imagesList)
  randSelections = np.random.permutation(len(imagesList))

  trainTestCutPoint = int(len(imagesList)*trainSplit)
  trainImages = imagesList[randSelections[:trainTestCutPoint]]
  testImages = imagesList[randSelections[trainTestCutPoint:]]

  lm_dir_name = lm_dir.split(os.sep)[-1]

  try:
    #Re-arrange train images
    lm_train_dir = TemporaryDirectory(dir=trainDir, prefix=lm_dir_name)
    print("Creating Train directory={}".format(lm_train_dir.name))
    for p in trainImages:
      imageName = str(p).split(os.sep)[-1]
      newPath = lm_train_dir.name + os.sep + imageName
      shutil.move(p, newPath)

    #Re-arrange test images
    lm_test_dir = TemporaryDirectory(dir=testDir, prefix=lm_dir_name)
    print("Creating Test directory={}".format(lm_test_dir.name))
    for p in testImages:
      imageName = str(p).split(os.sep)[-1]
      newPath = lm_test_dir.name + os.sep + imageName
      shutil.move(p, newPath)
  except:
    raise Exception("Failed to move images into train={} and test={} directories".
                    format(lm_train_dir.name, lm_test_dir.name))
  return (lm_train_dir, lm_test_dir)

In [11]:
landmarks_dir = Path(dataDir.name+os.sep+'landmarks')
tempdirs = []
for lm in landmarks_dir.glob('**/'):
  if lm == landmarks_dir:
    continue
  print("Original direcotry={}".format(lm))
  tempdirs.append(reArrangeImagesIntoTrainTest(str(lm), beforeAugImagesRoot.name, testImagesRoot.name, 0.75))

Original direcotry=/content/sample_data/github_ouhb_gxq/landmarks/Court Royal
Creating Train directory=/content/sample_data/BeforeAugmentation_qsize3tf/Court Royal71rf0ppu
Creating Test directory=/content/sample_data/Test_6j2gl5ac/Court Royal5_oxrazk
Original direcotry=/content/sample_data/github_ouhb_gxq/landmarks/Beales
Creating Train directory=/content/sample_data/BeforeAugmentation_qsize3tf/Bealestk1zrwoo
Creating Test directory=/content/sample_data/Test_6j2gl5ac/Beales31ckztzg
Original direcotry=/content/sample_data/github_ouhb_gxq/landmarks/bournemouth and poole college
Creating Train directory=/content/sample_data/BeforeAugmentation_qsize3tf/bournemouth and poole college9u1kndbo
Creating Test directory=/content/sample_data/Test_6j2gl5ac/bournemouth and poole collegesu0wna_4
Original direcotry=/content/sample_data/github_ouhb_gxq/landmarks/Sprinkles Gelato
Creating Train directory=/content/sample_data/BeforeAugmentation_qsize3tf/Sprinkles Gelatomdn6t23f
Creating Test directory=/c

In [12]:
getFilesList(beforeAugImagesRoot.name, image_extensions, recursion=True)

[PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/The Kings Arms Bar95hoodg9/The Kings Arms Bar_4.JPG'),
 PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/The Kings Arms Bar95hoodg9/The Kings Arms Bar_6.JPG'),
 PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/The Kings Arms Bar95hoodg9/The Kings Arms Bar_1.JPG'),
 PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/The Kings Arms Bar95hoodg9/The Kings Arms Bar_3.JPG'),
 PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/The Kings Arms Bar95hoodg9/The Kings Arms Bar_2.JPG'),
 PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/Mcarthy Stone Head Office2qnz1b58/20230327_152406.jpg'),
 PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/Mcarthy Stone Head Office2qnz1b58/20230327_152449.jpg'),
 PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/Mcarthy Stone Head Office2qnz1b58/20230327_152721.jpg'),
 PosixPath('/content/sample_data/BeforeAugmentation_qsize3tf/Mcart

In [13]:
getFilesList(testImagesRoot.name, image_extensions, recursion=True)

[PosixPath('/content/sample_data/Test_6j2gl5ac/Ashley Automotivebijh1bah/Ashley Automotive_4.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/Ashley Automotivebijh1bah/Ashley Automotive_3.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/IMG_7471.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/IMG_7475.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/IMG_7484.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/IMG_7495.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/IMG_7491.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/IMG_7468.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/IMG_7473.JPG'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/20230330_174644.jpg'),
 PosixPath('/content/sample_data/Test_6j2gl5ac/dorset housetmuy5wss/IMG_7480.JPG'),
 PosixPath('/content/sample_data/Test

Lets do Augmentation on the training images

In [14]:
def doAugmentation(inputDir: str, outputDirRoot: str, augmentationPipeline, numberOfTargetSamples):
  """
  This function is to perform image augmentations for the images present at the 'inputDir' 
  into the root location as specified at 'outputDirRoot' using the function 'augmentationPipeline' 
  which takes in an image and output an augmented image. The new directory will be created with the same prefix
  as the input directory and it will contain all the original images plus a number of augmented images 
  such that 'numberOfTargetSamples' is reached
  """
  inputDirPath = Path(inputDir)
  originalImagePaths = []

  global image_extensions
  # im_ext = [k for k,v in mimetypes.types_map.items() if 'image/' in v]

  for x in inputDirPath.iterdir():
      if x.suffix.lower() in image_extensions:
          originalImagePaths.append(x)

  origLMDir = inputDir.split(os.sep)[-1]

  try:
      augLMDir = TemporaryDirectory(dir=outputDirRoot, prefix="{}_Aug_".format(origLMDir))
  except:
      raise Exception("Error creating temp dir for augmentations")

  print("Original Directory={}, Augmented Directory={}".format(inputDir, augLMDir.name))

  for im in originalImagePaths:
      imageName = str(im).split(os.sep)[-1]
      outputImagePath = augLMDir.name + os.sep + imageName
      
      try:
          shutil.copy(im, outputImagePath)
      except:
          raise Exception("Failed to copy original file {} to {}".format(im, outputImagePath))

  augCandidates = random.choices(originalImagePaths, k=(numberOfTargetSamples-len(originalImagePaths)))
  aug_ind = 0

  for impath in augCandidates:
      im = cv2.imread(str(impath))
      im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

      augmentedImage = augmentationPipeline(im)
      imageName, imageExt = str(impath).split(os.sep)[-1].split('.')
      outputImagePath = augLMDir.name + os.sep + 'aug' + str(aug_ind) + '_' + imageName + '.' + imageExt
      aug_ind += 1
      try:
          cv2.imwrite(outputImagePath, augmentedImage)
      except:
          raise Exception("Failed to save augmented image to {}".format(outputImagePath))
      
  return augLMDir

In [15]:
transform = A.Compose(
        [A.RandomRotate90(),
        A.RandomBrightnessContrast(brightness_limit=0.8, contrast_limit=0.4, p=0.7),
        # A.CLAHE(p=0.7),
        A.HorizontalFlip(p=0.7),
        # A.Transpose(),
        # A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=0.75),
        A.Blur(blur_limit=3)])

def augmentationPipeline(image):
  return transform(image=image)['image']

In [16]:
augmentationDirRoot = TemporaryDirectory(dir='/content/sample_data', prefix='Augmentations')
augmentationDirRoot.name

'/content/sample_data/Augmentationse8vmr8n5'

In [17]:
%%time
original_dir_before_aug = Path(beforeAugImagesRoot.name)
augmented_dirs = []
for lm in original_dir_before_aug.glob('**/'):
  if lm == original_dir_before_aug:
    continue
  augmented_dirs.append(doAugmentation(str(lm), augmentationDirRoot.name, augmentationPipeline, 50))

Original Directory=/content/sample_data/BeforeAugmentation_qsize3tf/The Kings Arms Bar95hoodg9, Augmented Directory=/content/sample_data/Augmentationse8vmr8n5/The Kings Arms Bar95hoodg9_Aug_a8_8kng2
Original Directory=/content/sample_data/BeforeAugmentation_qsize3tf/Mcarthy Stone Head Office2qnz1b58, Augmented Directory=/content/sample_data/Augmentationse8vmr8n5/Mcarthy Stone Head Office2qnz1b58_Aug_ax7xtl0g
Original Directory=/content/sample_data/BeforeAugmentation_qsize3tf/Ashley Automotive2cdrtzcn, Augmented Directory=/content/sample_data/Augmentationse8vmr8n5/Ashley Automotive2cdrtzcn_Aug_8qt7utp7
Original Directory=/content/sample_data/BeforeAugmentation_qsize3tf/BU Christchurch Househk63_g79, Augmented Directory=/content/sample_data/Augmentationse8vmr8n5/BU Christchurch Househk63_g79_Aug_v9b3szwx
Original Directory=/content/sample_data/BeforeAugmentation_qsize3tf/Baboo Jia6xz99gk, Augmented Directory=/content/sample_data/Augmentationse8vmr8n5/Baboo Jia6xz99gk_Aug_zekq97dm
Origina

In [18]:
i2c = list(all_classes)
c2i = {v:i for i,v in enumerate(i2c)}

In [19]:
i2c

['bournemouth gateway building',
 'east cliff urc church',
 'sound circus bar',
 'dorchester house',
 'hilton',
 'bu executive business centre',
 'oceanorium',
 'sir micheal cobhem library',
 'timebomb tattoo studio',
 'poole house cafe',
 'kimmeridge house',
 'hot rocks',
 'park central',
 'iq building',
 'enterprise house',
 'noman motors',
 'wiggle',
 'the hub',
 'time bomb tattoo studio',
 'old fire station',
 'cranborne house',
 'obscura',
 'sprinkles gelato',
 'tolpuddle annex 3',
 'talbot uni roundabout',
 'minton lodge hotels',
 'oxford point',
 'fern arrow roundabout',
 'buth',
 'autg',
 'dorset house',
 'aush',
 'student center',
 'ocean 80',
 'buta',
 'tesco',
 'aniba',
 'art studios',
 'st john boscombe',
 'the round house',
 'jakey house',
 'burley court hotel',
 'bobbys',
 'viztality',
 'fairways care home',
 'bu lansdowne',
 'noodle bar',
 'arts bu library',
 'church',
 'bournemouth and poole college',
 'fusion building',
 'buch',
 'ashley automotive',
 'premier inn',
 '

In [20]:
c2i

{'bournemouth gateway building': 0,
 'east cliff urc church': 1,
 'sound circus bar': 2,
 'dorchester house': 3,
 'hilton': 4,
 'bu executive business centre': 5,
 'oceanorium': 6,
 'sir micheal cobhem library': 7,
 'timebomb tattoo studio': 8,
 'poole house cafe': 9,
 'kimmeridge house': 10,
 'hot rocks': 11,
 'park central': 12,
 'iq building': 13,
 'enterprise house': 14,
 'noman motors': 15,
 'wiggle': 16,
 'the hub': 17,
 'time bomb tattoo studio': 18,
 'old fire station': 19,
 'cranborne house': 20,
 'obscura': 21,
 'sprinkles gelato': 22,
 'tolpuddle annex 3': 23,
 'talbot uni roundabout': 24,
 'minton lodge hotels': 25,
 'oxford point': 26,
 'fern arrow roundabout': 27,
 'buth': 28,
 'autg': 29,
 'dorset house': 30,
 'aush': 31,
 'student center': 32,
 'ocean 80': 33,
 'buta': 34,
 'tesco': 35,
 'aniba': 36,
 'art studios': 37,
 'st john boscombe': 38,
 'the round house': 39,
 'jakey house': 40,
 'burley court hotel': 41,
 'bobbys': 42,
 'viztality': 43,
 'fairways care home': 

In [37]:
class LandmarkDataSet(Dataset):
  def __init__(self, imPathList, im2cl, c2i, transform):
    self.x = imPathList
    self.im2cl = im2cl
    self.c2i = c2i
    self.transform = transform

  def __len__(self): 
    return len(self.x)
    
  def __getitem__(self, ix): 
    image_filepath = str(self.x[ix])
    image = Image.open(image_filepath)
    
    im_name = str(image_filepath).split(os.sep)[-1]
    if 'aug' in image_filepath:
      orig_im_name = im_name[im_name.find('_')+1:]
    else:
      orig_im_name = im_name

    cls = self.im2cl[orig_im_name]
    label = self.c2i[cls]
    if self.transform is not None:
      image = self.transform(image)
    
    return image, label

In [27]:
train_image_paths_list = getFilesList(augmentationDirRoot.name, image_extensions, recursion=True)
test_image_paths_list = getFilesList(testImagesRoot.name, image_extensions, recursion=True)

In [28]:
im_name_to_class

{'Aniba_1 of_1.jpg': 'aniba',
 'Aniba_1 of_2.jpg': 'aniba',
 'Aniba_1 of_3.jpg': 'aniba',
 'Aniba_1 of_4.jpg': 'aniba',
 'Aniba_1 of_5.jpg': 'aniba',
 'Aniba_1 of_6.jpg': 'aniba',
 'Aniba_1 of_7.jpg': 'aniba',
 'Aniba_1 of_8.jpg': 'aniba',
 'Aniba_1 of_9.jpg': 'aniba',
 'Aniba_1 of_10.jpg': 'aniba',
 'Aniba_1 of_11.jpg': 'aniba',
 'Aniba_1 of_12.jpg': 'aniba',
 'Aniba_1 of_13.jpg': 'aniba',
 'Aniba_1 of_14.jpg': 'aniba',
 'Aniba_1 of_15.jpg': 'aniba',
 'Aniba_1 of_16.jpg': 'aniba',
 'AUDS_1 of_1.jpg': 'auds',
 'AUDS_1 of_2.jpg': 'auds',
 'AUDS_1 of_3.jpg': 'auds',
 'AUDS_1 of_4.jpg': 'auds',
 'AUDS_1 of_5.jpg': 'auds',
 'AUDS_1 of_6.jpg': 'auds',
 'AUDS_1 of_7.jpg': 'auds',
 'AUDS_1 of_8.jpg': 'auds',
 'AUDS_1 of_9.jpg': 'auds',
 'AUDS_1 of_10.jpg': 'auds',
 'AUDS_1 of_11.jpg': 'auds',
 'AUDS_1 of_12.jpg': 'auds',
 'AUSH_1 of_1.jpg': 'aush',
 'AUSH_1 of_2.jpg': 'aush',
 'AUSH_1 of_3.jpg': 'aush',
 'AUSH_1 of_4.jpg': 'aush',
 'AUSH_1 of_5.jpg': 'aush',
 'AUSH_1 of_6.jpg': 'aush',
 'AUSH

In [29]:
im_name_to_class['Bournemouth Gateway Building_10.JPG']

'bournemouth gateway building'

In [36]:
for image_filepath in train_image_paths_list[:20]:
  im_name = str(image_filepath).split(os.sep)[-1]

  if 'aug' in im_name:
    orig_im_name = im_name[im_name.find('_')+1:]
  else:
    orig_im_name = im_name
  
  print("given name={} orig name={}".format(im_name, orig_im_name)) 
  cls = im_name_to_class[orig_im_name]
  print("class={}".format(cls))

given name=aug27_IMG_7908.JPG orig name=IMG_7908.JPG
class=st augustins church
given name=aug21_IMG_7914.JPG orig name=IMG_7914.JPG
class=st augustins church
given name=aug22_IMG_7905.JPG orig name=IMG_7905.JPG
class=st augustins church
given name=aug23_IMG_7908.JPG orig name=IMG_7908.JPG
class=st augustins church
given name=aug9_IMG_7916.JPG orig name=IMG_7916.JPG
class=st augustins church
given name=IMG_7905.JPG orig name=IMG_7905.JPG
class=st augustins church
given name=aug8_IMG_7906.JPG orig name=IMG_7906.JPG
class=st augustins church
given name=aug0_IMG_7909.JPG orig name=IMG_7909.JPG
class=st augustins church
given name=aug32_IMG_7916.JPG orig name=IMG_7916.JPG
class=st augustins church
given name=aug1_IMG_7904.JPG orig name=IMG_7904.JPG
class=st augustins church
given name=IMG_7904.JPG orig name=IMG_7904.JPG
class=st augustins church
given name=aug13_IMG_7914.JPG orig name=IMG_7914.JPG
class=st augustins church
given name=aug11_IMG_7921.JPG orig name=IMG_7921.JPG
class=st august

In [38]:
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

transf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(*imagenet_stats)
    ])

train_ds  = LandmarkDataSet(train_image_paths_list, im_name_to_class, c2i, transf)
test_ds  = LandmarkDataSet(test_image_paths_list, im_name_to_class, c2i, transf)

len(train_ds), len(test_ds)

(4950, 443)

In [39]:
accuracy = lambda y_,yb: (y_.max(dim=1)[1] == yb).float().mean()

def one_epoch(net, loss, dl, opt=None, metric=None):   
  if opt: net.train()  # only affects some layers
  else:   net.eval()

  L, M = [], []
  for xb, yb in tqdm(iter(dl), leave=False):
    xb, yb = xb.cuda(), yb.cuda()
    if opt:
      y_ = net(xb)
      l = loss(y_, yb)
      opt.zero_grad()
      l.backward()
      opt.step()
    else:
      with tc.no_grad():
        y_ = net(xb)
        l = loss(y_, yb)
    L.append(l.detach().cpu().numpy())
    if metric: 
      M.append(metric(y_, yb).cpu().numpy())
          
  return L, M


def fit(net, tr_dl, val_dl, loss=nn.CrossEntropyLoss(), epochs=10, lr=3e-3, wd=1e-3, plot=True):
  opt = optim.Adam(net.parameters(), lr=lr, weight_decay=wd)
  Ltr_hist, Lval_hist = [], []
  for epoch in trange(epochs):
    Ltr,  _    = one_epoch(net, loss, tr_dl,  opt)
    Lval, Aval = one_epoch(net, loss, val_dl, None, accuracy)
    Ltr_hist.append(np.mean(Ltr))
    Lval_hist.append(np.mean(Lval))
    print(f'epoch: {epoch}\ttraining loss: {np.mean(Ltr):0.4f}\tvalidation loss: {np.mean(Lval):0.4f}\tvalidation accuracy: {np.mean(Aval):0.2f}')
      
  # plot the losses     
  if plot:
    _,ax = plt.subplots(1,1,figsize=(16,4))
    ax.plot(1+np.arange(len(Ltr_hist)),Ltr_hist)
    ax.plot(1+np.arange(len(Lval_hist)),Lval_hist)
    ax.grid('on')
    ax.set_xlim(left=1, right=len(Ltr_hist))
    ax.legend(['training loss', 'validation loss']);
      
  return Ltr_hist, Lval_hist

def freeze(md, fr=True):
  ch = list(md.children())
  for c in ch: freeze(c, fr)
  if not ch and not isinstance(md, tc.nn.modules.batchnorm.BatchNorm2d):  # not freezing the BatchNorm layers!
    for p in md.parameters(): 
      #print('---\n', md, p.requires_grad)
      p.requires_grad = not fr

def freeze_to(md, ix=-1, fr=True):
  ch_all = list(md.children())
  for ch in ch_all[:ix]:
    freeze(ch, fr)

In [42]:
bs = 16
num_of_classes = len(i2c)

tr_dl  = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=2)
val_dl = DataLoader(test_ds, batch_size=2*bs, shuffle=False, num_workers=2)

resnet18 = models.resnet18(pretrained=True)
resnet18.fc = nn.Linear(512, num_of_classes)
freeze_to(resnet18, -1, True)

resnet18 = resnet18.cuda()

fit(resnet18, tr_dl, val_dl)

RuntimeError: ignored