In [None]:
import sys
sys.path.append('../input/tez-lib')
sys.path.append('../input/timmmaster')

In [None]:
import argparse
import os

import cv2
import albumentations
import albumentations.pytorch
import pandas as pd
import numpy as np

import tez
import timm
import torch
import torch.nn as nn
import torchvision

from sklearn import metrics, model_selection, preprocessing
from tez.callbacks import EarlyStopping
from torch.nn import functional as F

In [None]:
INPUT_DIR = '../input/cassava-swin-transformer-tez/'

In [None]:
class CFG:
    image_size = 224
    target_size = 5
    target_col = 'label'
    model_name = 'swin_tiny_patch4_window7_224'
    epochs = 15
    batch_size = 64
    n_fold = 5
    trn_fold = [0,1,2,3,4]

In [None]:
class FlowerDataset:
    def __init__(self, image_paths, targets, augmentations):
        self.image_paths = image_paths
        self.targets = targets
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        targets = self.targets[item]
        
        image = cv2.imread(self.image_paths[item])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        augmented = self.augmentations(image = image)
        image = augmented["image"]
        
        return {
            "image": image,
            "targets": targets,
        }

In [None]:
class LeafModel(tez.Model):
    def __init__(self, pretrained = True):
        super().__init__()
        self.model = timm.create_model(model_name = CFG.model_name, pretrained = pretrained)
        self.n_features = self.model.head.in_features
        self.model.head = nn.Linear(self.n_features, CFG.target_size)
        
        self.step_scheduler_after = "epoch"
        self.step_scheduler_metric = "valid_accuracy"

    def forward(self, image, targets=None):
        batch_size, _, _, _ = image.shape

        outputs = self.model(image)
        
        if targets is not None:
            loss = nn.CrossEntropyLoss()(outputs, targets)
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, None, None

In [None]:
valid_aug = albumentations.Compose(
        [
            albumentations.Resize(CFG.image_size, CFG.image_size, p=1.0),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
            albumentations.pytorch.ToTensorV2()
        ],
        p=1.0
)

In [None]:
dfx = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
image_path = "../input/cassava-leaf-disease-classification/test_images/"
test_image_paths = [os.path.join(image_path, x) for x in dfx.image_id.values]
# fake targets
test_targets = dfx.label.values
test_dataset = FlowerDataset(
    image_paths=test_image_paths,
    targets=test_targets,
    augmentations=valid_aug,
)

In [None]:
def predict(fold):
    model = LeafModel(pretrained = False)
    model.load(INPUT_DIR+f'{CFG.model_name}_fold{fold}_best.bin', device="cuda", weights_only=True)
    predictions = model.predict(test_dataset, batch_size=32)
    return predictions

In [None]:
final_preds = None
for j in range(CFG.n_fold):
    preds = predict(j)
    temp_preds = None
    for p in preds:
        if temp_preds is None:
            temp_preds = p
        else:
            temp_preds = np.vstack((temp_preds, p))
    if final_preds is None:
        final_preds = temp_preds
    else:
        final_preds += temp_preds
final_preds /= 5

In [None]:
final_preds = final_preds.argmax(axis=1)

dfx.label = final_preds
dfx.to_csv("submission.csv", index=False)

In [None]:
dfx