In [None]:
import os
import numpy as np
import pandas as pd
from IPython.display import display

import torch
from torch import nn, optim
from torch.utils.data import DataLoader,Dataset
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision
from torchvision import datasets, models, transforms

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import preprocessing

from skimage import io
import random

from tqdm.notebook import tqdm

In [None]:
#seed everything
seed=33291
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
device=torch.device("cuda")
device

## Train data preprocessing

In [None]:
comp_dir="../input/sorghum-id-fgvc-9"

train_img_dir = os.path.join(comp_dir, 'train_images')
train_labels_dir = os.path.join(comp_dir, 'train_cultivar_mapping.csv')

test_img_dir = os.path.join(comp_dir, 'test')
test_subm_dir = os.path.join(comp_dir, 'sample_submission.csv')

In [None]:
train_labels = pd.read_csv(train_labels_dir)

display(train_labels)

In [None]:
#remove missing value
display(train_labels.iloc[3329])

train_labels=train_labels.drop(3329).reset_index(drop=True)

In [None]:
train_labels[train_labels.cultivar.isna()]

In [None]:
#encode cultivar by label codes
encoder=preprocessing.LabelEncoder()

train_labels["cultivar_code"]=encoder.fit_transform(train_labels.cultivar)
train_labels

In [None]:
plt.figure(figsize=[25, 5])
sns.countplot(x=train_labels.cultivar_code)

In [None]:
train_labels.cultivar_code.nunique()

In [None]:
#check original images
img_name=os.path.join(train_img_dir, '2017-06-01__10-26-27-479.png/')
image = io.imread(img_name)
plt.imshow(image)

In [None]:
#Custom dataset for train data
class ImageData(Dataset):
    def __init__(self, images_dir, labels_df, transform):
        super().__init__()
        self.labels_df = labels_df
        self.images_dir = images_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels_df)
    
    def __getitem__(self, index):       
        img_name = self.labels_df.image[index]
        label = self.labels_df.cultivar_code[index]

        img_path = os.path.join(self.images_dir, img_name)   
            
        image = io.imread(img_path)
        
        image = self.transform(image)
        
        return image, label

In [None]:
#train-val spliting
val_split=0.05
val_len=int(len(train_labels)*val_split)

indices = np.random.permutation(len(train_labels))

val_indices, train_indices = indices[:val_len],indices[val_len:]
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

In [None]:
#define transformations and train-val loaders
train_transforms = transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize(250),
                                    transforms.RandomRotation(30,expand=True),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomVerticalFlip(),
                                    transforms.RandomPosterize(bits=3),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                         std=[0.229, 0.224, 0.225])
])

val_transformations=transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize(224),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                         std=[0.229, 0.224, 0.225])
])

train_img = ImageData(images_dir = train_img_dir,labels_df = train_labels,  transform = train_transforms)
val_img = ImageData(images_dir = train_img_dir,labels_df = train_labels,  transform = val_transformations)
train_loader = DataLoader(dataset = train_img,batch_size=32,sampler=train_sampler)
val_loader = DataLoader(dataset = val_img, batch_size=32,sampler=val_sampler)

In [None]:
#check images after transformations
iterator=iter(train_loader)
img,labels=iterator.next()
plt.imshow(img[0].permute(1, 2, 0))
print(labels[0])

## Define model

In [None]:
#import pretrained efficientnet b0
model=models.efficientnet_b0(pretrained=True)
model=nn.DataParallel(model)
model

In [None]:
#replace linear layer of the model

prev_params=list(model.module.parameters())

model.module.classifier=nn.Sequential(
    nn.Dropout(0.2,inplace =True),
    nn.Linear(1280,100),
    nn.LogSoftmax(dim=1)
)
model

In [None]:
#optionaly load last state of the model to continue training
last_version_model_dir="../input/pytorch-efficientnet-b0"

state_dict = torch.load(os.path.join(last_version_model_dir, 'last_state.pth'))
model.load_state_dict(state_dict)

In [None]:
#define criterion, optimizer and scheduler
#for previous layers of model and new linear layer define different learning rate
model=model.to(device)

params_classifier=model.module.classifier.parameters()

criterion=nn.NLLLoss()
optimizer=optim.SGD([{"params":prev_params, "lr":1e-5}, {"params":params_classifier, "lr":1e-3}])
lr_scheduler=optim.lr_scheduler.CyclicLR(optimizer,base_lr=[1e-5,1e-3],max_lr =[1e-3,1e-1],step_size_up=100,step_size_down=200)

## Model training

In [None]:
#initialize values for training (TODO: add reading these values from save file)
total_epochs=0
train_losses=[]
val_losses=[]
val_accuracy_history=[]

#Copy last val loss value to continue train or use np.Inf otherwise
best_val_loss=0.0009597150281320765

In [None]:
#Can be repeated unlimited number of times
epochs=15

for e in range(epochs):
    #train part
    sum_train_loss=0

    sum_val_loss = 0
    val_correct = 0

    model.train()

    for images,labels in train_loader:
        labels=labels.type(torch.LongTensor)

        images=images.to(device)
        labels=labels.to(device)

        optimizer.zero_grad()

        log_ps=model(images)
        loss=criterion(log_ps,labels)
        sum_train_loss+=loss.item()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    #validation part
    else:
        model.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                labels=labels.type(torch.LongTensor)

                images=images.to(device)
                labels=labels.to(device)

                log_ps=model(images)
                loss=criterion(log_ps,labels)
                sum_val_loss+=loss.item()

                ps=torch.exp(log_ps)
                _, top_class = ps.topk(1, dim=1)
                equals_val = top_class == labels.view(*top_class.shape)
                
                val_correct += torch.sum(equals_val.type(torch.FloatTensor)).item()

    #mean losses to compare them:
    train_loss=sum_train_loss/len(train_loader.dataset)
    val_loss = sum_val_loss / len(val_loader.dataset)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    if(val_loss<best_val_loss):
        print("Better val loss have been achieved!")
        best_val_loss=val_loss
        torch.save(model.state_dict(), 'chp.pth')

    val_accuracy=val_correct / val_len
    val_accuracy_history.append(val_accuracy)

    print(f"Epoch: {e+1+total_epochs}/{total_epochs+epochs}\
            Training Loss: {train_loss:.5f}..\
            Val Loss: {val_loss:.5f}..\
            Val Accuracy: {val_accuracy:.5f}")
    
total_epochs+=epochs

In [None]:
#save last state of model
torch.save(model.state_dict(), 'last_state.pth')

#print losses and val accuracy for continue training (TODO: move it to save file)
print("train_losses: ", train_losses)
print("val_losses: ", val_losses)
print("val_accuracy_history: ", val_accuracy_history)

In [None]:
plt.figure(figsize=[15, 10])
sns.lineplot(data=pd.DataFrame({"train_losses":train_losses,"val_losses":val_losses}))

In [None]:
#skip first 10 epochs
plt.figure(figsize=[15, 10])
sns.lineplot(data=train_losses[10:])
sns.lineplot(data=val_losses[10:])

## Predict test images labels

In [None]:
#load model with the best val loss
state_dict = torch.load('chp.pth')
model.load_state_dict(state_dict)

In [None]:
test_df=pd.read_csv(test_subm_dir).drop(columns="cultivar")
test_df

In [None]:
#define TTA transforms
TTA_type="soft"


#w/o augmentations
tta_transformations_basic=transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize(224),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                         std=[0.229, 0.224, 0.225])
])

tta_list=[tta_transformations_basic]

#TTA with greater amount of augmentations than in train
if(TTA_type=="hard"):
    #augmentations of rotation
    tta_transformations_rotation=transforms.Compose([
                                        transforms.ToPILImage(),
                                        transforms.Resize(270),
                                        transforms.RandomRotation(30,expand=True),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.RandomVerticalFlip(),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
    ])

    #augmentations of values
    tta_transformations_values=transforms.Compose([
                                        transforms.ToPILImage(),
                                        transforms.ColorJitter(brightness=0.05,hue=0.05),
                                        transforms.RandomPosterize(bits=2),
                                        transforms.Resize(224),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
    ])

    tta_list.append(tta_transformations_rotation)
    tta_list.append(tta_transformations_values)

#TTA with less amount of augmentations than in train
elif(TTA_type=="soft"):
    #augmentations of rotation
    tta_transformations_rotation=transforms.Compose([
                                        transforms.ToPILImage(),
                                        transforms.Resize(250),
                                        transforms.RandomRotation(20,expand=True),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.RandomVerticalFlip(),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
    ])

    #augmentations of values
    tta_transformations_values=transforms.Compose([
                                        transforms.ToPILImage(),
                                        transforms.RandomPosterize(bits=3),
                                        transforms.Resize(224),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])
    ])

    tta_list.append(tta_transformations_rotation)
    tta_list.append(tta_transformations_values)

In [None]:
#Custom dataset for test data
class ImageDataTest(Dataset):
    def __init__(self, images_dir, test_df, transforms):
        super().__init__()
        self.test_df = test_df
        self.images_dir = images_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.test_df)
    
    def __getitem__(self, index):       
        img_name = self.test_df.filename[index]
        img_path = os.path.join(self.images_dir, img_name)   
        image = io.imread(img_path)

        output=[]
        
        for t in self.transforms:
            output.append(t(image))
        
        return output

In [None]:
#init test loader
test_img = ImageDataTest(images_dir = test_img_dir,test_df = test_df,  transforms = tta_list)
test_loader = DataLoader(dataset = test_img,batch_size=32,shuffle=False)

In [None]:
#check images after transformations
iterator=iter(test_loader)
outputs=iterator.next()

for i,img in enumerate(outputs):
    plt.subplot(1,len(outputs),i+1)
    plt.imshow(img[19].permute(1, 2, 0))

In [None]:
#make prediction
prediction = torch.empty(size=[len(test_loader.dataset)])

i=0

model.eval()
with torch.no_grad():
    torch.cuda.empty_cache()

    prediction=prediction.to(device)

    for outputs in tqdm(test_loader):
        tta_results=torch.zeros(size=[len(outputs[0]),100])
        tta_results=tta_results.to(device)
        for o in outputs:
            o=o.to(device)

            log_ps=model(o)
            ps=torch.exp(log_ps)

            tta_results+=ps

        _, top_class = tta_results.topk(1, dim=1)

        img_num=len(outputs[0])

        prediction[i:i+img_num]=top_class.squeeze()

        i+=img_num

In [None]:
#add predicted labels to values
test_df["cultivar"]=prediction.cpu().numpy().astype(int)
test_df

In [None]:
#decode test labels
test_df["cultivar"]=encoder.inverse_transform(test_df["cultivar"])
test_df

In [None]:
test_df.to_csv("submission.csv",index=False)