In [None]:
import pandas as pd
import os
import torch
from utils.parse import parse_yaml
import json
import numpy as np
from models.resnet50v2_sn_model import Model
from tqdm import tqdm
from PIL import Image
from datetime import datetime
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
tqdm.monitor_interval = 0

# load config

In [None]:
config = parse_yaml()
whole_config = config['whole']
whole_config['logger']=False
whole_config['use_tensorboard']=False

# load species model

In [None]:
species_model = Model(whole_config)
species_model.load('multi_task_ckpt/species/2018Oct16-183710/33.pth')

# load apple model

In [None]:
apple_config = config['apple']
apple_config['logger']=False
apple_config['use_tensorboard']=False
apple_model = Model(apple_config)
apple_model.load('multi_task_ckpt/apple/2018Oct17-110146/78.pth')

# load cherry model

In [None]:
cherry_config = config['cherry']
cherry_config['logger']=False
cherry_config['use_tensorboard']=False
cherry_model = Model(cherry_config)
cherry_model.load('multi_task_ckpt/cherry/2018Oct17-133954/67.pth')

# load corn model

In [None]:
corn_config = config['corn']
corn_config['logger']=False
corn_config['use_tensorboard']=False
corn_model = Model(corn_config)
corn_model.load('multi_task_ckpt/corn/2018Oct17-142625/21.pth')

# load grape model

In [None]:
grape_config = config['grape']
grape_config['logger']=False
grape_config['use_tensorboard']=False
grape_model = Model(grape_config)
grape_model.load('multi_task_ckpt/grape/2018Oct17-223134/191.pth')

# load citrus model

In [None]:
citrus_config = config['citrus']
citrus_config['logger']=False
citrus_config['use_tensorboard']=False
citrus_model = Model(citrus_config)
citrus_model.load('multi_task_ckpt/citrus/2018Oct17-223331/9.pth')

# load peach model

In [None]:
peach_config = config['peach']
peach_config['logger']=False
peach_config['use_tensorboard']=False
peach_model = Model(peach_config)
peach_model.load('multi_task_ckpt/peach/2018Oct18-095249/49.pth')

# load pepper model

In [None]:
pepper_config = config['pepper']
pepper_config['logger']=False
pepper_config['use_tensorboard']=False
pepper_model = Model(pepper_config)
pepper_model.load('multi_task_ckpt/pepper/2018Oct18-143331/55.pth')

# load potato model

In [None]:
potato_config = config['potato']
potato_config['logger']=False
potato_config['use_tensorboard']=False
potato_model = Model(potato_config)
potato_model.load('multi_task_ckpt/potato/2018Oct18-133900/15.pth')

# load strawberry model

In [None]:
strawberry_config = config['strawberry']
strawberry_config['logger']=False
strawberry_config['use_tensorboard']=False
strawberry_model = Model(strawberry_config)
strawberry_model.load('multi_task_ckpt/strawberry/2018Oct18-134116/25.pth')

# load tomato model

In [None]:
tomato_config = config['tomato']
tomato_config['logger']=False
tomato_config['use_tensorboard']=False
tomato_model = Model(tomato_config)
tomato_model.load('multi_task_ckpt/tomato/2018Oct18-152343/28.pth')

# validation 

In [None]:
validate_data = pd.read_json(
    'data/ai_challenger_pdr2018_validationset_20180905/AgriculturalDisease_validationset/AgriculturalDisease_validation_annotations.json')
path = 'data/ai_challenger_pdr2018_validationset_20180905/AgriculturalDisease_validationset/images'
count = 0
with tqdm(total=len(validate_data)) as pbar:
    for idx, row in validate_data.iterrows():
        img_path = os.path.join(path, row.image_id)
        img = Image.open(img_path)
        compose = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.39883098, 0.48478258, 0.46141425],
                std=[0.21294028, 0.17201832, 0.1934495])
        ])
        img_tensor = compose(img)
        img_tensor = img_tensor[None, :, :, :]

        species_res = species_model.test(img_tensor).cpu().data.numpy()
        species_pred_class = int(np.argmax(species_res, axis=1)[0])

        # apple
        if species_res == 0:
            res = apple_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 0
            if res == row.disease_class:
                count = count + 1
        # cherry
        elif species_res == 1:
            res = cherry_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 6
            if res == row.disease_class:
                count = count + 1
        # corn
        elif species_res == 2:
            res = corn_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 9
            if res == row.disease_class:
                count = count + 1
        # grape
        elif species_res == 3:
            res = grape_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 17
            if res == row.disease_class:
                count = count + 1
        # citrus
        elif species_res == 4:
            res = citrus_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 24
            if res == row.disease_class:
                count = count + 1
        # peach
        elif species_res == 5:
            res = peach_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 27
            if res == row.disease_class:
                count = count + 1
        # pepper
        elif species_res == 6:
            res = pepper_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 30
            if res == row.disease_class:
                count = count + 1
        # potato
        elif species_res == 7:
            res = potato_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 33
            if res == row.disease_class:
                count = count + 1
        # strawberry
        elif species_res == 8:
            res = strawberry_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 38
            if res == row.disease_class:
                count = count + 1
         # tomato
        elif species_res == 9:
            res = tomato_model.test(img_tensor).cpu().data.numpy()
            pred_class = int(np.argmax(res, axis=1)[0])
            res = res + 41
            if res == row.disease_class:
                count = count + 1

        pbar.update(1)
        pbar.set_description(
            "Processing {}, Prediction {}, True{}".format(img_id, pred_class, row.disease_class))
acc = count/len(validate_data)
print("Validation Accuracy: {}".format(acc))

# evaluation

In [None]:
res_json = []
path = 'data/ai_challenger_pdr2018_testA_20180905/AgriculturalDisease_testA/images'
files = os.listdir(path)
with tqdm(total=len(files)) as pbar:
    for img_id in files:
        img_path = os.path.join(path, img_id)
        img = Image.open(img_path)
        compose = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.39883098, 0.48478258, 0.46141425],
                std=[0.21294028, 0.17201832, 0.1934495])
        ])
        img_tensor = compose(img)
        img_tensor = img_tensor[None, :, :, :]
        res = model.test(img_tensor).cpu().data.numpy()
        pred_class = int(np.argmax(res, axis=1)[0])
        res_json.append({"image_id": img_id, "disease_class": pred_class})
        pbar.update(1)
        pbar.set_description(
            "Processing {}, Prediction {}".format(img_id, pred_class))

# save json

In [None]:
run_timestamp = datetime.now().strftime("%Y%b%d-%H%M%S")
with open('submit/{}.json'.format(run_timestamp), 'w') as json_file:
    json.dump(res_json, json_file)