In [None]:
!pip install /kaggle/input/timmwhl/timm-0.4.5-py3-none-any.whl

## imports

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import random
import cv2

import timm
import torch
import torch.optim as optim
from torch.optim import Adam
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision

from tqdm.notebook import tqdm

In [None]:
# Initialization
train_csv_loc= '../input/plant-pathology-2021-fgvc8/train.csv' 
test_csv_loc='../input/plant-pathology-2021-fgvc8/sample_submission.csv'
test_image_loc = '../input/plant-pathology-2021-fgvc8/test_images'
train_image_loc = '../input/plant-pathology-2021-fgvc8/train_images'

data_csv = pd.read_csv(train_csv_loc)
test_csv = pd.read_csv(test_csv_loc)
sub_csv = data_csv[:200]

image_size = 288
model_name = 'resnet200d'
seed = 719
batch_size = 32
num_workers = 0

## Processing the dataset csv

In [None]:
# GPU settings
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
class parseDataset(Dataset):
    def __init__(self, out_csv, image_loc, transform=None, test=False):
        self.out_csv = out_csv
        self.image_loc = image_loc
        self.transform = transform
        self.test = test

    def __len__(self):
        return len(self.out_csv)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.image_loc,
                                self.out_csv.iloc[idx, 0])
        image = cv2.imread(img_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         image = image.astype(np.uint8)
        if self.transform:
            image = self.transform(image)
        
        if self.test:
            return image
        else:
            categories = self.out_csv.iloc[idx,1:]
            categories = np.array(categories)
            categories = categories.astype(np.uint8)
            sample = [image, categories]

            return sample

In [None]:
data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((image_size, image_size)),
        transforms.Normalize(
         mean=[0.485, 0.456, 0.406],
         std=[0.229, 0.224, 0.225],
     ),
    ])

## define model

In [None]:
class ResNet200D(nn.Module):
    def __init__(self, model_name='resnet200d'):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=False)
        
        # load pretrained weights
        self.model.load_state_dict(torch.load('../input/resnet200dpretrainedweights/resnet200d_ra2-bdba9bf9.pth'))
        
#         self.model.conv1[0].in_channels = 1
        n_features = self.model.fc.in_features
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
      
        self.fc = nn.Sequential(
                        nn.Linear(n_features, 256),
                        nn.Dropout(p=0.2),
                        nn.Linear(256, 6),
                    )
    
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x):
        bs = x.size(0)
        features = self.model(x)
        pooled_features = self.pooling(features).view(bs, -1)
        output = self.fc(pooled_features)
        return output

In [None]:
net = ResNet200D(model_name = 'resnet200d')
net = net.to(device)
# print(net)

In [None]:
LAST_WEIGHT_PATH = '../input/plantpathweightsresnet288x288/resnet200d_fold_0_epoch_50.pth'
if device == 'cpu':
    net.load_state_dict(torch.load(LAST_WEIGHT_PATH, map_location=torch.device('cpu')))
else:
    net.load_state_dict(torch.load(LAST_WEIGHT_PATH, map_location=torch.device('cpu')))
    
print("[INFO] weights loaded to " + device)

## testing

In [None]:
labels = {'0': 'complex', '1': 'frog_eye_leaf_spot', '2': 'healthy', '3': 'powdery_mildew', '4': 'rust', '5': 'scab'}
inv_labels = {'complex': 0, 'frog_eye_leaf_spot': 1, 'healthy': 2, 'powdery_mildew': 3, 'rust': 4, 'scab': 5}

def get_labels(row, labels, ths):
    try:
        row = [i for i, x in enumerate(row) if x > ths[i]]
        row = [labels[str(i)] for i in row]
        if ('healthy' in row or len(row) == 0):
            row = 'healthy'
        elif 'complex' in row:
            row = 'complex'
        else:
            row = ' '.join(row)

    except:
        print(row)
    return row

In [None]:
criterion = nn.BCEWithLogitsLoss()

def inference(in_csv, img_loc, device, net, criterion):

    batch_size = min(len(in_csv), 25)

    # parse dataset and create dataloader
    test_dataset = parseDataset(in_csv, img_loc, transform = data_transform, test=True)
    test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle=False, num_workers=2)

    preds_list = []

    # get output
    t = tqdm(test_dataloader, desc='testing: ', colour="#557f50")
    for images in t:
        images = images.to(device)

        outputs = net(images)

        preds = outputs.sigmoid().detach().cpu().numpy()
        preds_list += [ preds ]
    return np.array(preds_list).reshape((-1, 6))

repeat = 2
tmp_p_list = []
for i in range(repeat):
    p_list = inference(test_csv, test_image_loc, device, net, criterion)
    tmp_p_list.append(p_list)
    
# average over some inferences
p_list = np.array(tmp_p_list).mean(axis = 0)

# from experiments
ths = [0.21000000000000002, 0.27, 0.51, 0.25, 0.08, 0.45]
test_csv['labels'] = [get_labels(x, labels, ths) for x in p_list]

In [None]:
test_csv

In [None]:
test_csv.to_csv("/kaggle/working/submission.csv", index=False)