In [None]:
import os
import pandas as pd 
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt 
from tqdm import tqdm
import torch 
import torchvision 
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

In [None]:
train_dir = "../input/plant-pathology-2021-fgvc8/train_images"
test_dir = "../input/plant-pathology-2021-fgvc8/test_images"

In [None]:
df = pd.read_csv("../input/plant-pathology-2021-fgvc8/train.csv")
df.head()

## Let's see number of classes 

In [None]:
classes = df["labels"].unique()
print(classes, "\nTotal number of unique labels:",len(classes))

In [None]:
label_map = {}
for idx, j in enumerate(classes):
    label_map.update({str(j):idx})
    
print(label_map)
#df.labels = df[]

In [None]:
df = df.replace({"labels":label_map})

In [None]:
# parameters
batch_size = 32
image_size = 224
epochs = 10
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

# plotting few sample data 

In [None]:
fig, ax = plt.subplots(4,3, figsize = (12, 10))
i = 0
for row in range(4):
    for col in range(3):
        rand_idx = np.random.randint(len(df.image))
        while df.labels[rand_idx] != label_map[classes[i]]:
            rand_idx = np.random.randint(len(df.image))
        img = Image.open(train_dir+"/"+df.image[rand_idx]).convert('RGB')
        img = np.array(img)
        ax[row, col].imshow(img)
        ax[row, col].set_title(classes[i])
        ax[row, col].set_axis_off()
        i +=1 
plt.show()

## augmentation and other stuff!

In [None]:
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
def get_train_transform():
    return A.Compose([
        A.Resize(width=300, height=300, p=1),
        
        A.RandomRotate90(),
        A.Flip(),
        A.Transpose(),
        A.OneOf([
            A.IAAAdditiveGaussianNoise(),
            A.GaussNoise(),
        ], p=0.2),
        A.OneOf([
            A.MotionBlur(p=.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ], p=0.2),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
        A.OneOf([
            A.OpticalDistortion(p=0.3),
            A.GridDistortion(p=.1),
            A.IAAPiecewiseAffine(p=0.3),
        ], p=0.2),
        A.OneOf([
            A.CLAHE(clip_limit=2),
            A.IAASharpen(),
            A.IAAEmboss(),
            A.RandomBrightnessContrast(),            
        ], p=0.3),
        A.HueSaturationValue(p=0.3),
        ToTensorV2(p=1.0)
    ])#https://albumentations.ai/docs/examples/example/

def get_valid_transform():
    return A.Compose([
        A.Resize(width=300, height=300, p=1),
        ToTensorV2(p=1.0)
    ])

In [None]:
class LeafDataset(Dataset):
    def __init__(self, df, image_dir, transforms=None):
        self.df = df
        #slef.img_ids = df["image_id"].unique()
        self.image_dir=image_dir 
        self.transforms=transforms 
        
    def __getitem__(self, index:int):
        #image_id = self.img_ids[index] 
        img_id = self.df.image[index]
        label = self.df.labels[index]
#         image = cv2.imread(f"{self.image_dir}/{image_id}", cv2.IMREAD_COLOR)
        image = Image.open(f"{self.image_dir}/{img_id}").convert('RGB')
        image = np.array(image)
        if self.transforms is not None:
            image = self.transforms(image=image)['image']
            
        return image, label 
    
    def __len__(self):
        return len(self.df)
        

In [None]:
train_dataset = LeafDataset(df, train_dir,transforms = get_train_transform() )
valid_dataset = LeafDataset(df, train_dir,transforms = get_valid_transform() )

indices = torch.randperm(len(train_dataset)).tolist()

train_dataset = torch.utils.data.Subset(train_dataset, indices[:-200])
valid_dataset = torch.utils.data.Subset(valid_dataset, indices[-200:])

train_data_loader = DataLoader(
    train_dataset, 
    batch_size = batch_size, 
    shuffle = True, 
    num_workers = 4
)
valid_data_loader = DataLoader(
    valid_dataset,
    batch_size = batch_size, 
    shuffle = False,
    num_workers = 4
)
print(len(train_dataset), len(valid_dataset))

## let's plot few augmented 

In [None]:
def visualize(image, labels):
    plt.figure(figsize=(20, 10))
    plt.imshow(np.transpose(image.numpy(), (1, 2, 0)))
    plt.title(labels.detach().numpy(), fontsize= 20)
    plt.axis('off')
dataiter = iter(train_data_loader)
images, labels = dataiter.next()    
visualize(torchvision.utils.make_grid(images), labels)

# model tarining 

In [None]:
!pip install efficientnet_pytorch

In [None]:
from efficientnet_pytorch import EfficientNet

model = EfficientNet.from_name('efficientnet-b4')
# Unfreeze model weights
for param in model.parameters():
    param.requires_grad = True
    
n_features =  model._fc.in_features
model._fc = nn.Linear(n_features, len(classes))
model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay = 1e-4 )
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=1e-6, last_epoch=-1, verbose=False)

In [None]:
def train_one_epoch(model, epoch,train_data_loader, device, criterion, optimizer, scheduler=None,
                    schd_batch_update=False ):
    model.train()
    running_loss = 0.0
    sample_num = 0
    pbar = tqdm(enumerate(train_data_loader), total=len(train_data_loader))
    
    for i, data in pbar:
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(device).float(), labels.to(device).long()
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs) 
        loss = criterion(outputs, labels)
        loss.backward()
#         optimizer.step()

        # print statistics
        running_loss += loss.item()
        sample_num += labels.shape[0] 
        if ((i+1) % 2 == 0 ) or ((i+1)==len(train_data_loader)): 
            optimizer.step()
            optimizer.zero_grad() 
            if scheduler is not None and not schd_batch_update:
                scheduler.step()
        if ((i+1) % 1 == 0 ) or ((i+1)==len(train_data_loader)):
            description = f'epoch {epoch+1} loss: {running_loss/sample_num:.4f}'
            pbar.set_description(description)
            
    if scheduler is not None and not schd_batch_update:
                scheduler.step()
    
def valid_one_epoch(model, epoch,valid_data_loader, device, criterian, optimizer, scheduler=None, 
                    schd_loss_update=False):
    model.eval()
    loss_sum = 0
    sample_num = 0
    all_predictions =[]
    all_targets = []
    pbar = tqdm(enumerate(valid_data_loader), total=len(valid_data_loader))
    
    for i , data in pbar:
        inputs, labels = data
        inputs, labels = inputs.to(device).float(), labels.to(device).long()
        outputs = model(inputs) 
        all_predictions += [torch.argmax(outputs, 1).detach().cpu().numpy()]
        all_targets += [labels.detach().cpu().numpy()]
        loss = criterion(outputs, labels)
        loss_sum += loss.item()*labels.shape[0]
        sample_num += labels.shape[0] 
        if ((i+1) % 1 == 0) or ((i+1)==len(valid_data_loader)):    # print every 2000 mini-batches
            description = f'epoch {epoch+1} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)
        
    all_predictions= np.concatenate(all_predictions)
    all_targets = np.concatenate(all_targets)
    print('validation accuracy = {:.4f}'.format((all_predictions==all_targets).mean()))
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()
        


In [None]:
for epoch in range(10):  # loop over the dataset multiple times
    
    train_one_epoch(model, epoch,train_data_loader, device, criterion, optimizer, scheduler=None, schd_batch_update = False)
    with torch.no_grad():
        valid_one_epoch(model, epoch, valid_data_loader, device, criterion, optimizer, scheduler=None, schd_loss_update = False)
    torch.save(model.state_dict(),f'efficientnet-b4-{epoch}.pth')

print('Finished Training')

## Continue ...