In [None]:
# !pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html
!pip install timm # install pytorch image models
!pip install torchmetrics

In [None]:
import torch
import os
import pandas as pd
import numpy as np
import random 

import albumentations as A
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn import preprocessing
from sklearn.model_selection import StratifiedKFold
import timm

import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
import torch.nn.functional as F
from torch import nn
import torchmetrics 

In [None]:
!nvidia-smi

In [None]:
class GlobalConstantsConfigure():
    def __init__(self):
        self.continue_training = True
        self.last_model = '../input/sorghum-100-cultivar-identification-8/tf_efficientnetv2_m_in21k_45_last.pt' 
        self.num_epochs_done = 55
        self.seed = 107
        self.num_classes = 100
        self.biggest_loss = 999
        self.training_size_rate = 0.8
        self.training_dir = '../input/sorghum-id-fgvc-9/train_images'
        self.model_name = 'tf_efficientnetv2_m_in21k'
        self.model_path = './tf_efficientnetv2_m_in21k_sgd_50.pt'
        self.image_size = 512
        self.batch_size = 8
        self.batch_size_testing = 32
        self.lr = 3e-5 # 3e-5
        self.num_epochs = 15
        self.steps_per_decay = 5
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.num_workers = 2  # if torch.cuda.is_available() else 4
gcc = GlobalConstantsConfigure()

In [None]:
def set_seed(seed) : 
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
set_seed(gcc.seed)

In [None]:
# df = pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv')

In [None]:
# df

In [None]:
# for row in df['image']:
#     if not os.path.isfile('../input/sorghum-id-fgvc-9/train_images/' + row):
#         print(row)

In [None]:
# df.drop(df[df['image'] == '.DS_Store'].index, inplace=True)

In [None]:
# df.groupby(['cultivar']).count().describe()

In [None]:
# df.describe()

In [None]:
# label_encoder = preprocessing.LabelEncoder()
# label_encoder.fit(df.cultivar)
# labels = label_encoder.transform(df.cultivar)

In [None]:
df_all = pd.read_csv('../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv')
print(len(df_all))
df_all.dropna(inplace=True)
print(len(df_all))
df_all.head()

In [None]:
unique_cultivars = list(df_all["cultivar"].unique())

In [None]:
df_all["file_path"] = df_all["image"].apply(lambda image: '../input/sorghum-id-fgvc-9/train_images/' + image)
df_all["cultivar_index"] = df_all["cultivar"].map(lambda item: unique_cultivars.index(item))
df_all["is_exist"] = df_all["file_path"].apply(lambda file_path: os.path.exists(file_path))
df_all = df_all[df_all.is_exist==True]
df_all.head()

In [None]:
skf = StratifiedKFold(n_splits=4, shuffle=True, random_state=gcc.seed)

for train_idx, valid_idx in skf.split(df_all['image'], df_all["cultivar_index"]):
    df_train = df_all.iloc[train_idx]
    df_valid = df_all.iloc[valid_idx]

print(f"train size: {len(df_train)}")
print(f"valid size: {len(df_valid)}")

print(df_train.cultivar.value_counts())
print(df_valid.cultivar.value_counts())

In [None]:
# dirs = df['image'].map(lambda x: '../input/sorghum-id-fgvc-9/train_images/' + x)

In [None]:
# dirs = np.array(dirs.tolist())

In [None]:
# print(len(labels))
# print(len(dirs))

In [None]:
# def path_walks_split_set(dirs, labels):


#     training_dirs = dirs[:int(gcc.training_size_rate * len(dirs))]
#     training_labels = labels[:int(gcc.training_size_rate * len(labels))]
#     validation_dirs = dirs[int(gcc.training_size_rate * len(dirs)):]
#     validation_labels = labels[int(gcc.training_size_rate * len(labels)):] 
    

    
#     return training_dirs, training_labels, validation_dirs, validation_labels

# training_dirs, training_labels, validation_dirs, validation_labels = path_walks_split_set(dirs, labels)

In [None]:
# len(os.listdir('../input/sorghum-id-fgvc-9/train_images'))

In [None]:
df_train

In [None]:
# avc

In [None]:
class SorghumDataset(Dataset):
    def __init__(self, dirs, labels, transformation=None):
        super(SorghumDataset,self).__init__()
        self.dirs = dirs
        self.labels = labels
        self.transformation = transformation
    def __len__(self):
        return len(self.dirs)

    def __getitem__(self, index):
        image = cv2.imread(self.dirs[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = self.labels[index] # need to one hot encoding here
        
        
        image = np.array(image)

        if self.transformation:
            aug_image = self.transformation(image=image)
            image = aug_image['image']
            
        image = image / 255.
        image = image.transpose((2, 0, 1))
        
        image = torch.from_numpy(image).type(torch.float32)
        image = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image)
        
        labels = torch.from_numpy(np.array(self.labels[index])).type(torch.float32)
        return image, labels

In [None]:
training_transformation = A.Compose([
    A.Resize(width=gcc.image_size, height=gcc.image_size, p=1.0),
    A.Flip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(p=0.5),
#     A.HueSaturationValue(p=0.5),
#     A.OneOf([
#         A.RandomBrightnessContrast(p=0.5),
#         A.RandomGamma(p=0.5),
#     ], p=0.5),
#     A.OneOf([
#         A.Blur(p=0.1),
#         A.GaussianBlur(p=0.1),
#         A.MotionBlur(p=0.1),
#     ], p=0.1),
#     A.OneOf([
#         A.GaussNoise(p=0.1),
#         A.ISONoise(p=0.1),
#         A.GridDropout(ratio=0.5, p=0.2),
#         A.CoarseDropout(max_holes=16, min_holes=8, max_height=16, max_width=16, min_height=8, min_width=8, p=0.2)
#     ], p=0.2),

])
validation_transformation = A.Compose([
    A.Resize(width=gcc.image_size, height=gcc.image_size, p=1.0)
])

In [None]:
# training_set = SorghumDataset(training_dirs, training_labels, training_transformation)
# validation_set = SorghumDataset(validation_dirs, validation_labels, validation_transformation)

training_set = SorghumDataset(df_train.file_path.values, df_train.cultivar_index.values, training_transformation)
validation_set = SorghumDataset(df_valid.file_path.values, df_valid.cultivar_index.values, validation_transformation)


training_dataloader = DataLoader(
    training_set,
    batch_size = gcc.batch_size,
    shuffle = True,
    num_workers = gcc.num_workers,
    pin_memory = True, 
    drop_last = True
)
validation_dataloader = DataLoader(
    validation_set,
    batch_size = gcc.batch_size,
    # shuffle = True,
    num_workers = gcc.num_workers,
    pin_memory = True,
    drop_last = True
)

In [None]:
class CustomModel(torch.nn.Module): 
    def __init__(self, model_backbone):
        super(CustomModel,self).__init__()
        self.model = model_backbone
        
        self.model.classifier = nn.Sequential(
            nn.BatchNorm1d(1280),
            nn.Linear(1280, 512),
            nn.Dropout(0.5),
            nn.ReLU(inplace=True),
            # nn.BatchNorm1d(512),
            
            nn.Linear(512, gcc.num_classes),
            
#             nn.BatchNorm1d(1280),
#             nn.Linear(1280, 512),
#             nn.Dropout(0.5),
#             nn.SiLU(inplace=True),

#             nn.Linear(512, 256),
#             nn.Dropout(0.5),
#             nn.SiLU(inplace=True),
#             nn.Linear(256, gcc.num_classes)
        )
    def forward(self,x):
        x = self.model(x)
        return x


In [None]:
# backbone = models.efficientnet_v2_s(pretrained=True) # models.efficientnet_b4(pretrained=True) # models.resnet50(pretrained=True) # models.resnet50(pretrained=True)
backbone = timm.create_model(gcc.model_name,pretrained=True)


# print(index)
model = CustomModel(backbone)
loss_func = torch.nn.CrossEntropyLoss()
metrics_acc = torchmetrics.Accuracy(threshold=0.0, num_classes = gcc.num_classes)
print(model)


# for index, child in enumerate(backbone.children()):
#     print(index)
#     if index <= 7:
#         for param in child.parameters():
#             param.requires_grad = False


trainable_parameters = [param for param in model.parameters() if param.requires_grad == True]
optimizer = torch.optim.Adam(trainable_parameters, lr = gcc.lr)
# optimizer = torch.optim.SGD(trainable_parameters, lr = gcc.lr, momentum = 0.9)
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, gcc.steps_per_decay)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)

model.to(gcc.device)

if gcc.continue_training == True:
    # model.load_state_dict(torch.load(gcc.last_model))
    checkpoint = torch.load(gcc.last_model)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    
    # optimizer = torch.optim.Adam(trainable_parameters, lr = gcc.lr)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, gcc.steps_per_decay)
    lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # print(lr_scheduler.state_dict())

print('load model done')

In [None]:
# def calc_accuracy(pred, true):
#     # print(pred, true)
#     true = true.type(torch.int64) # label
#     pred = F.softmax(pred, dim = 1)
#     true = torch.zeros(pred.shape[0], pred.shape[1]).scatter_(1, true.unsqueeze(1), 1.)
#     acc = (true.argmax(-1) == pred.argmax(-1)).float().detach().numpy()
#     acc = float(acc.sum() / len(acc))
#     return round(acc, 4)

In [None]:
def training_progress(training_dataloader, loss_func, scheduler):
    model.train()
    training_loss = 0
    training_acc = 0
    cnt = 0 
    print('Learning rate: ',scheduler.get_last_lr())
    print(scheduler.state_dict())
    training_loader = tqdm(training_dataloader, desc='Iterating through the training set')
    for image, label in training_loader:
        image = image.to(gcc.device)
        label = label.to(gcc.device)
        
        output = model(image)
        # output.to(gcc.device)

        acc = metrics_acc(output.cpu().argmax(1), label.cpu().int())
        loss = loss_func(output, label.long())
        # calculate accuracy here

        training_loss += loss.detach().item()
        training_acc += acc
        cnt +=1 
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    
    mean_training_loss = training_loss / cnt
    mean_training_acc = training_acc / cnt
    
    return mean_training_loss, mean_training_acc
    

In [None]:
def validation_progress(validation_dataloader, loss_func):
    model.eval()
    validation_loss = 0
    validation_acc = 0
    cnt = 0 
    validation_loader = tqdm(validation_dataloader, desc='Iterating through the validation set')
    with torch.no_grad():
        for image, label in validation_loader:
            image = image.to(gcc.device)
            label = label.to(gcc.device)

            output = model(image)
            loss = loss_func(output, label.long())
            # acc = calc_accuracy(output.cpu(), label.cpu())
            # output.to(gcc.device)
            acc = metrics_acc(output.cpu().argmax(1), label.cpu().int())
            # calculate accuracy here
            validation_loss += loss.detach().item()
            validation_acc += acc
            
            cnt += 1

    mean_validation_loss = validation_loss / cnt
    mean_validation_acc = validation_acc / cnt
    return mean_validation_loss, mean_validation_acc

In [None]:
def training_model(model, training_dataloader, validation_dataloader, loss_func, scheduler):
    training_losses_history, validation_losses_history = [], []
    training_acc_history, validation_acc_history = [], []
    best_loss = gcc.biggest_loss
    for epoch in range(gcc.num_epochs):
        
        training_loss, training_acc = training_progress(training_dataloader, loss_func, scheduler)
        training_losses_history.append(training_loss)
        training_acc_history.append(training_acc)
        
        validation_loss, validation_acc = validation_progress(validation_dataloader, loss_func)
        validation_losses_history.append(validation_loss)
        validation_acc_history.append(validation_acc)
        
        if validation_loss <= best_loss: # sussy baka
            best_loss = validation_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict()
            }, gcc.model_name + '_best.pt')
            # torch.save(model.state_dict(), gcc.model_name + '_best.pt')
        
        if epoch == gcc.num_epochs - 1: # i believe my timing capability
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict()
            }, gcc.model_name + '_' + str(gcc.num_epochs_done + gcc.num_epochs) + '_last.pt')

        print(f'Epoch {epoch + 1}/{gcc.num_epochs} | Training_loss : {training_loss:.3f} | Validation_loss : {validation_loss:.3f}' 
             + f' Training_acc : {training_acc:.3f} | Validation_acc : {validation_acc:.3f}'
             )
    return training_losses_history, validation_losses_history, training_acc_history, validation_acc_history


In [None]:
training_losses_history, validation_losses_history, training_acc_history, validation_acc_history = training_model(model, training_dataloader, validation_dataloader, loss_func, lr_scheduler)

In [None]:
# testing_progress(testing_dataloader, loss_func)

In [None]:
def plot_loss_history(model_name, train_loss_history, val_loss_history, num_epochs):
    
    x = np.arange(num_epochs)
    fig = plt.figure(figsize=(10, 6))
    plt.plot(x, train_loss_history, label='Train Loss', lw=3)
    plt.plot(x, val_loss_history, label='Validation Loss', lw=3)

    plt.title(f"{model_name}", fontsize=20)
    plt.legend(fontsize=12)
    plt.xlabel("Epoch", fontsize=15)
    plt.ylabel("Loss", fontsize=15)

    plt.show()
    
plot_loss_history(gcc.model_name, training_losses_history, validation_losses_history, gcc.num_epochs)

In [None]:
def plot_acc_history(model_name, train_acc_history, val_acc_history, num_epochs):
    
    x = np.arange(num_epochs)
    fig = plt.figure(figsize=(10, 6))
    plt.plot(x, train_acc_history, label='Training Accuracy', lw=3)
    plt.plot(x, val_acc_history, label='Validation Accuracy', lw=3)

    plt.title(f"{model_name}", fontsize=20)
    plt.legend(fontsize=12)
    plt.xlabel("Epoch", fontsize=15)
    plt.ylabel("Accuracy", fontsize=15)

    plt.show()
    
plot_acc_history(gcc.model_name, training_acc_history, validation_acc_history, gcc.num_epochs)

In [None]:
checkpoint = torch.load(gcc.model_name + '_best.pt')

model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
sub = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')
sub.head()

In [None]:
sub["filename"] = sub["filename"].apply(lambda image: '../input/sorghum-id-fgvc-9/test/' + image)
sub["cultivar"] = 0
sub.head()

In [None]:
testing_dataset = SorghumDataset(sub['filename'], sub['cultivar'], validation_transformation)
testing_dataloader = DataLoader(testing_dataset, 
                                batch_size=gcc.batch_size_testing, 
                                shuffle=False, 
                                num_workers=gcc.num_workers)

In [None]:
# predictions = np.zeros(len(testing_dataloader))
predictions = []
cnt = 0 
with torch.no_grad():
    for image, label in tqdm(testing_dataloader):
        image = image.to(gcc.device)
        outputs = model(image)
        # print(outputs)
        preds = outputs.detach().cpu()
        predictions.append(preds.argmax(1)) # need optimize here
        # print(predictions)

In [None]:
tmp = predictions[0]
for i in range(len(predictions) - 1):
    tmp = torch.cat((tmp, predictions[i+1]))

In [None]:
# predictions = label_encoder.inverse_transform(tmp)
predictions = [unique_cultivars[pred] for pred in tmp]

In [None]:
sub = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')
sub['cultivar'] = predictions
sub.to_csv('submission.csv', index=False)
sub.head()

In [None]:
# !cd /kaggle/working

In [None]:
# from IPython.display import FileLink
# FileLink(r'./tf_efficientnetv2_m_in21k_50_last.pt')