# Import Party!!

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch ,json, torchvision, os, glob
import torchvision

# Simple visualisation


In [None]:
def read_img(path):
    img=torchvision.io.read_image(path)
    return img


def plot_imgs(df,r=8,c=8,figsize=(20,20)):
    _, axs = plt.subplots(r,c,figsize=figsize)
    axs=axs.flatten()
    for n, ax in enumerate(axs):
        img=read_img(df.directory[n])
        cat=df.category[n]
        ax.imshow(torchvision.transforms.functional.to_pil_image(img))
        ax.set_title(cat)
        ax.axis('off')
        
    plt.tight_layout()
    plt.show()
    
def display_(path):
    df=pd.read_csv(path)
    display(df)
    
    print(df.info())
    print("unique values in columns")
    for col in df.columns:
        print(col,"           :          ", df[col].nunique())
    return df
    
    
def main():
    
    train_base_folder='../input/herbarium-2022-fgvc9/train_images'
    test_base_folder='../input/herbarium-2022-fgvc9/test_images'
    train_df_path='../input/herbarium-2022-pandas/train.csv'
    test_df_path='../input/herbarium-2022-pandas/test.csv'
    
    train_df=display_(train_df_path)
    plot_imgs(train_df)
    
    
    
main()


# Pipeline

In [None]:
from torch.utils.data import DataLoader,Dataset
from pytorch_lightning import LightningDataModule,LightningModule
from sklearn.model_selection import train_test_split
from kornia import image_to_tensor, tensor_to_image
from kornia.augmentation import ColorJitter, RandomChannelShuffle, RandomHorizontalFlip, RandomThinPlateSpline
from torchvision import transforms

class basic_pipe (Dataset):
    
    def __init__(
                self,
                df,
                ):
        
        self.df=df
    
    def __len__(self):
        return len(self.df)
    
    def read_img(self,path):
        
        img=torchvision.io.read_image(path)
        img=torchvision.transforms.Resize(size=(512,512))(img)
        
        return img/255.0
    
    
    def get_label(self,info):
        
        lab=torch.nn.functional.one_hot(torch.tensor(info), num_classes=15504)  ##############? solve
        
        return lab
    
    
    def __getitem__(self, idx):
        img=self.read_img(self.df.directory[idx])
        lab=self.get_label(self.df.category[idx])
        
        return img,lab
    
    
    
class pl_pipeline(LightningDataModule):
    
    def __init__(
        
        self, 
        dataset,
        df,
        bs,
        num_workers
                ):
        
        self.dataset=dataset
        self.df=df
        self.bs=bs
        self.num_workers=num_workers
        self.train_df,self.val_df=train_test_split(df)
        
        self.train_df,self.val_df=self.train_df.reset_index(),self.val_df.reset_index()
        
#     def setup(self):
#         self.train_df,self.val_df=train_test_split(df)
        
        
    def train_dataloader(self):
        data=self.dataset(self.train_df)
        dataloader=DataLoader(data,batch_size=self.bs )
        return dataloader
    
    def validation_dataloader(self):
        data=self.dataset(self.val_df)
        dataloader=DataLoader(data,batch_size=self.bs )
        return dataloader
        
        
def plot_basic_pipeline(data,r=8,c=8,figsize=(25,25)):
    _,axs=plt.subplots(r,c,figsize=figsize)
    axs=axs.flatten()
    
    for n, ax in enumerate(axs):
        img,lab=data[n]
        ax.imshow(torchvision.transforms.functional.to_pil_image(img))
        ax.set_title(lab)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

def plot_pl_pipeline(imgs,labs,r=8,c=8,figsize=(25,25)):
    _,axs=plt.subplots(r,c,figsize=figsize)
    axs=axs.flatten()
    
    for n, ax in enumerate(axs):
        img,lab=imgs[n],labs[n]
        ax.imshow(torchvision.transforms.functional.to_pil_image(img))
        ax.set_title(lab)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

    
def main():
    
    train_base_folder='../input/herbarium-2022-fgvc9/train_images'
    test_base_folder='../input/herbarium-2022-fgvc9/test_images'
    train_df_path='../input/herbarium-2022-pandas/train.csv'
    test_df_path='../input/herbarium-2022-pandas/test.csv'
    
    train_df=display_(train_df_path)
# uncomment the following to test the basic_pipeline for debuging purpose
#     data=basic_pipe(train_df)
#     plot_pipeline(data)
    dataloader=pl_pipeline(
        basic_pipe,
        df=train_df,
        bs=64,
        num_workers=1
    )
    img,lab=next(iter(dataloader.train_dataloader()))
    plot_pl_pipeline(img,lab)
    
    
    
main()

    
    
    
    


# Model and Classifier

In [None]:
# torchvision.models.densenet121(pretrained=False)


In [None]:
class pl_model(LightningModule):
    
    def __init__(self,
                model_name='default_model_name',
                 model=torchvision.models.densenet121(pretrained=True),
                 num_class=15504
                 
                ):
        super(pl_model,self).__init__()
        self.model_name=model_name
        self.model=model
        self.linear=torch.nn.Linear
        self.num_class=num_class
        self.get_head(self.model)
        
        
        
    def get_head(self,model):  # basic exp model ( baseline )
        
        model_base_features=self.model.classifier.in_features
        self.model.classifier=self.linear(model_base_features,out_features=self.num_class)
        
    def forward(self,x):
        
        out=self.model(x) # can add many hydra typeheadesor add stages here
        
        return out

import torchmetrics
import torch.nn.functional as F

class classifier(LightningModule):
    def __init__(self,
                model,
                ):
        super(classifier,self).__init__()
        self.model=model
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        
        
    def compute_loss(self,preds,gts):
        loss=F.cross_entropy(preds,gts)
        return loss

    def training_step(self,batch,batch_idx):
        x,y=batch
        prediction=self.model(x)
        loss=self.compute_loss(prediction,y)
        self.train_acc(prediction,y)
        self.log('train_acc',self.train_acc)
        self.log('loss',loss)
        
        return {'loss':loss,'matric_1':self.train_acc}
    
    
    def validation_step(self,batch,batch_idx):
        
        x,y=batch
        prediction=self.model(x)
        loss=self.compute_loss(prediction,y)
        
        self.val_acc(prediction,y)
        self.log('val_lacc',self.val_acc)
        self.log('loss',loss)
        
        return {'loss':loss, 'val_acc':self.val}
    
    def optimizer_step(self,model):
        opt=torch.nn.optim.Adam()
        opt(self.model.parameters)
        return opt

    
def plot_predictions(imgs,preds,r=2,c=3,figsize=(20,20)):
    _,axs=plt.subplots(r,c,figsize=figsize)
    axs=axs.flatten()
    
    for n, ax in enumerate(axs):
        ax.imshow(transforms.functional.to_pil_image(imgs[n]))
        ax.set_title(preds[n])
        ax.axis('off')
        
    plt.tight_layout()
    plt.show()
        
def main():
    
    train_base_folder='../input/herbarium-2022-fgvc9/train_images'
    test_base_folder='../input/herbarium-2022-fgvc9/test_images'
    train_df_path='../input/herbarium-2022-pandas/train.csv'
    test_df_path='../input/herbarium-2022-pandas/test.csv'
    
    train_df=display_(train_df_path)
#     data=basic_pipe(train_df)
    dataloader=pl_pipeline(
        basic_pipe,
        df=train_df,
        bs=64,
        num_workers=1
    )
    imgs,labs= next(iter(dataloader.train_dataloader()))
    imgs=imgs[2:8]
    model=pl_model()
    preds=model(imgs)
#     print(preds.shape,'\n ', labs.shape)
    plot_predictions(imgs,torch.argmax(preds,dim=1))
    
    
main()
