In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
work_dir="../input/cassava-leaf-disease-classification"

In [None]:
image_path="../input/cassava-leaf-disease-classification/train_images"

In [None]:
import os
import json
import sys
import numpy as np
import pandas as pd
import torch
import torchvision
from torch.utils.data import DataLoader,Dataset
import albumentations as A
from PIL import Image,ImageFile
from tqdm import tqdm_notebook as tqdm
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
import cv2
from sklearn import model_selection,metrics
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr
import torchvision.models as mdl
import warnings
warnings.simplefilter("ignore")

In [None]:
os.listdir(work_dir)

In [None]:
with open(work_dir+"/label_num_to_disease_map.json", 'r') as file:
    class_labels = json.load(file)
    
class_labels


In [None]:
df=pd.read_csv("../input/cassava-leaf-disease-classification/train.csv")
df.head()

**IMAGE VISUALIZATION**

In [None]:
plt.figure(figsize=(15,15))
data_sample=df.sample(16).reset_index(drop=True)
for i in range(16):
    plt.subplot(4,4,i+1)
    img = cv2.imread(work_dir+ "/train_images/" + data_sample.image_id[i])
    image = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 
    plt.axis("off")
    plt.imshow(image)
    plt.title(class_labels.get(str(data_sample.label[i])))
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(15,15))
data_sample1=df[df.label==0].sample(8).reset_index(drop=True)
for i in range(8):
    plt.subplot(2,4,i+1)
    img = cv2.imread(work_dir+ "/train_images/" + data_sample1.image_id[i])
    image = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 
    plt.axis("off")
    plt.imshow(image)
    plt.title(class_labels.get(str(data_sample1.label[i])))
plt.tight_layout()
plt.show()



In [None]:
plt.figure(figsize=(15,15))
data_sample2=df[df.label==1].sample(8).reset_index(drop=True)
for i in range(8):
    plt.subplot(2,4,i+1)
    img = cv2.imread(work_dir+ "/train_images/" + data_sample2.image_id[i])
    image = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 
    plt.axis("off")
    plt.imshow(image)
    plt.title(class_labels.get(str(data_sample2.label[i])))
plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(15,15))
data_sample3=df[df.label==2].sample(8).reset_index(drop=True)
for i in range(8):
    plt.subplot(2,4,i+1)
    img = cv2.imread(work_dir+ "/train_images/" + data_sample3.image_id[i])
    image = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 
    plt.axis("off")
    plt.imshow(image)
    plt.title(class_labels.get(str(data_sample3.label[i])))
plt.tight_layout()
plt.show()



In [None]:
plt.figure(figsize=(15,15))
data_sample4=df[df.label==3].sample(8).reset_index(drop=True)
for i in range(8):
    plt.subplot(2,4,i+1)
    img = cv2.imread(work_dir+ "/train_images/" + data_sample4.image_id[i])
    image = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 
    plt.axis("off")
    plt.imshow(image)
    plt.title(class_labels.get(str(data_sample4.label[i])))
plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(15,15))
data_sample5=df[df.label==4].sample(8).reset_index(drop=True)
for i in range(8):
    plt.subplot(2,4,i+1)
    img = cv2.imread(work_dir+ "/train_images/" + data_sample5.image_id[i])
    image = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 
    plt.axis("off")
    plt.imshow(image)
    plt.title(class_labels.get(str(data_sample5.label[i])))
plt.tight_layout()
plt.show()


In [None]:
df_train,df_valid=model_selection.train_test_split(df,test_size=0.05,random_state=42,stratify=df.label.values)

In [None]:
df_train.shape,df_valid.shape

In [None]:
df_train=df_train=df_train.reset_index(drop=True)
df_valid=df_valid.reset_index(drop=True)

In [None]:
train_image_path=[os.path.join(image_path,k) for k in df_train.image_id.values]
valid_image_path=[os.path.join(image_path,k) for k in df_valid.image_id.values]

In [None]:
train_image_path[:5]

In [None]:
valid_image_path[:5]

In [None]:
train_targets=df_train.label.values
valid_targets=df_valid.label.values

In [None]:
train_targets

In [None]:
valid_targets

In [None]:
class CassavaDataset(Dataset):
    def __init__(self,image_ids,labels,dimension=None,augmentations=None):
        super().__init__()
        self.image_ids=image_ids
        self.labels=labels
        self.dim=dimension
        self.augmentations=augmentations
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self,idx):
        imge=cv2.imread(self.image_ids[idx])
        imge=cv2.cvtColor(imge,cv2.COLOR_BGR2RGB)
        
        if self.dim:
            imge=cv2.resize(imge,self.dim)
            
        if self.augmentations:
            aug_img=self.augmentations(image=imge)
            imge=aug_img["image"]
            
        return {"image":transforms.ToTensor()(imge),
                "label":torch.tensor(self.labels[idx])}    
        

In [None]:
image_size = 256
train_aug= A.Compose([A.RandomCrop(height = 500, width = 500 ) ,
                          A.Transpose(p=0.3) , 
                          A.VerticalFlip(p=0.5),
                          A.HorizontalFlip(p=0.5),
                          A.RandomContrast(limit=0.05, p=0.5),
                          A.OneOf([ A.MedianBlur(blur_limit=3),
                                    #A.GaussianBlur(blur_limit=3),
                                    A.GaussNoise(var_limit=(5.0, 30.0)) ,], p=0.6),
                          A.OneOf([ A.OpticalDistortion(distort_limit=0.7), 
                                    A.GridDistortion(num_steps=2, distort_limit=0.2),
                                    A.ElasticTransform(alpha=3),  ], p=0.7),
                          A.CLAHE(clip_limit=4.0, p=0.7),
                          A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5) , 
                          A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.8),
                          A.Cutout(max_h_size=int(image_size * 0.2), max_w_size=int(image_size * 0.2), num_holes=1, p=0.5), 
                          A.Cutout(max_h_size=int(image_size * 0.1), max_w_size=int(image_size * 0.1), num_holes=3, p=0.5), 
                          A.Resize(image_size ,image_size )])





In [None]:
valid_aug = A.Compose([A.RandomCrop(height = 500, width = 500 ) ,
                          A.Transpose(p=0.3) , 
                          A.VerticalFlip(p=0.5),
                          A.HorizontalFlip(p=0.5),
                          A.RandomContrast(limit=0.05, p=0.5),
                          A.OneOf([ A.MedianBlur(blur_limit=3),
                                    #A.GaussianBlur(blur_limit=3),
                                    A.GaussNoise(var_limit=(5.0, 30.0)) ,], p=0.6),
                          A.OneOf([ A.OpticalDistortion(distort_limit=0.7), 
                                    A.GridDistortion(num_steps=2, distort_limit=0.2),
                                    A.ElasticTransform(alpha=3),  ], p=0.7),
                          A.CLAHE(clip_limit=4.0, p=0.7),
                          A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5) , 
                          A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.8),
                          A.Cutout(max_h_size=int(image_size * 0.2), max_w_size=int(image_size * 0.2), num_holes=1, p=0.5), 
                          A.Cutout(max_h_size=int(image_size * 0.1), max_w_size=int(image_size * 0.1), num_holes=3, p=0.5), 
                          A.Resize(image_size ,image_size )])


In [None]:
train_dataset=CassavaDataset(image_ids=train_image_path,labels=train_targets,dimension=None,augmentations=train_aug)
valid_dataset=CassavaDataset(image_ids=valid_image_path,labels=valid_targets,dimension=None,augmentations=valid_aug)


In [None]:
plt.figure(figsize=(15,15))

for i in range(16):
    plt.subplot(4,4,i+1)
    image,label=train_dataset[i]["image"],train_dataset[i]["label"]
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.title(class_labels.get(str(label.item())))


plt.tight_layout()
plt.show()

    
    

In [None]:
train_loader = torch.utils.data.DataLoader(
                                           train_dataset,
                                           batch_size=16,
                                           num_workers=4,
                                           shuffle=False,
                                           pin_memory=False,
                                           drop_last=False,
                                           )


In [None]:
valid_loader = torch.utils.data.DataLoader(
                                           valid_dataset,
                                           batch_size=16,
                                           num_workers=4,
                                           shuffle=False,
                                           pin_memory=False,
                                           drop_last=False,
                                           )