In [None]:
tez_path = '../input/tez-lib/'
effnet_path = '../input/efficientnet-pytorch/'
import sys
sys.path.append(tez_path)
sys.path.append(effnet_path)

In [None]:
import os
import albumentations as A
import matplotlib.pyplot as plt
import pandas as pd

import tez
from tez import enums
from tez.datasets import ImageDataset
from tez.utils import AverageMeter
from tez.callbacks import EarlyStopping

from tqdm import tqdm

import torch
import torch.nn as nn

import torchvision

from sklearn import metrics, model_selection

%matplotlib inline

# **Load data**

In [None]:
dfx = pd.read_csv("../input/plant-pathology-2021-fgvc8/train.csv")
dfx.head()

# **Encoded labels**

In [None]:
from sklearn.preprocessing import LabelEncoder

labelencoder = LabelEncoder()
dfx["encoded_labels"] = labelencoder.fit_transform(dfx["labels"])
dfx.head()

In [None]:
dfx.encoded_labels.value_counts()

# **Split train, valid & Reset index**

In [None]:
df_train, df_valid = model_selection.train_test_split(dfx, test_size=0.2, random_state=42, stratify=dfx.encoded_labels.values)

df_train = df_train.reset_index(drop=True)
df_valid = df_valid.reset_index(drop=True)

In [None]:
df_train.shape

In [None]:
df_valid.shape

# **Import image**

In [None]:
image_path = "../input/resized-plant2021/img_sz_512"

train_image_paths = [
    os.path.join(image_path, x) for x in df_train.image.values
]

train_image_paths = [
    os.path.join(image_path, x) for x in df_train.image.values
]

In [None]:
train_image_paths[:5]

**Set train, valid target**

In [None]:
train_target = df_train.encoded_labels.values
valid_target = df_valid.encoded_labels.values

In [None]:
# train_target

In [None]:
# valid_target

# **Create train_dataset**

In [None]:
train_dataset = ImageDataset(
    image_paths = train_image_paths,
    targets = train_target,
    augmentations = None
)

**Plot image**

In [None]:
def plot_img(img_dict):
    img_tensor = img_dict['image']
    target = img_dict['targets']
    print(target)
    plt.figure(figsize=(5,5))
    image = img_tensor.permute(1,2,0)/255
    plt.imshow(image)

In [None]:
plot_img(train_dataset[10])

# **Augmentation**

In [None]:
train_aug = A.Compose(
    [
        A.RandomResizedCrop(256, 256),
        A.Transpose(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.HueSaturationValue(
            hue_shift_limit=0.2, 
            sat_shift_limit=0.2,
            val_shift_limit=0.2, 
            p=0.5
        ),
        A.RandomBrightnessContrast(
            brightness_limit=(-0.1,0.1), 
            contrast_limit=(-0.1, 0.1), 
            p=0.5
        ),
#         A.Normalize(
#             mean=[0.485, 0.456, 0.406], 
#             std=[0.229, 0.224, 0.225], 
#             max_pixel_value=255.0, 
#             p=1.0
#         )
    ]
)

valid_aug = A.Compose(
    [
        A.CenterCrop(256, 256, p=1.0),
        A.Transpose(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.HueSaturationValue(
            hue_shift_limit=0.2, 
            sat_shift_limit=0.2,
            val_shift_limit=0.2, 
            p=0.5
        ),
        A.RandomBrightnessContrast(
            brightness_limit=(-0.1,0.1), 
            contrast_limit=(-0.1, 0.1), 
            p=0.5
        ),
#         A.Normalize(
#             mean=[0.485, 0.456, 0.406], 
#             std=[0.229, 0.224, 0.225], 
#             max_pixel_value=255.0, 
#             p=1.0
#         )
    ]
)

**Apply augmentation**

In [None]:
train_dataset = ImageDataset(
    image_paths = train_image_paths,
    targets = train_target,
    augmentations = train_aug
)

valid_dataset = ImageDataset(
    image_paths = train_image_paths,
    targets = train_target,
    augmentations = valid_aug
)

In [None]:
plot_img(train_dataset[10])

# **Create Model**

In [None]:
import pickle

# pretrained=True
# pretrained_model = torchvision.models.resnet18(pretrained=pretrained)

# Pkl_Filename = "pretrained_resnet18.pkl"

# with open(Pkl_Filename, 'wb') as file:  
#     pickle.dump(pretrained_model, file)

In [None]:
Pkl_Filename = "../input/resnet18-pretrained/pretrained_resnet18.pkl"

with open(Pkl_Filename, 'rb') as file:  
    pretrained_model = pickle.load(file)

# pretrained_model

In [None]:
def plot_result(train_loss, train_acc, train_f1, valid_loss, valid_acc, valid_f1):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 7))
    ax1.plot(train_loss, label='Train')
    ax1.plot(valid_loss, label='Validation')
    ax1.set_title('Loss')
    ax1.legend()

    ax2.plot(train_acc, label='Train')
    ax2.plot(valid_acc, label='Validation')
    ax2.set_title('Accuracy')
    ax2.legend()

    ax3.plot(train_f1, label='Train')
    ax3.plot(valid_f1, label='Validation')
    ax3.set_title('F1 Score')
    ax3.legend()

In [None]:
train_losses = []
train_acc = []
train_f1= []

valid_losses = []
valid_acc = []
valid_f1 = []

class PlantModel(tez.Model):
    def __init__(self,num_classes):
        super().__init__()
        self.convnet = pretrained_model
        self.convnet.fc = nn.Linear(512, num_classes)
        self.step_scheduler_after = "epoch"
        
    def loss(self, outputs, targets):
        if targets is None: 
            return None
        return nn.CrossEntropyLoss()(outputs, targets)
    
    def update_metrics(self, losses, monitor):
        self.metrics[self._model_state.value].update(monitor)
        self.metrics[self._model_state.value]["loss"] = losses.avg

    def monitor_metrics(self, outputs, targets):
        outputs = torch.argmax(outputs, dim = 1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        acc = metrics.accuracy_score(targets, outputs)
        f1_score = metrics.f1_score(outputs, targets, average='weighted')
        if(self.train_state is not None):
            if(self.train_state == enums.TrainingState.TRAIN_STEP_START):
                train_acc.append(acc)
                train_f1.append(f1_score)
#                 train_loss.append(loss)
            if(self.train_state == enums.TrainingState.VALID_STEP_START):
                valid_acc.append(acc)
                valid_f1.append(f1_score)
#                 valid_loss.append(loss)
#             if(self.train_state == enums.TrainingState.EPOCH_END):
#                 print("end.......................................")
#             if(self._model_state.value == "end"):
#                 print(f"Epoch: {self.current_epoch} ended.")
            
#             if(self.train_state == enums.TrainingState.EPOCH_START):
#                 print(f"Epoch: {self.current_epoch} started.")
        return{
            "accuracy" : acc,
            "f1_score" : f1_score
        }

    def train_one_epoch(self, data_loader):
        self.train()
        self.model_state = enums.ModelState.TRAIN
        losses = AverageMeter()
        tk0 = tqdm(data_loader, total=len(data_loader))
        for b_idx, data in enumerate(tk0):
            self.train_state = enums.TrainingState.TRAIN_STEP_START
            loss, metrics = self.train_one_step(data)
            if(self.train_state is not None):
#                 print(f"Train one epoch loss: {self.train_state}")
                if(self.train_state == enums.TrainingState.TRAIN_STEP_START):
                    train_losses.append(round(loss.item(),4))
#                     print(train_loss)
        
            self.train_state = enums.TrainingState.TRAIN_STEP_END
            losses.update(loss.item(), data_loader.batch_size)
            if b_idx == 0:
                metrics_meter = {k: AverageMeter() for k in metrics}
            monitor = {}
            for m_m in metrics_meter:
                metrics_meter[m_m].update(metrics[m_m], data_loader.batch_size)
                monitor[m_m] = metrics_meter[m_m].avg
            self.current_train_step += 1
            tk0.set_postfix(loss=losses.avg, stage="train", **monitor)
        tk0.close()
        self.update_metrics(losses=losses, monitor=monitor)
        return losses.avg
    
    def validate_one_step(self, data):
        _, loss, metrics = self.model_fn(data)
        if(self.train_state is not None):
#             print(f"Valid one epoch loss: {self.train_state}")
            if(self.train_state == enums.TrainingState.VALID_STEP_START):
#                 print("valid loss: ")
                valid_losses.append(round(loss.item(),4))
#                 print(valid_losses)
        return loss, metrics
    
    def load(self, model_path, device = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        if next(self.parameters()).device != self.device:
            self.to(self.device)
        model_dict = torch.load(model_path, map_location=torch.device(device))
        self.load_state_dict(model_dict["state_dict"])
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3)
        return opt

    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=0.7)
        return sch
        
    def forward(self, image, targets=None):
        outputs = self.convnet(image)
        if targets is not None:
            loss = self.loss(outputs, targets)
            mon_metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, mon_metrics
        return outputs, None, None
    
    def fit(
        self,
        train_dataset,
        valid_dataset=None,
        train_sampler=None,
        valid_sampler=None,
        device ='cuda' if torch.cuda.is_available() else 'cpu',
        epochs=10,
        train_bs=16,
        valid_bs=16,
        n_jobs=8,
        callbacks=None,
        fp16=False,
        train_collate_fn=None,
        valid_collate_fn=None,
    ):
        """
        The model fit function. Heavily inspired by tf/keras, this function is the core of Tez and this is the only
        function you need to train your models.
        """
        self._init_model(
            device=device,
            train_dataset=train_dataset,
            valid_dataset=valid_dataset,
            train_sampler=train_sampler,
            valid_sampler=valid_sampler,
            train_bs=train_bs,
            valid_bs=valid_bs,
            n_jobs=n_jobs,
            callbacks=callbacks,
            fp16=fp16,
            train_collate_fn=train_collate_fn,
            valid_collate_fn=valid_collate_fn,
        )

        for _ in range(epochs):
            self.train_state = enums.TrainingState.EPOCH_START
            self.train_state = enums.TrainingState.TRAIN_EPOCH_START
            train_loss = self.train_one_epoch(self.train_loader)
            self.train_state = enums.TrainingState.TRAIN_EPOCH_END
            if self.valid_loader:
                self.train_state = enums.TrainingState.VALID_EPOCH_START
                valid_loss = self.validate_one_epoch(self.valid_loader)
                self.train_state = enums.TrainingState.VALID_EPOCH_END
            if self.scheduler:
                if self.step_scheduler_after == "epoch":
                    if self.step_scheduler_metric is None:
                        self.scheduler.step()
                    else:
                        step_metric = self.name_to_metric(self.step_scheduler_metric)
                        self.scheduler.step(step_metric)
            self.train_state = enums.TrainingState.EPOCH_END
#             print(valid_losses)
            plot_result(train_losses, train_acc, train_f1, valid_losses, valid_acc, valid_f1)    
            if self._model_state.value == "end":
                break
            self.current_epoch += 1
        self.train_state = enums.TrainingState.TRAIN_END
        

**Resnet18 structure**

In [None]:
# torchvision.models.resnet18(pretrained=False)

In [None]:
# dfx.encoded_labels.nunique()

In [None]:
model = PlantModel(dfx.encoded_labels.nunique())

In [None]:
# model

In [None]:
img = train_dataset[0]["image"]
y = train_dataset[0]["targets"]
model(img.unsqueeze(0), y.unsqueeze(0))

**Our custom model**

**Train model**

In [None]:
es = EarlyStopping(
    monitor = "train_accuracy", 
    model_path = "model.bin", 
    patience = 2,
    mode='max'
)

model.fit(
    train_dataset,
    valid_dataset = valid_dataset,
    train_bs = 32,
    valid_bs = 64,
    device = "cuda",
    callbacks = [es],
    fp16 = True,
    epochs = 50
)

# model.save("model.bin")

# **Predict testset**

**Load test data**

In [None]:
test_dfx = pd.read_csv("../input/plant-pathology-2021-fgvc8/sample_submission.csv")
image_path = "../input/plant-pathology-2021-fgvc8/test_images/" 

# model.load("../input/resnet18-tez/model.bin")

# model

# Pkl_Filename = "../input/resnet18-tez/resnet18_trained_model.pkl"
# with open(Pkl_Filename, 'rb') as file:  
#     model = pickle.load(file)

# test_dfx.head()

**Encode label**

In [None]:
from sklearn.preprocessing import LabelEncoder

labelencoder = LabelEncoder()
test_dfx["encoded_labels"] = labelencoder.fit_transform(test_dfx["labels"])
test_dfx.head()

In [None]:
test_image_paths = [
    os.path.join(image_path, x) for x in test_dfx.image.values
]

test_target = test_dfx.encoded_labels

In [None]:
test_aug = A.Compose(
    [
        A.RandomResizedCrop(256, 256),
        A.Transpose(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.HueSaturationValue(
            hue_shift_limit=0.2, 
            sat_shift_limit=0.2,
            val_shift_limit=0.2, 
            p=0.5
        ),
        A.RandomBrightnessContrast(
            brightness_limit=(-0.1,0.1), 
            contrast_limit=(-0.1, 0.1), 
            p=0.5
        ),
#         A.Normalize(
#             mean=[0.485, 0.456, 0.406], 
#             std=[0.229, 0.224, 0.225], 
#             max_pixel_value=255.0, 
#             p=1.0
#         )
    ]
)

In [None]:
test_dataset = ImageDataset(
    image_paths = test_image_paths,
    targets = test_target,
    augmentations = test_aug
)

test_dataset[0]

In [None]:
final_preds = None
for j in range(5):
    preds = model.predict(test_dataset, batch_size=32, n_jobs=-1)
    temp_preds = None
    for p in preds:
        if temp_preds is None:
            temp_preds = p
        else:
            temp_preds = np.vstack((temp_preds, p))
    if final_preds is None:
        final_preds = temp_preds
    else:
        final_preds += temp_preds
final_preds /= 5

In [None]:
final_preds = final_preds.argmax(axis=1)
final_preds

In [None]:
test_dfx.encoded_labels = final_preds
test_dfx.head()

In [None]:
lblist = df_train.drop_duplicates(subset=['labels'])
lblist = lblist.set_index("encoded_labels")
lblist

In [None]:
# lblist.at[5, "labels"]

In [None]:
def get_labels(val):
    return lblist.at[val, "labels"]

In [None]:
pred_lists = []
for i, pred in enumerate(final_preds):
    label = get_labels(pred)
    pred_lists.append(label)
    
pred_lists

In [None]:
test_dfx["labels"] = pred_lists
test_dfx = test_dfx.drop(columns=['encoded_labels'])
test_dfx

In [None]:
test_dfx.to_csv("submission.csv", index=False)