## Import Library

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns

import os
import albumentations as A
from albumentations.pytorch import ToTensor

import torch
import cv2

import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader


from sklearn import metrics, model_selection

%matplotlib inline

## Data

In [None]:
df = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv") 

In [None]:
df

In [None]:
df.label.value_counts().sort_index().plot.barh()
df.label.value_counts().sort_index()

## Visualization According to Label(0,1,2,3,4)

In [None]:
label0_img = df[df['label']==0].image_id.values
label1_img = df[df['label']==1].image_id.values
label2_img = df[df['label']==2].image_id.values
label3_img = df[df['label']==3].image_id.values
label4_img = df[df['label']==4].image_id.values

In [None]:
img_path = '../input/cassava-leaf-disease-classification/train_images/'

label0_img_path = [os.path.join(img_path, x) for x in label0_img]
label1_img_path = [os.path.join(img_path, x) for x in label1_img]
label2_img_path = [os.path.join(img_path, x) for x in label2_img]
label3_img_path = [os.path.join(img_path, x) for x in label3_img]
label4_img_path = [os.path.join(img_path, x) for x in label4_img]

In [None]:
plt.figure(figsize=(10,10))

for i in range(4):
    
    plt.subplot(2,2,i+1)
        
    img = cv2.imread(label0_img_path[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.title("label:0")
    plt.imshow(img)
    
plt.show()
        

In [None]:
plt.figure(figsize=(10,10))

for i in range(4):
    
    plt.subplot(2,2,i+1)
        
    img = cv2.imread(label1_img_path[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.title("label:1")
    plt.imshow(img)
    
plt.show()
        

In [None]:
plt.figure(figsize=(10,10))

for i in range(4):
    
    plt.subplot(2,2,i+1)
        
    img = cv2.imread(label2_img_path[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.title("label:2")
    plt.imshow(img)
    
plt.show()
        

In [None]:
plt.figure(figsize=(10,10))

for i in range(4):
    
    plt.subplot(2,2,i+1)
        
    img = cv2.imread(label3_img_path[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.title("label:3")
    plt.imshow(img)
    
plt.show()
        

In [None]:
plt.figure(figsize=(10,10))

for i in range(4):
    
    plt.subplot(2,2,i+1)
        
    img = cv2.imread(label4_img_path[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.title("label:4")
    plt.imshow(img)
    
plt.show()
        

## Split Train ,Val Data

In [None]:
df_train, df_val = model_selection.train_test_split(df, test_size=0.1, random_state=42, stratify=df.label.values)


In [None]:
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)

In [None]:
df_train.shape, df_val.shape

In [None]:
img_path = '../input/cassava-leaf-disease-classification/train_images/'

train_img_path = [os.path.join(img_path, x) for x in df_train.image_id.values]
val_img_path = [os.path.join(img_path, x) for x in df_val.image_id.values]

In [None]:
len(train_img_path), len(val_img_path)

In [None]:
train_target = df_train.label.values
val_target = df_val.label.values

## Define Dataset

In [None]:
class LeafDataset(Dataset):
    def __init__(self, img_ids, targets, transform):
        self.img_ids = img_ids
        self.targets = targets
        self.transform = transform
        
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, index):
        img_id = self.img_ids[index]
        img = cv2.imread(img_id)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        target = self.targets[index]
        
        if self.transform is not None:
            img = self.transform(image=img)['image']
            
        return img, target
        

## Augumentation

In [None]:
train_transform = A.Compose([
    A.Rotate(15,p=0.2),
    A.VerticalFlip(p=0.2),
    A.HorizontalFlip(p=0.2),
    ToTensor()
])

val_transform=A.Compose([
    ToTensor()
])

In [None]:
train_dataset = LeafDataset(img_ids = train_img_path, targets = train_target, transform=train_transform)
val_dataset = LeafDataset(img_ids = val_img_path, targets = val_target, transform=val_transform)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)

In [None]:
len(train_dataset),len(val_dataset)

## Resnet 34 Model

In [None]:
! pip install pretrainedmodels

In [None]:
import pretrainedmodels

model_name = 'resnet34'
model = pretrainedmodels.__dict__[model_name](pretrained='imagenet')

In [None]:
in_features = model.last_linear.in_features

In [None]:
model.last_linear = nn.Linear(in_features, len(np.unique(df.label.values)), bias=False)

In [None]:
model

## Optimizer, Loss, Lr_scheduler

In [None]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
loss_fn = nn.CrossEntropyLoss()

In [None]:
device = 'cuda' if torch.cuda.is_available else 'cpu'  

In [None]:
import gc
gc.collect()

In [None]:
from tqdm import tqdm_notebook
from sklearn.metrics import accuracy_score

## Train

In [None]:

best_score = -1

for epoch in tqdm_notebook(range(10)):
    model = model.to(device)
    model.train()
    train_loss=[]
    for inputs, outputs in train_dataloader:
        inputs = inputs.to(device)
        outputs = outputs.to(device)
        
        optimizer.zero_grad()
        
        logit = model(inputs)
        
        loss = loss_fn(logit, outputs)
        train_loss.append(loss.item())
        
        loss.backward()
        optimizer.step()
        
    val_loss=[]
    val_true=[]
    val_pred=[]
    
    model.eval()
    with torch.no_grad():
        for inputs, outputs in val_dataloader:
            inputs = inputs.to(device)
            outputs = outputs.to(device)
            
            logit = model(inputs)
            
            loss = loss_fn(logit, outputs)
            
            val_loss.append(loss.item())

            val_pred.append(np.argmax(logit.cpu().data.numpy(),axis=1))
            val_true.append(outputs.cpu().data.numpy())
        
    
    val_pred = np.concatenate(val_pred, axis=0)
    val_true = np.concatenate(val_true, axis=0)

    
    score = accuracy_score(val_pred, val_true)
    
    lr_scheduler.step(np.mean(val_loss))
    
    print(f" epoch: {epoch+1}, train_loss: {np.round(np.mean(train_loss),4)}, val_loss:{np.round(np.mean(val_loss),4)}, accuracy:{np.round(score,4)}")
    
    if score>best_score:
        best_score = score
        
        state_dict = model.cpu().state_dict()
        torch.save(state_dict, 'checkpoint.pth')
            
        
        

## Load Model's parameter 

In [None]:
model = model.load_state_dict(torch.load('../input/cassava-leaf-disease-classification/checkpoint.pth'))

In [None]:
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')

In [None]:
submission

## Test Image Classification 

In [None]:
test_image_path = '../input/cassava-leaf-disease-classification/test_images/2216849948.jpg'

In [None]:
img = cv2.imread(test_image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img/255


In [None]:
img = torch.FloatTensor(img)
img = img.permute(1,2,0)
img = img.to(device)

result = model(img)

