# Import party!!

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torchvision
import glob, os
import torchvision.transforms.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset , DataLoader
from pytorch_lightning import LightningDataModule,LightningModule
plt.rcParams["savefig.bbox"] = 'tight'

# Helpers

In [None]:
def get_paths(base_folder):
    paths=glob.glob(base_folder+'/*.png')
    return paths

def read_img(path):
    img=torchvision.io.read_image(path)
    return img

def display_(df_path,base_dir):
    
    df=pd.read_csv(df_path)
    df.iloc[:,0]=base_dir+'/'+df.iloc[:,0]
    
    #rest are just fancy prints for more info from csv
    print('*'*100)
    print(df_path.split('/')[-1])
    display(df)
    print(df.info())
    print(f'unique _ values in cultivar: {df.cultivar.nunique()}')
    
    return df


def plot_imgs(paths,r=8,c=8,figsize=(20,20)):
    _,axs=plt.subplots(r,c,figsize=figsize)
    axs=axs.flatten()
    for n,ax in enumerate(axs):
        img=read_img(paths[n])
        ax.imshow(F.to_pil_image(img))
        ax.axis('off')
        
    plt.tight_layout()
    plt.show()
    


# Pipeline


In [None]:
class pipeline_basic(Dataset):
    
    def __init__(
                self,
                df
                ):
        
        self.df=df
        
    def __len__(self):
        return len(self.df)
    
    def read_img(self,img):
        img=torchvision.io.read_image(path)
        return img/255.0
    
    def __getitem__(self,idx):
        
        img=read_img(self.df.image[idx])
        lab=self.df.cultivar[idx]
        
        return img,lab
    
class PL_pipeline(LightningDataModule):
    def __init__(self,
                Dataset,
                df,
                bs):
        self.Dataset=Dataset(df)
        self.df=df
        self.bs=bs
        
    def setup(self):
        self.train_df,self.val_df=train_test_split(self.df)
    
    def training_dataloader(self):
        data=self.Dataset(self.train_df)
        return DataLoader(data,batch_size=self.bs)
    
    def validation_dataloader(self):
        data=self.Dataset(self.val_df)
        return DataLoader(data,batch_size=self.bs)
    
def plot_pipeline(dataset,r=8,c=8,figsize=(20,20)):
    _,axs=plt.subplots(r,c,figsize=figsize)
    axs=axs.flatten()
    for n, ax in enumerate(axs):
        img,lab=dataset[n]
        ax.imshow(F.to_pil_image(img))
        ax.set_title(lab)
        ax.axis('off')
        
    plt.tight_layout()
    plt.show()
    


# Avengers assemble

In [None]:
def main():
    
    training_folder='../input/sorghum-id-fgvc-9/train_images'
    test_folder='../input/sorghum-id-fgvc-9/test'
    train_df_path='../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv'
    submission_df_path='../input/sorghum-id-fgvc-9/sample_submission.csv'
    train_df=display_(train_df_path,training_folder)
#     plot_imgs(train_df.image)                                            # Basic level plotting of images from paths
    submission=display_(submission_df_path,test_folder)
    
    #testing pipeline
    print('testing pipeline')
    data=pipeline_basic(train_df)
    plot_pipeline(data)
    
    
    
main()

In [None]:
# TODO : EDA
# TODO : Pipeline (add batchwise augmentation method from lighning docs) 
# TODO : Model 
# TODO : trainer 
# TODO : monitor (neptune, W&B) 