In [None]:
import os
import cv2
import timm
import torch
import random
import numpy as np
import ttach as tta
import pandas as pd
import torch.nn as nn
import albumentations
import albumentations.pytorch

from glob import glob
from tqdm.auto import tqdm
from prettyprinter import cpprint
from torch.utils.data import Dataset, DataLoader

In [None]:
class TestDataset(Dataset):
    def __init__(self, data_path='./test/0/', transform=None):
        self.data_path = data_path
        self.data = os.listdir(data_path)
        self.transform = transform
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):        
        image_path = os.path.join(self.data_path, self.data[idx])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image=np.array(image))['image']
            
        return image

In [None]:
SEED = 7777  
BATCH_SIZE = 32    
IMAGE_SIZE = 227
MODEL_ARC = 'xception'
NUM_CLASSES = 7
MODEL_DIR = './results'
NUM_FOLD = 5

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# Fix random seed
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore

In [None]:
seed_everything(SEED)

In [None]:
test_transform = albumentations.Compose([               
        albumentations.Resize(IMAGE_SIZE, IMAGE_SIZE),
        albumentations.Normalize(mean=(0.4569, 0.5074, 0.5557), std=(0.2888, 0.2743, 0.2829)),
        albumentations.pytorch.transforms.ToTensorV2()])

In [None]:
test_dataset = TestDataset(transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Model

In [None]:
class PretrainedModel(nn.Module):
    def __init__(self, model_arc='swin_tiny_patch4_window7_224', num_classes=7):
        super().__init__()
        self.net = timm.create_model(model_arc, pretrained=False, num_classes=num_classes)
    
    def forward(self, x):
        x = self.net(x)

        return x

In [None]:
model = PretrainedModel(model_arc=MODEL_ARC, num_classes=NUM_CLASSES)
model.to(device)

In [None]:
states = [torch.load(glob(MODEL_DIR + f'/{MODEL_ARC}/{k}_fold/*.pth')[-1]) for k in range(1, NUM_FOLD + 1)]

In [None]:
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        # tta.VerticalFlip(),
        # tta.Multiply(factors=[0.9, 1, 1.1])
    ]
)

In [None]:
probs = []
save_ = []
for i, images in enumerate(tqdm(test_loader)):
    images = images.to(device)
    avg_preds = []
    for state in states:
        model.load_state_dict(state)
        model.eval()
        tta_model = tta.ClassificationTTAWrapper(model, transforms)
        tta_model.to(device)
        tta_model.eval()
        with torch.no_grad():
            logits = tta_model(images)
        avg_preds.append(logits.to('cpu').numpy())
    avg_preds = np.mean(avg_preds, axis=0)
    save_.append(avg_preds)
    probs.append(avg_preds.argmax(-1))
save_ = np.concatenate(save_)
probs = np.concatenate(probs)

In [None]:
df = pd.read_csv('./test_answer_sample_.csv')

In [None]:
len(probs)

In [None]:
save_.shape

In [None]:
np.save(f'./{MODEL_ARC}.npy', save_)

In [None]:
df['answer value'] = probs

In [None]:
df.to_csv(f'submission_{MODEL_ARC}.csv', index=False)