In [None]:
import wandb
wandb.login()

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings('ignore')
from collections import Counter
import shutil
import os
import random
import pytorch_lightning as pl
from types import SimpleNamespace
from torchvision.models import resnet50
from pytorch_lightning.loggers import WandbLogger
%matplotlib inline

In [None]:
## Splitting train to train(80%) and valid(20%)

prefix='/kaggle/input/inaturalist12k/Data/inaturalist_12K/'

data_prefix='/kaggle/working/'

classes=['Amphibia', 'Animalia', 'Arachnida', 'Aves', 'Fungi', 'Insecta', 'Mammalia', 'Mollusca', 'Plantae', 'Reptilia']

flag=os.path.exists(data_prefix+'splittedVal')

## Splitting train to train(80%) and valid(20%)
valid_split=0.2
if not flag:
    for each in ['train','val']:
        shutil.copytree(prefix+each,data_prefix+each)
    os.mkdir(data_prefix+"splittedVal")
    for each in classes:
        images = os.listdir(data_prefix+'train/'+each+'/')
        random.shuffle(images)
        valid_till=int(len(images)*valid_split)
        os.mkdir(data_prefix+'splittedVal/'+each)
        for i in range(valid_till):
            shutil.move(data_prefix+'train/'+each+"/"+images[i],data_prefix+'splittedVal/'+each)
            

In [None]:
#loading data

transform = transforms.Compose([
                        transforms.Resize((224, 224)),
                        transforms.ToTensor(),
                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
                        ])

train_dataset = torchvision.datasets.ImageFolder(root=data_prefix+'train', transform=transform)
valid_dataset = torchvision.datasets.ImageFolder(root=data_prefix+'splittedVal', transform=transform)
test_dataset = torchvision.datasets.ImageFolder(root=data_prefix+'val', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader=DataLoader(valid_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
def getActivation(function): #activations
    if function=='ReLU':
        return nn.ReLU()
    if function=='GELU':
        return nn.GELU()
    if function=='SiLU':
        return nn.SELU()
    return nn.ReLU()   

In [None]:
# Building Model
class Model(pl.LightningModule):
    def __init__(self):
        
        super().__init__()
        self.learning_rate=0.0001
        
        self.resnet = resnet50(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad = False #freezing all layers 
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, 10) #changing fc layer to have 10 neurons as we have 10 classes
        self.loss = nn.CrossEntropyLoss() 
        self.valid_loss=[]
        self.valid_acc=[]
        self.train_loss=[]
        self.train_acc=[]
        
  
        
    def forward(self,x):
        return self.resnet(x)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.resnet.fc.parameters(),lr= self.learning_rate)

    def training_step(self,batch,batch_idx): # After every train batch, computes it's loss/acc and store it.
        X,Y = batch
        output = self(X)
        loss = self.loss(output,Y)
        acc = (output.argmax(dim = 1) == Y).float().mean()
        self.train_loss.append(loss)
        self.train_acc.append(acc)
        return loss


    def validation_step(self, batch,batch_idx): # After every valid batch, computes it's loss/acc and store it.
        X,Y = batch
        output = self(X)
        loss = self.loss(output,Y)
        acc = (output.argmax(dim = 1) == Y).float().mean()
        self.valid_loss.append(loss)
        self.valid_acc.append(acc)
        return loss
    
    
    def on_train_epoch_end(self): #once an epoch is completed, print and log the metrics to WandB
      valid_loss=sum(self.valid_loss)/len(self.valid_loss)
      valid_acc=sum(self.valid_acc)/len(self.valid_acc)
      train_loss=sum(self.train_loss)/len(self.train_loss)
      train_acc=sum(self.train_acc)/len(self.train_acc)
      self.train_acc=[]
      self.train_loss=[]
      self.valid_loss=[]
      self.valid_acc=[]
      print(f"Epoch: {self.current_epoch} train accuracy :{train_acc:.2f} valid_accuracy :{valid_acc:.2f}")
      wandb.log({'train_acc':train_acc,'train_loss':train_loss,'valid_acc':valid_acc,'valid_loss':valid_loss})
    
    def predict_step(self, batch, batch_idx): #for prediction
        X, Y = batch
        preds = self.resnet(X)
        return preds


In [None]:
# initialize the modle and fine tune
wandb.init(project='ResNet50 Model') 
model = Model() 
trainer = pl.Trainer(max_epochs=10,devices=1,accelerator='gpu') 
trainer.fit(model,train_loader,valid_loader) #fine tuning the model

In [None]:
# To compute final test accuracy

def calc_acc(data_loader,targets):
  preds = trainer.predict(model, data_loader)
  preds = torch.concat(preds)
  preds = preds.argmax(axis=1)
  preds=preds.numpy()
  targets=np.array(targets)
  return np.sum(preds==targets)/len(targets)

valid_accuracy=calc_acc(valid_loader,valid_dataset.targets)
test_accuracy=calc_acc(test_loader,test_dataset.targets)
print(f'valid accuracy: {valid_accuracy:.2f} test accuracy: {test_accuracy:.2f}')
wandb.log({'test accuracy':test_accuracy,'valid accuracy':valid_accuracy})
wandb.finish()