In [None]:
'''
Flower classification using PyTorch
dataset: https://www.kaggle.com/alxmamaev/flowers-recognition
'''

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import os

os.environ['KAGGLE_CONFIG_DIR'] = "/content/gdrive/My Drive/Kaggle"

In [2]:
#changing the working directory
%cd /content/gdrive/My Drive/Kaggle
%pwd

/content/gdrive/My Drive/Kaggle


'/content/gdrive/My Drive/Kaggle'

In [None]:
!kaggle datasets download -d alxmamaev/flowers-recognition

Downloading flowers-recognition.zip to /content/gdrive/My Drive/Kaggle
 97% 436M/450M [00:03<00:00, 151MB/s]
100% 450M/450M [00:03<00:00, 139MB/s]


In [None]:
#unzipping the zip files and deleting the zip files
!unzip \*.zip  && rm *.zip

In [3]:
DIR_PATH = '/content/gdrive/MyDrive/Kaggle/flowers/flowers'

In [4]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# pytorch imports

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset,DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F

In [6]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

transformations = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.CenterCrop((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ]),
    'test': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.CenterCrop((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ])
}

In [7]:
# hyperparamters

learning_rate = 0.001
batch_size = 8
num_epochs = 50
num_classes = 5

# device
device = None
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
print(device)

cuda


In [8]:
total_dataset = torchvision.datasets.ImageFolder(DIR_PATH,transform=transformations['train'])

len(total_dataset),total_dataset[0][0].shape,total_dataset.class_to_idx

(4323,
 torch.Size([3, 224, 224]),
 {'daisy': 0, 'dandelion': 1, 'rose': 2, 'sunflower': 3, 'tulip': 4})

In [9]:
# splitting into train and validation sets

SPLIT_SIZE = 0.8
tot_len = len(total_dataset)

train_size = int(SPLIT_SIZE * tot_len)
val_size = tot_len - train_size

print(f'Training set size = {train_size} \nValidation set size = {val_size}')

train_dataset, val_dataset =  torch.utils.data.random_split(total_dataset,[train_size,val_size])

len(train_dataset),len(val_dataset)

Training set size = 3458 
Validation set size = 865


(3458, 865)

In [10]:
# dataloaders
train_loader = DataLoader(dataset=train_dataset,
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=4)

val_loader = DataLoader(dataset=val_dataset,
                       batch_size=1,
                       shuffle=True,
                       num_workers=4)

  cpuset_checked))


In [11]:
# testing dataloading 

examples = iter(train_loader)
samples,labels = examples.next()
print(samples.shape,labels.shape) # batch_size=8
len(train_loader),len(val_loader)

  cpuset_checked))


torch.Size([8, 3, 224, 224]) torch.Size([8])


(433, 865)

In [16]:
# custom CNN model class

class ConvNet(nn.Module):
    def __init__(self,model,num_classes):
        super(ConvNet,self).__init__()
        self.base_model = nn.Sequential(*list(model.children())[:-1]) # model excluding last FC layer
        self.linear1 = nn.Linear(in_features=2048,out_features=512)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(in_features=512,out_features=num_classes)
    
    def forward(self,x):
        x = self.base_model(x)
        x = torch.flatten(x,1)
        lin = self.linear1(x)
        x = self.relu(lin)
        out = self.linear2(x)
        return lin, out

In [17]:
model = torchvision.models.resnet50(pretrained=True) # base model

model = ConvNet(model,num_classes)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate,momentum=0.9)

In [18]:
print(model)

ConvNet(
  (base_model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(6

In [19]:
# training loop

n_iters = len(train_loader)

for epoch in range(num_epochs):
    model.train()
    for ii,(images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        _,outputs = model(images)
        loss = criterion(outputs,labels)
        
        # free_gpu_cache()
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if (ii+1)%108 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{ii+1}/{n_iters}], Loss = {loss.item():.6f}')
            
    print('----------------------------------------')
    

  cpuset_checked))


Epoch [1/50], Step [108/433], Loss = 0.757020
Epoch [1/50], Step [216/433], Loss = 0.685718
Epoch [1/50], Step [324/433], Loss = 0.115149
Epoch [1/50], Step [432/433], Loss = 0.439111
----------------------------------------
Epoch [2/50], Step [108/433], Loss = 0.136134
Epoch [2/50], Step [216/433], Loss = 0.049118
Epoch [2/50], Step [324/433], Loss = 0.456234
Epoch [2/50], Step [432/433], Loss = 0.104673
----------------------------------------
Epoch [3/50], Step [108/433], Loss = 0.034409
Epoch [3/50], Step [216/433], Loss = 0.231488
Epoch [3/50], Step [324/433], Loss = 0.066696
Epoch [3/50], Step [432/433], Loss = 0.467270
----------------------------------------
Epoch [4/50], Step [108/433], Loss = 0.026214
Epoch [4/50], Step [216/433], Loss = 0.030503
Epoch [4/50], Step [324/433], Loss = 0.312166
Epoch [4/50], Step [432/433], Loss = 0.314792
----------------------------------------
Epoch [5/50], Step [108/433], Loss = 0.021850
Epoch [5/50], Step [216/433], Loss = 1.182730
Epoch [5

In [30]:
# evaluating model and getting features of every image

def eval_model_extract_features(features,true_labels,model,dataloader,phase):

    with torch.no_grad():
        # for entire dataset
        n_correct = 0
        n_samples = 0

        model.eval()

        for images,labels in dataloader:

            images = images.to(device)
            labels = labels.to(device)

            true_labels.append(labels)
            
            ftrs,outputs = model(images)
            features.append(ftrs)

            _,preds = torch.max(outputs,1)
            n_samples += labels.size(0)
            n_correct += (preds == labels).sum().item()
                
        accuracy = n_correct/float(n_samples)

        print(f'Accuracy of model on {phase} set = {(100.0 * accuracy):.4f} %')

    return features,true_labels
        

In [31]:
features = []
true_labels = []

In [32]:
train_loader = DataLoader(dataset=train_dataset,
                         batch_size=1,
                         shuffle=False,
                         num_workers=4)

features,true_labels = eval_model_extract_features(features,true_labels,model,dataloader=train_loader,phase='training')

print(len(features),len(true_labels))

  cpuset_checked))


Accuracy of model on training set = 99.7687 %
3458 3458


In [35]:
features,true_labels = eval_model_extract_features(features,true_labels,model,dataloader=val_loader,phase='validation')

print(len(features),len(true_labels))

  cpuset_checked))


Accuracy of model on validation set = 94.3353 %
4323 4323


In [36]:
ftrs = features.copy() 
lbls = true_labels.copy()

In [39]:
for i in range(len(ftrs)):
    ftrs[i]=ftrs[i].cpu().numpy()

ftrs[0].shape

(1, 512)

In [41]:
for i in range(len(lbls)):
    lbls[i]=lbls[i].cpu().numpy()

lbls[0].shape

(1,)

In [42]:
type(ftrs),type(lbls)

(list, list)

In [43]:
ftrs = np.array(ftrs)
lbls = np.array(lbls)

ftrs.shape,lbls.shape

((4323, 1, 512), (4323, 1))

In [44]:
n_samples = ftrs.shape[0]*ftrs.shape[1]
n_features = ftrs.shape[2]
ftrs = ftrs.reshape(n_samples,n_features)

print(ftrs.shape)

(4323, 512)


In [45]:
n_lbls = lbls.shape[0]
lbls = lbls.reshape(n_lbls)

print(lbls.shape)

(4323,)


In [46]:
# save to csv
ftrs_df = pd.DataFrame(ftrs)
ftrs_df.to_csv('./resnet50_FC_features_512.csv',index=False)

# reloading the saved csv into a df

ftrs_df = pd.read_csv('./resnet50_FC_features_512.csv')
ftrs_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511
0,0.085607,0.021805,0.441109,-0.286769,0.513648,-0.185310,-0.292494,-0.433075,-0.010477,-0.379597,-0.464587,-0.311266,-0.397218,-0.035729,-0.184875,0.005529,-0.485860,-0.084951,1.079289,0.234766,-0.607950,-0.416270,0.270703,0.078726,3.390669,1.175128,1.639136,0.919151,0.297039,1.392895,-0.078611,-0.245920,-0.380388,-1.181374,0.501518,-0.359335,0.949465,-0.596463,0.978446,0.250207,...,0.213398,-0.673682,0.134089,-0.830013,-0.270452,0.274271,-0.492283,1.312896,0.623881,-0.427828,-0.434984,-0.211153,0.655323,0.089081,2.196209,-0.643546,0.143208,1.332753,-0.046068,0.817310,0.190408,-0.307006,0.571560,0.143442,1.586519,0.291949,-0.515978,1.637968,1.365301,0.992635,1.900833,-0.517263,0.046018,-0.363306,-0.605235,-0.088705,0.874085,-0.474211,-0.151722,-0.569436
1,-0.133243,-0.964772,-0.479122,-0.388593,2.520442,-0.078540,1.397684,-0.962195,-0.197336,-0.283352,-0.368340,-0.348090,-0.860821,0.803020,-0.532486,-0.219943,-0.567403,-0.213607,1.202563,1.007898,1.166395,-0.908014,1.859558,1.871797,2.570776,1.382042,2.295932,0.721915,1.490825,0.843888,-0.023431,0.762509,-0.507451,0.378494,0.872305,-0.252291,2.568783,-1.072574,-0.142558,0.200490,...,-0.020628,-0.657437,0.102203,-0.975201,-0.341460,1.338218,-0.867677,1.562890,-0.927345,-0.620522,-0.262975,-0.594635,2.326258,-0.029713,1.235405,-1.075083,-0.299171,1.695334,2.343036,0.426081,0.945150,-0.855601,0.465321,0.578151,2.881992,0.689831,-0.723939,0.376424,1.517536,-0.131685,0.042126,-0.536044,-0.029770,-0.425078,-0.667606,-0.453189,2.844271,-0.952448,1.436085,1.016485
2,0.083105,-0.696422,-0.796738,-0.497731,1.963784,-0.284097,3.261419,-0.800969,-0.506469,-0.441910,0.079485,-0.463492,-0.471498,0.799019,-0.376445,-0.360156,-0.686314,-0.172683,1.956364,-0.891847,3.087634,-0.818431,0.335942,0.611970,0.758249,1.121906,0.802602,0.901728,3.559395,-1.312710,0.931835,2.984989,-0.468388,1.456828,0.267345,-0.318779,3.901358,-0.609226,-0.471357,1.560983,...,-0.130084,-0.486394,0.011836,-0.607265,-0.756973,0.369215,-0.849208,2.028156,0.119567,-0.803319,-0.465819,0.024541,0.522723,-0.393713,0.456446,-1.160259,-0.269853,1.397479,0.577923,-0.016859,1.445956,-0.849506,2.242668,1.023136,1.804294,0.145638,-0.650150,1.232863,2.362357,0.294656,0.325273,-0.361375,0.469979,-0.755104,-0.431179,0.206783,2.577395,-0.640026,0.363521,0.624068
3,-0.071460,-0.553484,0.029384,-0.145030,-0.509019,-0.443575,2.058645,-0.614785,-0.025296,-0.196204,1.508216,-0.386372,-0.322134,0.093845,-0.164296,-0.202908,-0.675448,1.987206,0.743067,-0.112193,1.952012,-0.279601,1.317924,1.228551,-0.335249,-0.084516,0.756525,1.279888,0.214162,-0.109619,1.804547,1.737462,-0.252732,1.941269,-0.171081,-0.347327,-0.324948,-0.587768,0.029561,-0.269150,...,0.598750,1.275778,0.271461,2.004199,-0.408011,1.705920,-0.323414,-0.148019,0.805807,-0.330228,-0.294776,0.718698,0.251789,0.782802,-0.270511,-0.356032,-0.055300,0.247848,1.221817,-0.020865,1.573320,0.149294,-0.046459,0.353816,-0.122537,-0.027578,-0.480010,-0.243383,-0.362624,0.525934,-0.071147,-0.662112,-0.019017,-0.365846,-0.136268,0.428925,-0.099518,-0.303276,2.082908,0.965919
4,0.197739,-0.731074,-0.256347,-1.116051,-1.104760,0.012868,0.499145,-0.610310,-0.062483,-0.172752,-0.808338,-0.300046,-1.091365,-0.175848,-0.590491,-0.383289,-0.950313,1.537892,2.285882,-0.894574,1.005046,-0.548572,0.915498,2.209650,4.332871,1.777773,3.032871,2.438895,0.461631,1.467396,0.488518,0.534241,-0.472278,0.219244,0.072697,-0.203950,-0.583803,-0.649506,1.389083,-0.124601,...,0.184684,-0.562442,0.392000,0.149472,-0.060157,2.243619,-0.548696,1.862436,0.320670,-0.460364,-0.493218,-1.026154,0.948146,1.560519,2.732547,-1.321941,-0.210493,2.269151,1.818774,-0.473236,2.870932,-1.494021,-0.124955,0.283038,2.440798,1.405221,-1.211263,1.504962,0.168726,0.873936,2.443866,-0.567604,-0.187445,-0.365194,-0.753977,-0.016396,0.172039,-0.294501,0.093440,0.641316
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4318,-0.411964,-0.215166,0.397792,0.127926,0.365831,-0.639052,1.061767,-0.250366,-0.247163,-0.370309,1.497652,-0.214139,-0.195388,-0.013514,0.012661,0.188817,-0.336080,1.103943,-0.164219,1.557188,0.655723,-0.521478,1.169153,0.040593,-0.518317,-0.471488,0.012427,0.054357,-0.207444,1.032838,1.547008,0.688899,-0.316939,0.648589,0.135564,-0.549887,-0.142928,-0.493280,-0.090844,-0.563234,...,0.748472,1.338823,0.153386,1.769272,-0.446314,0.650508,-0.419870,-0.274930,0.959066,-0.248627,-0.318770,1.104606,0.445158,0.750738,-0.376530,0.280118,0.080847,-0.069336,0.432121,0.767615,0.103846,0.989550,-0.475103,-0.104094,-0.102964,0.004872,-0.392388,-0.187689,0.181535,0.615621,-0.114013,-0.522579,0.209361,-0.475251,0.033660,-0.036103,-0.219503,-0.180813,2.115536,0.385495
4319,-0.744634,-0.367143,1.707262,0.563260,0.069633,-0.691620,0.129602,-0.526031,-0.377778,-0.736090,2.004293,-0.408265,-0.393246,-0.356447,0.076957,0.251363,-0.789603,1.927533,-0.053897,2.110189,-0.328533,-0.817039,1.364293,-0.739697,0.126688,0.165101,-0.221024,0.270788,-0.458646,2.639428,2.171700,-0.067937,-0.599245,-0.461420,0.408699,-1.055864,-0.621415,-0.718658,0.534098,-0.680208,...,1.283679,2.375924,-0.003406,1.893804,-0.873891,0.238221,-0.773082,-0.402667,2.376142,-0.437616,-0.652478,2.314448,1.110072,1.417839,0.516021,-0.226597,0.085850,0.059442,-0.809649,1.077470,-0.256480,2.136263,-0.387530,-0.117997,-0.276213,-0.259430,-1.149565,0.687346,-0.313412,1.807125,1.162865,-0.855010,0.792749,-0.682653,-0.344341,-0.449044,-0.725507,-0.555504,2.674017,-0.772274
4320,-0.189872,-0.064006,0.198464,0.052106,-0.556809,-0.325992,1.246096,-0.125790,-0.203488,-0.040403,0.882528,-0.256750,-0.152355,-0.164828,-0.004779,0.014405,-0.113635,0.965828,0.237597,-0.088504,1.087871,-0.319717,0.501730,0.358327,-0.670823,-0.460788,-0.022841,0.590250,-0.073192,-0.358912,1.104351,1.078326,-0.239340,1.085713,-0.188321,-0.453035,-0.470226,-0.425583,0.210761,-0.221895,...,0.608380,1.263205,0.002098,1.483538,-0.377799,0.971156,-0.218028,-0.211596,0.677350,-0.143152,-0.142737,0.839846,-0.201976,0.756051,-0.535939,-0.116270,-0.175716,-0.104628,0.470687,0.122938,0.774564,0.282786,-0.140790,0.116737,-0.369623,-0.191232,-0.398529,-0.119287,-0.540150,0.332821,-0.215005,-0.370967,0.168773,-0.028186,-0.061796,0.390675,-0.531841,-0.157892,1.204072,0.653930
4321,0.197524,-0.386280,0.245068,-0.459268,-0.636512,-0.162679,0.479350,-0.480562,-0.025279,-0.069157,-0.346472,-0.249377,-0.406491,0.128068,-0.299416,-0.246644,-0.377301,0.904717,1.483523,-0.725433,0.873144,-0.354503,0.134507,0.702128,2.022075,1.200273,1.347055,1.335216,0.391315,0.567601,0.579339,0.578097,-0.335478,0.249294,-0.303445,-0.234489,0.046308,-0.314374,0.911541,0.213348,...,0.144132,-0.193557,0.294394,0.080515,-0.294245,0.775613,-0.349679,1.201547,0.641203,-0.232427,-0.415436,-0.333580,0.138332,0.737065,1.390217,-0.603217,0.101477,1.131908,0.365596,-0.368327,1.525597,-0.404978,0.468510,0.068937,1.040590,0.312121,-0.710542,1.058605,0.103527,0.811860,1.489462,-0.398147,0.020573,-0.133686,-0.239739,0.008940,-0.189435,-0.142793,-0.277332,0.049284


In [48]:
# appending labels to the feature set
ftrs_df['label'] = lbls

ftrs_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499,500,501,502,503,504,505,506,507,508,509,510,511,label
0,0.085607,0.021805,0.441109,-0.286769,0.513648,-0.18531,-0.292494,-0.433075,-0.010477,-0.379597,-0.464587,-0.311266,-0.397218,-0.035729,-0.184875,0.005529,-0.48586,-0.084951,1.079289,0.234766,-0.60795,-0.41627,0.270703,0.078726,3.390669,1.175128,1.639136,0.919151,0.297039,1.392895,-0.078611,-0.24592,-0.380388,-1.181374,0.501518,-0.359335,0.949465,-0.596463,0.978446,0.250207,...,-0.673682,0.134089,-0.830013,-0.270452,0.274271,-0.492283,1.312896,0.623881,-0.427828,-0.434984,-0.211153,0.655323,0.089081,2.196209,-0.643546,0.143208,1.332753,-0.046068,0.81731,0.190408,-0.307006,0.57156,0.143442,1.586519,0.291949,-0.515978,1.637968,1.365301,0.992635,1.900833,-0.517263,0.046018,-0.363306,-0.605235,-0.088705,0.874085,-0.474211,-0.151722,-0.569436,1
1,-0.133243,-0.964772,-0.479122,-0.388593,2.520442,-0.07854,1.397684,-0.962195,-0.197336,-0.283352,-0.36834,-0.34809,-0.860821,0.80302,-0.532486,-0.219943,-0.567403,-0.213607,1.202563,1.007898,1.166395,-0.908014,1.859558,1.871797,2.570776,1.382042,2.295932,0.721915,1.490825,0.843888,-0.023431,0.762509,-0.507451,0.378494,0.872305,-0.252291,2.568783,-1.072574,-0.142558,0.20049,...,-0.657437,0.102203,-0.975201,-0.34146,1.338218,-0.867677,1.56289,-0.927345,-0.620522,-0.262975,-0.594635,2.326258,-0.029713,1.235405,-1.075083,-0.299171,1.695334,2.343036,0.426081,0.94515,-0.855601,0.465321,0.578151,2.881992,0.689831,-0.723939,0.376424,1.517536,-0.131685,0.042126,-0.536044,-0.02977,-0.425078,-0.667606,-0.453189,2.844271,-0.952448,1.436085,1.016485,3
2,0.083105,-0.696422,-0.796738,-0.497731,1.963784,-0.284097,3.261419,-0.800969,-0.506469,-0.44191,0.079485,-0.463492,-0.471498,0.799019,-0.376445,-0.360156,-0.686314,-0.172683,1.956364,-0.891847,3.087634,-0.818431,0.335942,0.61197,0.758249,1.121906,0.802602,0.901728,3.559395,-1.31271,0.931835,2.984989,-0.468388,1.456828,0.267345,-0.318779,3.901358,-0.609226,-0.471357,1.560983,...,-0.486394,0.011836,-0.607265,-0.756973,0.369215,-0.849208,2.028156,0.119567,-0.803319,-0.465819,0.024541,0.522723,-0.393713,0.456446,-1.160259,-0.269853,1.397479,0.577923,-0.016859,1.445956,-0.849506,2.242668,1.023136,1.804294,0.145638,-0.65015,1.232863,2.362357,0.294656,0.325273,-0.361375,0.469979,-0.755104,-0.431179,0.206783,2.577395,-0.640026,0.363521,0.624068,0
3,-0.07146,-0.553484,0.029384,-0.14503,-0.509019,-0.443575,2.058645,-0.614785,-0.025296,-0.196204,1.508216,-0.386372,-0.322134,0.093845,-0.164296,-0.202908,-0.675448,1.987206,0.743067,-0.112193,1.952012,-0.279601,1.317924,1.228551,-0.335249,-0.084516,0.756525,1.279888,0.214162,-0.109619,1.804547,1.737462,-0.252732,1.941269,-0.171081,-0.347327,-0.324948,-0.587768,0.029561,-0.26915,...,1.275778,0.271461,2.004199,-0.408011,1.70592,-0.323414,-0.148019,0.805807,-0.330228,-0.294776,0.718698,0.251789,0.782802,-0.270511,-0.356032,-0.0553,0.247848,1.221817,-0.020865,1.57332,0.149294,-0.046459,0.353816,-0.122537,-0.027578,-0.48001,-0.243383,-0.362624,0.525934,-0.071147,-0.662112,-0.019017,-0.365846,-0.136268,0.428925,-0.099518,-0.303276,2.082908,0.965919,2
4,0.197739,-0.731074,-0.256347,-1.116051,-1.10476,0.012868,0.499145,-0.61031,-0.062483,-0.172752,-0.808338,-0.300046,-1.091365,-0.175848,-0.590491,-0.383289,-0.950313,1.537892,2.285882,-0.894574,1.005046,-0.548572,0.915498,2.20965,4.332871,1.777773,3.032871,2.438895,0.461631,1.467396,0.488518,0.534241,-0.472278,0.219244,0.072697,-0.20395,-0.583803,-0.649506,1.389083,-0.124601,...,-0.562442,0.392,0.149472,-0.060157,2.243619,-0.548696,1.862436,0.32067,-0.460364,-0.493218,-1.026154,0.948146,1.560519,2.732547,-1.321941,-0.210493,2.269151,1.818774,-0.473236,2.870932,-1.494021,-0.124955,0.283038,2.440798,1.405221,-1.211263,1.504962,0.168726,0.873936,2.443866,-0.567604,-0.187445,-0.365194,-0.753977,-0.016396,0.172039,-0.294501,0.09344,0.641316,1


In [49]:
ftrs_df.to_csv('./resnet50_FC_512_features_with_labels.csv',index=False)

print('feature set saved successfully !')

feature set saved successfully !


In [47]:
# save model
MODEL_PATH = './resnet50_TL_model_94%acc.pth'
torch.save(model.state_dict(),MODEL_PATH)