In [None]:
import numpy as np
import pandas as pd
import random
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import os
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image

In [None]:
class Model(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        check_point = models.resnext50_32x4d().state_dict()
        del check_point["fc.weight"]
        del check_point["fc.bias"]
        
        self.backbone = models.resnext50_32x4d(num_classes=num_classes)
        self.backbone.load_state_dict(check_point, strict=False)
        self.loss = nn.BCEWithLogitsLoss()
        
    def forward(self, images, labels=None):
        if self.training:
            y = self.backbone(images)
            loss = self.loss(y, labels)
            return loss
        else:
            pred = self.backbone(images)
            logits = pred.sigmoid()
            batched_labels = []
            for item in logits:
                batched_labels.append(torch.where(item > 0.5)[0].tolist())
            return batched_labels

In [None]:
class Dataset:
    def __init__(self):
        self.root = "/kaggle/input/plant-pathology-2021-fgvc8/test_images"
        self.label_map = {
            'complex':["疑难复杂", 0], 
            'rust':["锈菌，生锈", 1], 
            'scab': ["疮痂病，斑点病", 2], 
            'frog_eye_leaf_spot': ["青蛙眼叶斑", 3], 
            'healthy': ["健康的", 4], 
            'powdery_mildew': ["白粉病", 5]
        }
        self.index_to_name = {
            self.label_map[key][1]: key for key in self.label_map
        }
        self.num_classes = len(self.label_map)
        self.files = []
        for dirname, _, filenames in os.walk(self.root):
            for filename in filenames:
                self.files.append([os.path.join(dirname, filename), filename])
                
        self.trans = T.Compose([
            #T.RandomHorizontalFlip(),
            T.Resize((300, 300)),
            T.ToTensor(),
            T.Normalize(mean=0.5, std=1.0)
        ])
        
    def __getitem__(self, index):
        file, name = self.files[index]
        image = Image.open(file)
        image = self.trans(image)
        return image, name
        
    def __len__(self):
        return len(self.files)

In [None]:
batch_size = 32
device = "cuda:0"
dataloader = torch.utils.data.DataLoader(Dataset(), batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=3)
model = Model(dataloader.dataset.num_classes)
checkpoint = torch.load("../input/model1pth/030.pth", map_location="cpu")
model.load_state_dict(checkpoint)
model.to(device)
_ = model.eval()

all_predict = []
for images, names in dataloader:
    images = images.to(device)
    batched_labels = model(images)
    for name, labels in zip(names, batched_labels):
        all_predict.append([name, " ".join([dataloader.dataset.index_to_name[index] for index in labels])])

data = pd.DataFrame(all_predict, columns=("image", "labels"))
data.to_csv("submission.csv", index=False)
data