# Import

In [None]:
import os
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF

import PIL
import cv2

import timm

import albumentations as A
from albumentations import (
    Compose, OneOf, Normalize, CenterCrop, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, RandomRotate90, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, HueSaturationValue, CoarseDropout
    )
from albumentations.pytorch import ToTensorV2

from joblib import Parallel, delayed

from PIL import Image
from PIL import ImageFile

import warnings
warnings.filterwarnings("ignore")

import logging
import time
from contextlib import contextmanager

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Directly Settings

In [None]:
DATA_DIR = '../input/plant-pathology-2021-fgvc8/'
TEST_DIR = DATA_DIR + 'test_images/'
TRAIN_DIR = DATA_DIR + 'train_images/'
TRAIN_CSV_DIR = DATA_DIR + 'train.csv'
train_df = pd.read_csv(TRAIN_CSV_DIR)

In [None]:
TRAIN_DIR

In [None]:
train_df.shape

In [None]:
train_df.head(10)

In [None]:
le = LabelEncoder()
le.fit(train_df['labels'])
train_df['labels'] = le.transform(train_df['labels'])
train_df

In [None]:
train_df['labels'].unique()

In [None]:
tmp = train_df["image"].values
tmp

In [None]:
tmp2 = train_df["labels"].values
tmp2

In [None]:
len(train_df)

In [None]:
DEBUG = True
if DEBUG:
    train_df = train_df.sample(frac = 0.01).reset_index(drop = True)
    print(train_df.shape)

# Config

In [None]:
#../input/timm-pytorch-image-models/pytorch-image-models-master/timm/models/resnest.py
CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'resnet18',
    'img_size': 224,
    'epochs': 3,
    'train_bs': 32,
    'valid_bs': 32,
    'lr': 1e-4,
    'num_workers': 4,
    'accum_iter': 1,
    'verbose_step': 1,
    'device': 'cuda:0',
    'used_folds':[0,2,3],
    'used_epochs': [7,8,9],
    'tta': 4
}

@contextmanager
def timer(name, logger=None, level=logging.DEBUG):
    print_ = print if logger is None else lambda msg: logger.log(level, msg)
    t0 = time.time()
    print_(f'[{name}] start')
    yield
    print_(f'[{name}] done in {time.time() - t0:.0f} s')

TARGET_COLS = ['healthy', 'scab frog_eye_leaf_spot complex', 'scab', 'complex',
               'rust', 'frog_eye_leaf_spot', 'powdery_mildew',
               'scab frog_eye_leaf_spot', 'frog_eye_leaf_spot complex',
               'rust frog_eye_leaf_spot', 'powdery_mildew complex',
               'rust complex']

In [None]:
TARGET_COLS

# Split data

In [None]:
from sklearn.model_selection import train_test_split
train, valid =  train_test_split(train_df, test_size = 0.1)
print(train.shape, valid.shape)

# Dataset

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

## CSVファイルからデータを読み込む

In [None]:
#pytorchのDatasetクラスを継承したクラスを作成する
class TrainDataset(Dataset):
    def __init__(self, train_df, transform = None):
        self.train_df = train_df
        self.image_names = train_df["image"].values
        self.labels = train_df["labels"].values
        self.transform = transform
        
    def __len__(self):
        return len(self.train_df)
    
#indexに対応する画像とラベルを返す関数
    def __getitem__(self, idx):
        image_name = self.image_names[idx] #indexに対応するimageの値
        image_path = TRAIN_DIR + image_name #indexに対応するデータのパス
        image = cv2.imread(image_path) #画像読み込み
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) #BGR→RGB
        #label = self.labels[idx] #indexに対応するラベル
        label = torch.tensor(self.labels[idx]).float()
        if self.transform: #前処理ある場合
            augmented = self.transform(image=image)
            image = augmented['image']
        return image, label

In [None]:
class TestDataset(Dataset):
    def __init__(self, train_df, transform = None):
        self.train_df = train_df
        self.image_names = train_df["image"].values
        self.transform = transform
        
    def __len__(self):
        return len(self.train_df)
    
    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = TEST_DIR + image_name
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        #image = Resize(IMAGE_SIZE, IMAGE_SIZE)(image = image)["image"]
        #image = ToTensorV2()(image = image)["image"]
        return image

In [None]:
def get_transforms(*,data):
    
    if data == 'train':
        return Compose([
            Resize(CFG['img_size'], CFG['img_size']),
            RandomResizedCrop(CFG['img_size'], CFG['img_size'], scale=(0.85, 1.0)),
            HorizontalFlip(p=0.5),
            Normalize(
                mean=[0.48, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    elif data == 'valid':
        return Compose([
            Resize(CFG['img_size'], CFG['img_size']),
            #RandomResizedCrop(600, 600, scale=(0.85, 1.0)),
            #HorizontalFlip(p=0.5),
            Normalize(
                mean=[0.48, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [None]:
train_dataset = TrainDataset(train_df, transform = get_transforms(data = 'train'))
train_dataset[0]

In [None]:
for i in range(5):
    image, label = train_dataset[i]
    plt.imshow(image[0])
    plt.title(f'label: {label}')
    plt.show() 

In [None]:
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True, num_workers = 2, drop_last = True)
train_dataset[0]

In [None]:
valid_dataset = TrainDataset(train_df, transform = get_transforms(data = 'valid'))
valid_loader = DataLoader(valid_dataset, batch_size = 32, shuffle = False)
valid_dataset[0]

In [None]:
from pprint import pprint
pprint(timm.list_models(pretrained = True))

In [None]:
class EfficientNetB4(nn.Module):
    def __init__(self):
        super().__init__()
        self.efficientnetb4 = timm.create_model(model_name = 'efficientnet_b4',pretrained = False)
        n_features = self.efficientnetb4.classifier.in_features
        self.efficientnetb4.classifier = nn.Linear(n_features, len(TARGET_COLS))
        
    def forward(self, x):
        x = self.efficientnetb4(x)
        return x
    
model = EfficientNetB4()
model = model.to(DEVICE)
print(DEVICE)

batchsize = 16で12個のクラスに分けたいから出力は入力と同じ[16,12]じゃないといけない
**tensor([3., 1., 9., 9., 3., 3., 1., 1., 9., 3., 0., 9., 1., 9., 1., 9.],
       device='cuda:0')**

## Target size (torch.Size([16])) must be the same as input size (torch.Size([16, 12]))
yは**tensor([3., 1., 9., 9., 3., 3., 1., 1., 9., 3., 0., 9., 1., 9., 1., 9.]**</br>
predは**tensor([[ 3.3222e-03, -6.3023e-02, -9.8400e-02,  5.8039e-02,  3.5130e-02,
         -3.7843e-02,  1.0272e-01,  5.2852e-02,  8.2996e-02, -6.2220e-02,
          1.7350e-02,  1.0375e-01],
        [ 4.6247e-02, -1.2377e-02, -2.6026e-02, -1.2311e-02,  8.9585e-02,
         -3.0924e-02,  3.8724e-02,  4.6246e-02, -4.5003e-02,  4.3646e-04,
         -6.2441e-02,  1.3156e-02],
        [-6.2680e-02,  2.9954e-02, -6.1780e-02, -4.3327e-02,  9.0033e-02,
         -7.6068e-02,  6.6231e-02, -2.4851e-04,  1.4230e-02, -2.5715e-03,
          2.7242e-02,  8.6976e-02],
        [-8.5043e-02, -1.7327e-02, -5.5150e-02,  1.6316e-02,  5.3273e-02,
         -8.4093e-02,  5.6221e-02,  2.5884e-02,  1.6135e-01, -3.7092e-03,
          2.5990e-02,  1.5509e-01],
        [ 2.3410e-02, -1.7971e-02,  1.9343e-03,  2.8661e-02,  1.5071e-01,
         -5.5160e-02,  1.0588e-01,  1.3510e-01, -8.2632e-02,  2.4116e-02,
         -2.6500e-02, -7.0092e-02],
        [ 3.7134e-02, -2.0467e-02, -1.3438e-02, -6.4493e-03,  8.6707e-02,
         -5.6956e-02,  8.9093e-02,  6.9034e-02, -4.7279e-02, -9.1909e-03,
         -3.9923e-02, -2.5610e-02],
        [ 8.2467e-03, -5.3051e-02, -5.5619e-02, -2.7187e-03,  5.5326e-02,
         -5.7907e-02,  3.9816e-02, -2.6238e-02,  5.2140e-02, -3.8826e-02,
         -3.4793e-03,  1.0258e-01],
        [ 3.2859e-02, -1.4498e-02,  1.2448e-02,  1.2810e-02,  1.0626e-01,
         -7.5030e-02,  8.7327e-02,  1.0942e-01, -5.5229e-02,  2.2528e-02,
         -3.3347e-02, -5.8272e-02],
        [-4.8386e-02, -1.8281e-02, -3.5381e-02, -2.1936e-02,  1.2283e-01,
         -7.6297e-02,  6.5367e-02,  2.9805e-03,  2.9422e-02, -5.3792e-02,
          2.4696e-03,  1.1347e-01],
        [ 3.9974e-02, -9.4892e-03, -2.2844e-02, -3.9816e-02,  1.1815e-01,
         -6.9634e-02,  6.7499e-02,  6.5805e-02, -9.0425e-02,  4.7688e-03,
         -4.0652e-02, -2.0793e-02],
        [-2.2975e+00, -1.0275e+00,  8.9758e-02,  4.4820e-02,  1.5578e-02,
          6.4363e-01,  3.9178e-01, -6.2643e-01,  1.2151e+00,  5.8945e-01,
          4.7740e-01,  1.2402e+00],
        [ 2.3138e-02, -1.8476e-02, -3.9248e-03,  1.2180e-02,  1.1278e-01,
         -6.4264e-02,  8.0766e-02,  1.0090e-01, -7.6118e-02, -3.9621e-03,
         -3.6371e-02, -1.6388e-02],
        [ 1.1618e-02, -6.2424e-02, -1.4368e-02, -1.5579e-02,  1.3855e-01,
         -5.9373e-02,  5.6734e-02,  5.6160e-02, -3.5171e-02, -2.4320e-02,
         -1.7437e-02,  6.4983e-02],
        [ 4.6946e-02, -2.5251e-02, -3.3545e-02, -3.9220e-02,  8.5006e-02,
         -3.9864e-02,  5.8532e-02,  7.7934e-02, -1.5106e-02,  1.3025e-02,
         -4.7164e-02,  1.7300e-02],
        [ 3.7800e-02, -3.4233e-04, -9.8017e-03, -5.2540e-03,  1.2891e-01,
         -8.4462e-02,  7.8746e-02,  1.0758e-01, -6.7346e-02,  9.4944e-03,
         -3.6443e-02, -3.8671e-02],
        [ 6.0334e-02,  2.2638e-02, -4.2752e-02, -3.1945e-03,  7.4823e-02,
         -7.1255e-02,  7.4871e-02,  5.8197e-02, -1.6085e-02,  1.0187e-02,
         -4.8222e-02, -1.2406e-02]]**

In [None]:
with timer('training'):
    
    model = EfficientNetB4().to(DEVICE)

    #criterion:LogisticLoss
    criterion = nn.BCEWithLogitsLoss()
    #optimizer:Adam
    optimizer = torch.optim.Adam(model.parameters())

    best_loss = np.inf
for epoch in range(10):
        model.train()
        for X, y in train_loader:
            optimizer.zero_grad()
            X = X.float().to(DEVICE)
            y = y.float().to(DEVICE)
            pred = model(X)
            print(pred.shape)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
        model.eval()
        valid_loss = 0
        with torch.no_grad():
            for X, y in valid_loader:
                X = X.float().to(DEVICE)
                y = y.float().to(DEVICE)
                pred = model(X)
                loss = criterion(pred, y)
                valid_loss += loss.item()
        valid_loss /= len(valid_loader)
        print(f"EPOCH:{epoch}, Loss:{valid_loss}")
        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), "MyEfficientNetb4.pth")
            print("saved")


In [None]:
test_dataset = TestDataset(test_df, transform = get_transforms(data='valid'))
test_loader = DataLoader(test_dataset, batch_size = 32, shuffle = False)

In [None]:
submit_preds = []

model.eval()
with torch.no_grad():
    for X in test_loader:
        X = X.float().to(DEVICE)
        submit_preds.append(model(X).sigmoid().to("cpu"))
    submit_preds = np.concatenate([p.numpy() for p in submit_preds], axis = 0)

In [None]:
submit = pd.DataFrame(submit_preds, columns = TARGET_COLUMNS)
submit.head(10)