<a href="https://colab.research.google.com/github/sheikmohdimran/Experiments_2021/blob/main/Vision/102_Flowers_Classification_TransferLearning_NativePT_vs_FastAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!wget -q https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!wget -q https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat
!wget -q https://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat
!tar -xf 102flowers.tgz

In [1]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch
from torch import nn
from collections import OrderedDict
from tqdm import tqdm
from scipy.io import loadmat
from PIL import Image
from torch import optim


In [2]:
split = loadmat('setid.mat')
label = loadmat('imagelabels.mat')

In [3]:
trnid_df=pd.DataFrame(split['trnid'][0], columns = ['id'])
trnid_df=trnid_df.assign(split='train')
tstid_df=pd.DataFrame(split['tstid'][0], columns = ['id'])
tstid_df=tstid_df.assign(split='test')
valid_df=pd.DataFrame(split['valid'][0], columns = ['id'])
valid_df=valid_df.assign(split='valid')

split_df=trnid_df.append(valid_df, ignore_index=True).append(tstid_df, ignore_index=True)

In [4]:
df = pd.DataFrame(label['labels'][0], columns = ['label'])
df['file'] = df.index
df['file']=df['file'].apply(lambda x: x+1).apply(lambda x: '{0:0>5}'.format(x)).apply(lambda x:'jpg/image_'+str(x)+'.jpg')
df['id'] = df.index
df['id']=df['id'].apply(lambda x: x+1)

In [5]:
 df=pd.merge(df, split_df, on="id")
 df=df.drop(columns=['id'])
 df.head()

Unnamed: 0,label,file,split
0,77,jpg/image_00001.jpg,test
1,77,jpg/image_00002.jpg,test
2,77,jpg/image_00003.jpg,test
3,77,jpg/image_00004.jpg,test
4,77,jpg/image_00005.jpg,test


In [6]:
df.dtypes

label     uint8
file     object
split    object
dtype: object

In [7]:
import numpy as np
np.unique(label['labels'][0])

array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102], dtype=uint8)

In [8]:
#image=Image.open('jpg/image_06765.jpg')
#image

In [9]:
class CustomImageDataset(Dataset):
    def __init__(self, data_frame, split, transform=None):
        self.img_labels = data_frame[data_frame.split==split]
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = read_image(self.img_labels.iloc[idx, 1])
        label = self.img_labels.iloc[idx, 0].astype(np.long)-1
        if self.transform:
            image = self.transform(image)
        return image, label

In [10]:
train_transform=transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

val_transform=transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [11]:
train_data = CustomImageDataset(data_frame=df,split='train',transform=train_transform)
valid_data = CustomImageDataset(data_frame=df,split='valid',transform=val_transform)

In [12]:
trainloader = DataLoader(train_data, batch_size=32, shuffle=True)
validloader = DataLoader(valid_data, batch_size=256, shuffle=False)

In [13]:
next(iter(trainloader))[1]

tensor([  1,  28,  64,  33,  73,  60,  50,  75,  85,  86,  31,  40, 101,  40,
         58,  32,  95,  84,   9,  81,   7,  30,  77,  55,  30,  14,  66, 100,
         30,  25,   3,  41])

In [44]:
model=models.resnet18(pretrained=True)

In [45]:
for name, module in model.named_children():
    print(name)

conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc


In [46]:
for name, module in model.named_children():
  if name in ['layer4']: #Unfreeze
    for param in module.parameters():
      param.requires_grad = True
  else: #Freeze
    for param in module.parameters():
      param.requires_grad = False

In [47]:
model.fc = nn.Sequential(OrderedDict([
          ('lin1', nn.Linear(512,256)),
          ('relu1', nn.ReLU()),
          ('lin2', nn.Linear(256,102))
        ]))

In [48]:
#model

## Training Loop - Native Pytorch 

In [49]:
num_epochs=20

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

optimizer = optim.SGD(model.parameters(),lr=1e-3,momentum=0.9)
criterion = nn.CrossEntropyLoss()
model = model.to(device)

In [50]:
train_loss=[]
validation_loss=[]
accuracy=[]

for epoch in range(num_epochs):
  model.train()
  running_loss = 0.0
  val_corr = 0
  for inputs, labels in tqdm(trainloader, position=0, leave=True):
    inputs,labels=inputs.to(device),labels.to(device)
    optimizer.zero_grad()
    #with torch.set_grad_enabled(True):
    outputs=model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
  train_loss.append(running_loss/len(trainloader))
  

  with torch.no_grad():
    model.eval()
    val_corr=0
    acc=0
    total=0
    valid_loss=0
    for inputs, labels in tqdm(validloader, position=0, leave=True):
      inputs,labels=inputs.to(device),labels.to(device)
      outputs=model(inputs)
      loss = criterion(outputs, labels)
      valid_loss+= loss.item()
      pred = torch.max(outputs,1)[1]
      val_corr += (labels == pred).sum().item()
      total += labels.size(0)
      acc=val_corr/total
    accuracy.append(acc*100)
    validation_loss.append(valid_loss/len(validloader))


  print(train_loss)    
  print(validation_loss)
  print(accuracy)

100%|██████████| 32/32 [00:06<00:00,  4.58it/s]
100%|██████████| 4/4 [00:07<00:00,  1.95s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003]
[4.5902063846588135]
[1.3725490196078431]


100%|██████████| 32/32 [00:06<00:00,  4.82it/s]
100%|██████████| 4/4 [00:07<00:00,  1.99s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436]
[4.5902063846588135, 4.507440209388733]
[1.3725490196078431, 3.9215686274509802]


100%|██████████| 32/32 [00:06<00:00,  4.71it/s]
100%|██████████| 4/4 [00:07<00:00,  1.93s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905]
[4.5902063846588135, 4.507440209388733, 4.426660776138306]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098]


100%|██████████| 32/32 [00:06<00:00,  4.78it/s]
100%|██████████| 4/4 [00:07<00:00,  1.93s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351]


100%|██████████| 32/32 [00:06<00:00,  4.71it/s]
100%|██████████| 4/4 [00:07<00:00,  1.98s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703]


100%|██████████| 32/32 [00:06<00:00,  4.69it/s]
100%|██████████| 4/4 [00:07<00:00,  1.92s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487]


100%|██████████| 32/32 [00:06<00:00,  4.85it/s]
100%|██████████| 4/4 [00:07<00:00,  1.91s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177]


100%|██████████| 32/32 [00:06<00:00,  4.83it/s]
100%|██████████| 4/4 [00:07<00:00,  1.91s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451]


100%|██████████| 32/32 [00:06<00:00,  4.85it/s]
100%|██████████| 4/4 [00:07<00:00,  1.95s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606]


100%|██████████| 32/32 [00:06<00:00,  4.67it/s]
100%|██████████| 4/4 [00:07<00:00,  1.95s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0]


100%|██████████| 32/32 [00:06<00:00,  4.83it/s]
100%|██████████| 4/4 [00:07<00:00,  1.90s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628]


100%|██████████| 32/32 [00:06<00:00,  4.67it/s]
100%|██████████| 4/4 [00:07<00:00,  1.96s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406]


100%|██████████| 32/32 [00:06<00:00,  4.79it/s]
100%|██████████| 4/4 [00:07<00:00,  1.91s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806, 2.9475461915135384]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287, 2.614030659198761]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406, 59.509803921568626]


100%|██████████| 32/32 [00:06<00:00,  4.82it/s]
100%|██████████| 4/4 [00:07<00:00,  1.92s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806, 2.9475461915135384, 2.696686938405037]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287, 2.614030659198761, 2.3908668160438538]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406, 59.509803921568626, 63.725490196078425]


100%|██████████| 32/32 [00:06<00:00,  4.85it/s]
100%|██████████| 4/4 [00:07<00:00,  1.91s/it]
  3%|▎         | 1/32 [00:00<00:06,  5.03it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806, 2.9475461915135384, 2.696686938405037, 2.5279562547802925]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287, 2.614030659198761, 2.3908668160438538, 2.177520751953125]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406, 59.509803921568626, 63.725490196078425, 66.86274509803921]


100%|██████████| 32/32 [00:06<00:00,  4.83it/s]
100%|██████████| 4/4 [00:07<00:00,  1.95s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806, 2.9475461915135384, 2.696686938405037, 2.5279562547802925, 2.323392190039158]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287, 2.614030659198761, 2.3908668160438538, 2.177520751953125, 2.024788558483124]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406, 59.509803921568626, 63.725490196078425, 66.86274509803921, 68.23529411764706]


100%|██████████| 32/32 [00:06<00:00,  4.68it/s]
100%|██████████| 4/4 [00:07<00:00,  1.94s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806, 2.9475461915135384, 2.696686938405037, 2.5279562547802925, 2.323392190039158, 2.154583405703306]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287, 2.614030659198761, 2.3908668160438538, 2.177520751953125, 2.024788558483124, 1.8509829342365265]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406, 59.509803921568626, 63.725490196078425, 66.86274509803921, 68.23529411764706, 71.66666666666667]


100%|██████████| 32/32 [00:06<00:00,  4.81it/s]
100%|██████████| 4/4 [00:07<00:00,  1.92s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806, 2.9475461915135384, 2.696686938405037, 2.5279562547802925, 2.323392190039158, 2.154583405703306, 2.0040000453591347]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287, 2.614030659198761, 2.3908668160438538, 2.177520751953125, 2.024788558483124, 1.8509829342365265, 1.691148817539215]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406, 59.509803921568626, 63.725490196078425, 66.86274509803921, 68.23529411764706, 71.

100%|██████████| 32/32 [00:06<00:00,  4.83it/s]
100%|██████████| 4/4 [00:07<00:00,  1.91s/it]
  0%|          | 0/32 [00:00<?, ?it/s]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806, 2.9475461915135384, 2.696686938405037, 2.5279562547802925, 2.323392190039158, 2.154583405703306, 2.0040000453591347, 1.836447786539793]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287, 2.614030659198761, 2.3908668160438538, 2.177520751953125, 2.024788558483124, 1.8509829342365265, 1.691148817539215, 1.573032259941101]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406, 59.509803921568626, 63.725490196078425, 66.

100%|██████████| 32/32 [00:06<00:00,  4.84it/s]
100%|██████████| 4/4 [00:07<00:00,  1.90s/it]

[4.64577229321003, 4.569572538137436, 4.4835668206214905, 4.402775064110756, 4.314853951334953, 4.210940107703209, 4.086755596101284, 3.927594803273678, 3.7422419264912605, 3.569014385342598, 3.388018675148487, 3.146589048206806, 2.9475461915135384, 2.696686938405037, 2.5279562547802925, 2.323392190039158, 2.154583405703306, 2.0040000453591347, 1.836447786539793, 1.701786857098341]
[4.5902063846588135, 4.507440209388733, 4.426660776138306, 4.334953784942627, 4.221819996833801, 4.086189270019531, 3.9317728877067566, 3.7494282722473145, 3.535394072532654, 3.3115472197532654, 3.074109196662903, 2.8567826747894287, 2.614030659198761, 2.3908668160438538, 2.177520751953125, 2.024788558483124, 1.8509829342365265, 1.691148817539215, 1.573032259941101, 1.468008279800415]
[1.3725490196078431, 3.9215686274509802, 8.92156862745098, 14.411764705882351, 23.823529411764703, 31.960784313725487, 37.05882352941177, 43.03921568627451, 48.431372549019606, 50.0, 55.09803921568628, 57.647058823529406, 59.50




## Training Loop - FastAI

In [None]:
!pip install -q fastai==2.4.1

In [35]:
from fastai.vision.all import *
data = DataLoaders(trainloader, validloader)
learn = Learner(data, model, loss_func=criterion, opt_func=Adam, metrics=accuracy)

In [36]:
learn.fine_tune(num_epochs)

epoch,train_loss,valid_loss,accuracy,time
0,4.128293,3.335753,0.268627,00:14


epoch,train_loss,valid_loss,accuracy,time
0,2.262967,1.523196,0.747059,00:15
1,1.71805,1.018243,0.814706,00:14
2,1.262595,0.762076,0.822549,00:15
3,0.987542,0.75051,0.812745,00:14
4,0.850594,0.943815,0.757843,00:14
5,0.813167,0.884433,0.790196,00:14
6,0.731607,0.762097,0.812745,00:14
7,0.628659,0.97492,0.779412,00:15
8,0.550381,0.710221,0.828431,00:14
9,0.482185,0.675275,0.836275,00:14
