In [1]:
## Import Library ##
import warnings 
warnings.filterwarnings('ignore')
import os

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import cv2
from PIL import Image

# torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

# image processing
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, Flip,
    IAAAdditiveGaussianNoise, Transpose
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

In [2]:
!pip install timm
import timm

In [3]:
## Configuration ##
class CFG:
    num_workers=4
    size=324 
    size_swin = 224
    batch_size=32
    num_classes= 12
    trn_fold=[0, 1, 2, 3]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
## Directory path ##
ckpt_path = '../input/d/minjuro/plant-seedlings-cls-checkpoint/checkpoint_swin/'
TEST_PATH = '../input/plant-seedlings-classification/test'

In [5]:
convert_classes = {0: 'Black-grass',
         1: 'Charlock',
         2: 'Cleavers',
         3: 'Common Chickweed',
         4: 'Common wheat',
         5: 'Fat Hen',
         6: 'Loose Silky-bent',
         7: 'Maize',
         8: 'Scentless Mayweed',
         9: 'Shepherds Purse',
         10: 'Small-flowered Cranesbill',
         11: 'Sugar beet'}

In [6]:
## load test dataset ##
test = pd.DataFrame(columns=['file'])
pathToTestData='../input/plant-seedlings-classification/test'

for dirname, _, filenames in os.walk(pathToTestData):
    for filename in filenames:
        file = filename
        test = test.append({'file': file}, ignore_index = True)
test.head(3)

In [7]:
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['file'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [8]:
## Transform ##
def get_transforms(*, size, data):
    
    if data == 'train':
        return Compose([
            Resize(size, size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

    elif data == 'valid':
        return Compose([
            Resize(size, size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [9]:
## Model ##
class customModel(nn.Module):
    def __init__(self, model_name, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.fc_layer = nn.Linear(1000,CFG.num_classes)

    def forward(self, x):
        x = self.model(x)
        out = self.fc_layer(x)
        return out

In [10]:
## inf function ##
def inference(model, states, test_loader, device):
    model.to(device)
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in tk0:
        images = images.to(device)
        avg_preds = []
        for state in states:
            model.load_state_dict(state['model'])
            model.eval()
            with torch.no_grad():
                y_preds = model(images)
            avg_preds.append(y_preds.softmax(1).to('cpu').numpy())
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
    probs = np.concatenate(probs)
    return probs

In [11]:
def main(ckpt_path, mod_name):
    model = customModel(mod_name, pretrained=False)
    states = [torch.load(ckpt_path+f'{mod_name}_fold{fold}_best.pth') for fold in CFG.trn_fold]
    
    if mod_name == "swin_base_patch4_window7_224":
        test_dataset = TestDataset(test, transform=get_transforms(size = CFG.size_swin, data='valid'))
    else:
        test_dataset = TestDataset(test, transform=get_transforms(size = CFG.size, data='valid'))

    
    test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, 
                             num_workers=CFG.num_workers, pin_memory=True)
    predictions = inference(model, states, test_loader, device)
    return predictions

In [12]:
result = main(ckpt_path, "swin_base_patch4_window7_224")

In [13]:
## Submission File ##
test['species'] = result.argmax(1)
for i in range(len(test)):
    test.species[i] = convert_classes[test.species[i]]
test.to_csv('./submission.csv', index=False)