In [None]:
'''
Flower classification using PyTorch
dataset: https://www.kaggle.com/alxmamaev/flowers-recognition
'''

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
import os

os.environ['KAGGLE_CONFIG_DIR'] = "/content/gdrive/My Drive/Kaggle"

In [3]:
#changing the working directory
%cd /content/gdrive/My Drive/Kaggle

%pwd

/content/gdrive/My Drive/Kaggle


'/content/gdrive/My Drive/Kaggle'

In [4]:
!kaggle datasets download -d alxmamaev/flowers-recognition

Downloading flowers-recognition.zip to /content/gdrive/My Drive/Kaggle
 97% 436M/450M [00:03<00:00, 151MB/s]
100% 450M/450M [00:03<00:00, 139MB/s]


In [None]:
#unzipping the zip files and deleting the zip files
!unzip \*.zip  && rm *.zip

In [6]:
DIR_PATH = '/content/gdrive/MyDrive/Kaggle/flowers/flowers'

In [7]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# pytorch imports

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset,DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F

In [8]:
transformations = {
    'train': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]) # resnet18
    ]),
    'test': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]) # resnet18
    ])
}

In [20]:
# hyperparamters

learning_rate = 0.01
batch_size = 32
num_epochs = 30
num_classes = 5

# device
device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
print(device)

cuda


In [10]:
total_dataset = torchvision.datasets.ImageFolder(DIR_PATH,transform=transformations['train'])

len(total_dataset),total_dataset[0][0].shape,total_dataset.class_to_idx

(4323,
 torch.Size([3, 224, 224]),
 {'daisy': 0, 'dandelion': 1, 'rose': 2, 'sunflower': 3, 'tulip': 4})

In [11]:
# splitting into train and validation sets

SPLIT_SIZE = 0.8
tot_len = len(total_dataset)

train_size = int(SPLIT_SIZE * tot_len)
val_size = tot_len - train_size

print(f'Training set size = {train_size} \nValidation set size = {val_size}')

train_dataset, val_dataset =  torch.utils.data.random_split(total_dataset,[train_size,val_size])

len(train_dataset),len(val_dataset)

Training set size = 3458 
Validation set size = 865


(3458, 865)

In [21]:
# dataloaders
train_loader = DataLoader(dataset=train_dataset,
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=4)

val_loader = DataLoader(dataset=val_dataset,
                       batch_size=batch_size,
                       shuffle=True,
                       num_workers=4)

In [22]:
# testing dataloading 

examples = iter(train_loader)
samples,labels = examples.next()
print(samples.shape,labels.shape) # batch_size=32
len(train_loader),len(val_loader)

torch.Size([32, 3, 224, 224]) torch.Size([32])


(109, 28)

In [14]:
# custom CNN model class

class ConvNet(nn.Module):
    def __init__(self,model,num_classes):
        super(ConvNet,self).__init__()
        self.base_model = nn.Sequential(*list(model.children())[:-1]) # model excluding last FC layer
        self.linear1 = nn.Linear(in_features=512,out_features=120)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(in_features=120,out_features=num_classes)
    
    def forward(self,x):
        x = self.base_model(x)
        x = torch.flatten(x,1)
        lin = self.linear1(x)
        x = self.relu(lin)
        out = self.linear2(x)
        return lin, out

In [15]:
model = torchvision.models.resnet18(pretrained=True) # base model

model = ConvNet(model,num_classes)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




In [16]:
print(model)

ConvNet(
  (base_model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr

In [23]:
# training loop

n_iters = len(train_loader)

for epoch in range(num_epochs):
    model.train()
    for ii,(images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        _,outputs = model(images)
        loss = criterion(outputs,labels)
        
        # free_gpu_cache()
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if (ii+1)%25 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{ii+1}/{n_iters}], Loss = {loss.item():.6f}')
            
    print('----------------------------------------')
    

Epoch [1/30], Step [25/109], Loss = 0.929877
Epoch [1/30], Step [50/109], Loss = 0.934986
Epoch [1/30], Step [75/109], Loss = 0.829087
Epoch [1/30], Step [100/109], Loss = 0.723132
----------------------------------------
Epoch [2/30], Step [25/109], Loss = 0.766979
Epoch [2/30], Step [50/109], Loss = 0.664721
Epoch [2/30], Step [75/109], Loss = 1.134478
Epoch [2/30], Step [100/109], Loss = 0.784860
----------------------------------------
Epoch [3/30], Step [25/109], Loss = 0.946166
Epoch [3/30], Step [50/109], Loss = 0.822793
Epoch [3/30], Step [75/109], Loss = 0.768476
Epoch [3/30], Step [100/109], Loss = 0.821083
----------------------------------------
Epoch [4/30], Step [25/109], Loss = 0.710896
Epoch [4/30], Step [50/109], Loss = 0.635608
Epoch [4/30], Step [75/109], Loss = 0.937803
Epoch [4/30], Step [100/109], Loss = 0.899614
----------------------------------------
Epoch [5/30], Step [25/109], Loss = 0.738527
Epoch [5/30], Step [50/109], Loss = 1.186588
Epoch [5/30], Step [75

In [42]:
# evaluating model and getting features of every image
features = []

with torch.no_grad():
    # for entire test set
    n_correct = 0
    n_samples = 0

    # for each class label
    n_class_correct = [0 for i in range(num_classes)]
    n_class_samples = [0 for i in range(num_classes)]

    for images,labels in val_loader:

        images = images.to(device)
        labels = labels.to(device)

        ftrs,outputs = model(images)
        features.append(ftrs)

        _,preds = torch.max(outputs,1)
        n_samples += labels.size(0)
        n_correct += (preds == labels).sum().item()
            
    accuracy = n_correct/float(n_samples)

    print(f'Accuracy of model on test set = {(100.0 * accuracy):.4f} %')
        

Accuracy of model on test set = 77.2254 %


In [107]:
ftrs = features.copy() 

In [108]:
for i in range(len(ftrs)):
    ftrs[i] = ftrs[i].cpu().numpy()

ftrs = ftrs[:-1]

In [113]:
type(ftrs),ftrs[0].dtype,ftrs[0][0].dtype

(list, dtype('float32'), dtype('float32'))

In [114]:
ftrs = np.array(ftrs)
ftrs.shape

(27, 32, 120)

In [115]:
n_samples = ftrs.shape[0]*ftrs.shape[1]
n_features = ftrs.shape[2]
ftrs = ftrs.reshape(n_samples,n_features)
print(ftrs.shape)

(864, 120)


In [116]:
# save to csv
ftrs_df = pd.DataFrame(ftrs)
ftrs_df.to_csv('./fc_layer_features.csv',index=False)

# reloading the saved csv into a df

ftrs_df = pd.read_csv('./fc_layer_features.csv')
ftrs_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119
0,-2.682216,-1.809271,-7.484774,-2.597401,-2.129449,-2.684520,-1.210757,-2.449394,-1.321972,-2.159092,-2.654260,-6.571615,-1.947343,-1.002393,-2.354665,-2.960461,-1.089879,17.190150,-2.379287,-5.472676,-2.085195,-2.831443,-3.142085,-2.629379,-2.547121,-2.673542,-3.140305,-3.010859,-1.428673,-1.858961,-2.137040,-1.200971,-4.264568,-10.876523,-2.811174,-2.702208,-6.728081,-0.740339,-3.158659,-2.148518,...,-2.911375,-2.589549,-2.249176,-2.052157,-2.182076,-1.401704,-8.315755,-3.026768,-1.277607,-2.420685,-2.941650,-1.491159,-2.831244,-2.337384,-2.972995,-2.413209,-2.147227,-2.480342,-2.834272,-2.641840,-1.907128,-1.589941,-1.647408,-1.306086,-2.020735,-2.048256,-3.149031,-1.620974,-2.523599,11.364213,-3.131830,-2.930294,-3.315069,-13.240990,-0.846887,-3.146235,-2.276960,-2.013842,-2.581115,-1.946608
1,-0.695757,-0.435507,-2.002212,-0.726669,-0.538524,-0.664133,-0.278109,-0.651841,-0.344545,-0.468672,-0.689619,-1.814749,-0.573759,-0.255574,-0.658513,-0.727305,-0.456727,1.648199,-0.468136,-1.834989,-0.623594,-0.769229,-0.827116,-0.687995,-0.618820,-0.600189,-0.927013,-0.779245,-0.333926,-0.458722,-0.662025,-0.266457,-1.307203,0.200569,-0.758323,-0.785185,-2.230599,-0.383973,-0.849165,-0.558235,...,-0.736843,-0.661113,-0.644412,-0.391596,-0.522938,-0.408781,-2.239814,-0.665891,-0.493303,-0.740793,-0.779765,-0.397665,-0.728332,-0.692928,-0.766485,-0.628347,-0.553270,-0.594672,-0.774251,-0.644996,-0.589726,-0.433961,-0.529941,-0.525891,-0.557111,-0.564597,-0.807183,-0.371622,-0.777322,3.133616,-0.826066,-0.723999,-0.827413,-3.670339,-0.229353,-0.875108,-0.645979,-0.565116,-0.666970,-0.582515
2,-3.683415,-2.383975,-7.987829,-3.820760,-2.868524,-2.551046,-2.027176,-3.118186,-2.900056,-2.489923,-3.546045,-6.697475,-2.816068,-1.266764,-3.159733,-3.798971,-2.099383,-13.021486,-3.016269,-3.881403,-2.431212,-3.022377,-4.002620,-3.675863,-3.111759,-3.487597,-3.955573,-3.568248,-1.862123,-1.802237,-2.628286,-1.267471,-3.496161,-17.233816,-3.172531,-3.675971,-7.465856,-1.691595,-4.155883,-2.769426,...,-3.123573,-2.789018,-3.489793,-3.318862,-1.941323,-2.211892,-12.419865,-3.349534,-1.889787,-3.433095,-4.116605,-1.466856,-2.909222,-3.439175,-3.909019,-2.247212,-2.757628,-2.835237,-3.544679,-3.020590,-2.348225,-2.020951,-2.626538,-2.460883,-3.007290,-2.069873,-4.256525,-2.201931,-3.119349,25.950176,-3.578945,-3.379693,-3.575484,-16.841516,-1.253214,-4.173200,-3.323675,-2.245342,-3.144603,-2.613774
3,-1.173320,-0.750072,-2.802344,-1.164975,-0.711169,-0.806676,-0.493163,-1.012302,-0.555422,-0.630718,-1.227813,-4.627109,-0.638829,-0.369939,-0.996563,-1.230575,-0.879868,0.693294,-0.688844,-1.667879,-0.833348,-1.071949,-1.242746,-1.067801,-1.003154,-0.914465,-1.366012,-1.179207,-0.673505,-0.640495,-0.881471,-0.404645,-1.867932,-0.360532,-1.266818,-1.208169,-2.787689,-0.283506,-1.232506,-0.877398,...,-1.002554,-1.124190,-0.750089,-0.757345,-0.802430,-0.523366,-3.949636,-1.055945,-0.582767,-1.182436,-1.250123,-0.501455,-1.073802,-1.115375,-1.289442,-0.831518,-0.804931,-0.937642,-1.224237,-1.098042,-0.734878,-0.611672,-0.677367,-0.787671,-0.626887,-0.964976,-1.239771,-0.680202,-0.890958,1.209716,-1.218066,-1.198352,-1.300493,-2.789112,-0.310184,-1.450094,-0.960666,-0.891008,-0.980529,-0.829047
4,-1.514406,-1.070283,-7.335757,-1.291664,-1.032857,-1.348833,-0.847980,-1.371439,-0.836292,-0.758627,-1.415049,-3.704118,-1.341327,-0.807291,-1.053635,-1.793361,-1.247617,1.480043,-1.147421,-6.082498,-1.281297,-1.469837,-1.776708,-1.413108,-1.136022,-1.190119,-2.027516,-1.754061,-1.261062,-1.475961,-1.345021,-0.779778,-1.529282,10.805672,-1.561157,-1.382681,-6.432305,-0.803183,-1.792940,-1.256010,...,-1.809721,-1.703669,-1.074445,-0.632277,-1.241656,-0.320203,-3.995121,-1.171972,-1.015934,-1.167211,-1.536554,-0.909973,-1.871817,-1.677236,-1.697805,-1.477156,-1.336026,-0.846316,-1.736524,-1.558902,-1.330040,-0.647472,-1.350899,-0.950122,-1.548696,-1.124324,-1.614228,-0.916950,-1.490492,-6.271286,-1.591411,-1.853100,-1.822645,-10.387329,-0.880690,-1.658471,-1.188071,-0.967906,-1.462414,-1.681509
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
859,-2.149203,-1.167360,-7.465871,-2.228412,-2.173722,-1.692038,-1.409218,-1.888183,-1.583038,-1.170188,-1.944963,-2.856950,-2.212190,-1.312086,-1.808873,-2.291579,-1.908697,-12.570084,-1.613252,-7.302553,-2.012983,-2.134210,-2.642103,-2.316546,-1.121836,-1.555654,-2.921847,-2.262361,-1.459241,-1.673375,-2.134064,-0.682239,-2.306306,9.903104,-2.111704,-2.196830,-7.001608,-1.371686,-2.786876,-1.655471,...,-2.376404,-1.819809,-2.526303,-1.033014,-1.259838,-1.328975,-7.053557,-1.453996,-1.486542,-2.159665,-2.477511,-0.972920,-1.999085,-2.393944,-2.088926,-1.998367,-2.414711,-1.106473,-2.343587,-1.934486,-1.694322,-1.479425,-2.231659,-1.571938,-2.410158,-1.514201,-2.560195,-1.081027,-2.012081,2.323297,-2.243312,-2.248933,-2.388998,-15.652525,-1.163700,-2.620704,-2.018867,-1.379433,-2.045894,-2.172523
860,-1.562800,-0.906254,-3.993460,-1.568198,-1.137905,-1.145966,-0.727874,-1.378477,-0.796791,-0.787036,-1.581448,-5.530391,-0.943061,-0.560086,-1.368658,-1.619732,-1.105445,1.905164,-1.032439,-2.762967,-1.168241,-1.476441,-1.687641,-1.461497,-1.162141,-1.255375,-1.874028,-1.552257,-0.957404,-0.933245,-1.228927,-0.496220,-2.650651,-0.025041,-1.677677,-1.558726,-3.346039,-0.360525,-1.749568,-1.151465,...,-1.374294,-1.493654,-1.094739,-0.976018,-1.078524,-0.745917,-5.226374,-1.416925,-0.734377,-1.617143,-1.683642,-0.699283,-1.408795,-1.408024,-1.665272,-1.169565,-1.279981,-1.165604,-1.530076,-1.529025,-1.003100,-0.874588,-0.907109,-1.033133,-1.002617,-1.228194,-1.703007,-0.888105,-1.225370,2.424272,-1.605450,-1.639443,-1.718546,-4.962667,-0.425299,-1.935863,-1.255676,-1.166074,-1.320367,-1.170232
861,-2.008381,-1.019784,-4.430455,-1.916422,-1.324491,-1.305691,-0.894520,-1.568956,-0.872068,-0.888425,-1.979610,-8.280972,-1.070173,-0.527253,-1.701700,-2.084373,-1.526454,1.316703,-1.115793,-3.027009,-1.468342,-1.907583,-2.031136,-1.721954,-1.473641,-1.364456,-2.336177,-1.885022,-1.269501,-1.179358,-1.669333,-0.532040,-3.446072,1.300768,-2.231837,-1.859068,-4.147035,-0.368525,-2.159190,-1.454647,...,-1.621037,-1.942347,-1.126336,-1.068145,-1.413645,-0.864849,-6.650526,-1.721208,-0.970578,-1.998592,-2.057246,-0.807653,-1.821744,-1.740368,-2.054749,-1.554342,-1.603140,-1.358109,-1.907893,-1.905593,-1.225278,-1.090034,-1.105384,-1.300037,-1.084049,-1.691715,-2.109883,-0.993285,-1.530828,-0.736873,-1.962887,-2.049496,-2.237448,-4.820333,-0.381630,-2.401539,-1.494782,-1.429371,-1.619133,-1.384991
862,-2.881787,-1.817634,-6.522514,-2.881721,-2.671067,-2.889485,-1.451315,-2.841643,-1.286826,-2.315200,-3.055826,-8.696461,-1.973018,-1.220838,-3.103690,-3.349902,-1.233574,20.612210,-2.871333,-5.600273,-2.609855,-3.116448,-3.411882,-3.032839,-2.763036,-2.688051,-3.442422,-3.214019,-1.563149,-2.020676,-2.516314,-1.053416,-4.894775,-12.314918,-3.122335,-3.191750,-5.829114,-0.407290,-3.601918,-2.290954,...,-3.264230,-2.856436,-2.784358,-2.325966,-2.419290,-1.894581,-9.509387,-3.395753,-1.165350,-2.633135,-3.361977,-1.329225,-2.948021,-2.625127,-3.392757,-2.506977,-2.487493,-2.977251,-2.872969,-2.858849,-1.992996,-2.111091,-1.598005,-1.480488,-2.111705,-2.061406,-3.526394,-1.939423,-2.597189,10.692299,-3.307765,-3.530249,-3.647159,-12.394196,-0.662293,-3.628499,-2.556084,-2.441861,-2.952585,-1.825578


In [117]:
# save model
MODEL_PATH = 'flower_classification_model.pth'
torch.save(model.state_dict(),MODEL_PATH)