In [None]:
import warnings
warnings.filterwarnings('ignore')

from glob import glob
import pandas as pd
import numpy as np 
from tqdm import tqdm
import cv2

import os
import timm
import random

import albumentations as A
from albumentations.pytorch import transforms, ToTensorV2

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import json
import joblib

from sklearn.metrics import f1_score, accuracy_score

In [None]:
# Configs
config = {}
config_path = "./config/swin_base.json"
with open(config_path, 'r') as f:
    config = json.load(f)
    f.close()

In [None]:
config['DEVICE'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class CustomDataset(Dataset):
    def __init__(self,
                 data_path,
                 size,
                 transform=None,
                 fold=0,
                 mode="train"):
        self.csv = pd.read_csv(data_path)
        if 'kfold' in self.csv:
            if mode == "train":
                self.csv = self.csv[self.csv['kfold'] != fold]
            elif mode == "validation":
                self.csv = self.csv[self.csv['kfold'] == fold]
        
        self.path = self.csv['path'].to_list()
        if 'encoded_label' in self.csv:
            self.labels = self.csv['encoded_label'].to_list()
        self.transform = transform
        self.size = size
        self.mode = mode
    
    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, idx):        
        # Image
        image = cv2.imread(self.path[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(self.size)(image=image)['image']
        
        # Only test mode
        if self.mode == "test":
            return {
                'image': image
            }
        
        # Label
        label = self.labels[idx]
        label = torch.tensor(label, dtype=torch.long)
        
        return {
            'image': image,
            'label': label
        }

In [None]:
def create_validation_transforms(size):
    return A.Compose([
        A.Resize(size, size),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

In [None]:
test_dataset = CustomDataset(
    data_path = config["TEST_CSV"],
    size = config["SIZE"],
    transform = create_validation_transforms,
    fold = config["FOLD"],
    mode = "test"
)
test_loader = DataLoader(
    dataset = test_dataset,
    shuffle = False,
    batch_size = config["BATCH_SIZE"],
    num_workers = config["N_WORKERS"]
)

In [None]:
def score_function(real, pred):
    score = f1_score(real, pred, average="macro")
    return score

In [None]:
with open('encoder.pickle', 'rb') as f:
    encoder = joblib.load(f)

In [None]:
def load_model(model_name = None, pretrained = True, num_classes = 88):
    return timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)

In [None]:
model = load_model(config["MODEL"], config["PRETRAINED"], config["N_CLASSES"]).to(config["DEVICE"])

In [None]:
best_model_name = f"{config['MODEL_SAVE_PREFIX']}_best.pth"
best_model_path = os.path.join(config['MODEL_SAVE'], best_model_name)
model_data = torch.load(best_model_path)
print(model_data["epoch"], model_data["score"], model_data["loss"])
model.load_state_dict(model_data['state_dict'])

In [None]:
model.eval()

f_pred = []
f_score = []
with torch.no_grad():
    pbar = tqdm(test_loader, total=len(test_loader))
    for batch in pbar:
        x = torch.tensor(batch['image'], dtype=torch.float32, device=config["DEVICE"])
        with torch.cuda.amp.autocast():
            pred = model(x)
        f_score.extend(pred.detach().cpu().numpy().tolist())
        y = pred.argmax(1).detach().cpu().numpy().tolist()
        y = encoder.inverse_transform(y)
        f_pred.extend(y)

In [None]:
submission = pd.read_csv(config["TEST_CSV"])
submission['label'] = f_pred
submission['score'] = f_score
del submission['file_name']
del submission['path']
submission

In [None]:
submission.to_csv(config["SUBMISSION_CSV"], index=None)