In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2

from tqdm import tqdm
from glob import glob
import os
import json 
import timm
import torch
from torch import nn
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW

from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split, StratifiedKFold
import albumentations as A

In [2]:
village_data = os.listdir('./data/public/PlantVillage')
label_encoder = {}
for idx, data_name in enumerate(village_data) :
    label_encoder[idx] = data_name

label_decoder = {val:key for key, val in label_encoder.items()}
display(label_decoder)
display(label_encoder)

{'Apple___Apple_scab': 0,
 'Apple___Black_rot': 1,
 'Apple___Cedar_apple_rust': 2,
 'Apple___healthy': 3,
 'Blueberry___healthy': 4,
 'Cherry_(including_sour)___healthy': 5,
 'Cherry_(including_sour)___Powdery_mildew': 6,
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot': 7,
 'Corn_(maize)___Common_rust_': 8,
 'Corn_(maize)___healthy': 9,
 'Corn_(maize)___Northern_Leaf_Blight': 10,
 'Grape___Black_rot': 11,
 'Grape___Esca_(Black_Measles)': 12,
 'Grape___healthy': 13,
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)': 14,
 'Orange___Haunglongbing_(Citrus_greening)': 15,
 'Peach___Bacterial_spot': 16,
 'Peach___healthy': 17,
 'Pepper,_bell___Bacterial_spot': 18,
 'Pepper,_bell___healthy': 19,
 'Potato___Early_blight': 20,
 'Potato___healthy': 21,
 'Potato___Late_blight': 22,
 'Raspberry___healthy': 23,
 'Soybean___healthy': 24,
 'Squash___Powdery_mildew': 25,
 'Strawberry___healthy': 26,
 'Strawberry___Leaf_scorch': 27,
 'Tomato___Bacterial_spot': 28,
 'Tomato___Early_blight': 29,
 'Toma

{0: 'Apple___Apple_scab',
 1: 'Apple___Black_rot',
 2: 'Apple___Cedar_apple_rust',
 3: 'Apple___healthy',
 4: 'Blueberry___healthy',
 5: 'Cherry_(including_sour)___healthy',
 6: 'Cherry_(including_sour)___Powdery_mildew',
 7: 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 8: 'Corn_(maize)___Common_rust_',
 9: 'Corn_(maize)___healthy',
 10: 'Corn_(maize)___Northern_Leaf_Blight',
 11: 'Grape___Black_rot',
 12: 'Grape___Esca_(Black_Measles)',
 13: 'Grape___healthy',
 14: 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 15: 'Orange___Haunglongbing_(Citrus_greening)',
 16: 'Peach___Bacterial_spot',
 17: 'Peach___healthy',
 18: 'Pepper,_bell___Bacterial_spot',
 19: 'Pepper,_bell___healthy',
 20: 'Potato___Early_blight',
 21: 'Potato___healthy',
 22: 'Potato___Late_blight',
 23: 'Raspberry___healthy',
 24: 'Soybean___healthy',
 25: 'Squash___Powdery_mildew',
 26: 'Strawberry___healthy',
 27: 'Strawberry___Leaf_scorch',
 28: 'Tomato___Bacterial_spot',
 29: 'Tomato___Early_blight',
 30: '

In [3]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)
 
    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

In [4]:
import timm

class DeiT(nn.Module):
    def __init__(self, model_name_1, n_classes):
        super(DeiT, self).__init__()
        self.model = timm.create_model(model_name_1, num_classes=n_classes, pretrained=True)
    
    def forward(self, inputs):
        output = self.model(inputs)
        return output
    
class EffiV2S(nn.Module):
    def __init__(self, model_name_2, n_classes):
        super(EffiV2S, self).__init__()
        self.model = timm.create_model(model_name_2, num_classes=n_classes, pretrained=True)
    
    def forward(self, inputs):
        output = self.model(inputs)
        return output
    
class total_model(nn.Module):
    def __init__(self, model_name_1, model_name_2, n_classes) :
        super(total_model, self).__init__()
        self.effi = EffiV2S(model_name_1, n_classes)
        self.DeiT = DeiT(model_name_2, n_classes)
        
    def forward(self, input_224, input_288) : 
        output1 = self.effi(input_288)
        output2 = self.DeiT(input_224)
        
        output = (output1 + output2) / 2
        return output

In [5]:
device = torch.device("cuda:0")#("cpu")
batch_size = 8
n_classes = 38
model_name_1 = 'efficientnetv2_rw_s'
model_name_2 = 'deit_small_patch16_224'
learning_rate = 1e-4
epochs = 15
save_path = './model/public_vill_50k_pretrain_EffiDeit.pt'
num_early_stopping = 3

In [6]:
class VillageDataset(Dataset) :
    def __init__(self, files, transform_224, transform_288, mode='train') :
        super(VillageDataset, self).__init__()
        self.files = files
        self.mode = mode
        
        self.transform_224 = transform_224
        self.transform_288 = transform_288
      
    def __len__(self) :
        return len(self.files)
    
    def __getitem__(self, idx) :
        file_path = self.files[idx]
        
        label = label_decoder[file_path.split('\\')[-2]]
        
        img = cv2.imread(file_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_224 = self.transform_224(image=img)['image']
        img_288 = self.transform_288(image=img)['image']
        
        img_224 = img_224.transpose(2, 0, 1)
        img_288 = img_288.transpose(2, 0, 1)
        
#         return torch.tensor(img, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
        return {
            'effi' : torch.tensor(img_288, dtype=torch.float32) / 255.0,
            'deit' : torch.tensor(img_224, dtype=torch.float32) / 255.0,
            'label' : torch.tensor(label, dtype=torch.long),
        }
#         return torch.tensor(img, dtype=torch.float32) / 255.0, torch.tensor(label, dtype=torch.long)

In [7]:
train_transforms_224 = A.Compose([
                A.Resize(224 ,224),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])

train_transforms_288 = A.Compose([
                A.Resize(288 ,288),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])

val_transforms_224 = A.Compose([
    A.Resize(224,224)
])
val_transforms_288 = A.Compose([
    A.Resize(288,288)
])

train = glob('./data/public/PlantVillage/*/*.JPG')
print("total : ", len(train))
label_list = [label_decoder[img_path.split('\\')[-2]] for img_path in train]

train, val = train_test_split(train, test_size=0.2, shuffle=True, stratify=label_list)
print("train : ", len(train))
print("val : ", len(val))

train_dataset = VillageDataset(train, train_transforms_224, train_transforms_288)
val_dataset = VillageDataset(val, val_transforms_224, val_transforms_288)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

total :  54303
train :  43442
val :  10861


In [8]:
model = total_model(model_name_1, model_name_2, n_classes)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [9]:
def accuracy_function(real, pred):    
    real = real.cpu()
    pred = torch.argmax(pred, dim=1).cpu()
    score = f1_score(real, pred, average='macro')
    return score

def train_step(batch_item, training):
    img_effi = batch_item['effi'].to(device) # 288
    img_deit = batch_item['deit'].to(device) # 224
    label = batch_item['label'].to(device)

    lam = np.random.beta(1.0, 1.0)
    
    if training is True:
        model.train()
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            # add - cutmix
            rand_index = torch.randperm(img_deit.size()[0])
            target_a = label
            target_b = label[rand_index]
            
            # 224 size 기준으로 이미지 crop
            bbx1, bby1, bbx2, bby2 = rand_bbox(img_deit.size(), lam)
            
            img_deit[:, :, bbx1:bbx2, bby1:bby2] = img_deit[rand_index, :, bbx1:bbx2, bby1:bby2]
            img_effi[:, :, bbx1:bbx2, bby1:bby2] = img_effi[rand_index, :, bbx1:bbx2, bby1:bby2]
            
            # lam 값은 공유
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img_deit.size()[-1] * img_deit.size()[-2]))
            
            output = model(img_deit, img_effi)
            loss = criterion(output, target_a) * lam + criterion(output, target_b) * (1. - lam)
#             output = model(img, csv_feature)
#             loss = criterion(output, label) 
        loss.backward()
        optimizer.step()
        score = accuracy_function(label, output)
        return loss, score
    else:
        model.eval()
        with torch.no_grad():
            output = model(img_deit, img_effi)
            loss = criterion(output, label)
        score = accuracy_function(label, output)
        return loss, score

In [None]:
loss_plot, val_loss_plot = [], []
metric_plot, val_metric_plot = [], []

early_stopping = 0
for epoch in range(epochs):
    total_loss, total_val_loss = 0, 0
    total_acc, total_val_acc = 0, 0

    tqdm_dataset = tqdm(enumerate(train_loader))
    training = True
    for batch, batch_item in tqdm_dataset:
        batch_loss, batch_acc = train_step(batch_item, training)
        total_loss += batch_loss
        total_acc += batch_acc

        tqdm_dataset.set_postfix({
            'Epoch': epoch + 1,
            'Loss': '{:06f}'.format(batch_loss.item()),
            'Mean Loss' : '{:06f}'.format(total_loss/(batch+1)),
            'Mean F-1' : '{:06f}'.format(total_acc/(batch+1))
        })
    loss_plot.append(total_loss/(batch+1))
    metric_plot.append(total_acc/(batch+1))

    tqdm_dataset = tqdm(enumerate(val_loader))
    training = False
    for batch, batch_item in tqdm_dataset:
        batch_loss, batch_acc = train_step(batch_item, training)
        total_val_loss += batch_loss
        total_val_acc += batch_acc

        tqdm_dataset.set_postfix({
            'Epoch': epoch + 1,
            'Val Loss': '{:06f}'.format(batch_loss.item()),
            'Mean Val Loss' : '{:06f}'.format(total_val_loss/(batch+1)),
            'Mean Val F-1' : '{:06f}'.format(total_val_acc/(batch+1))
        })
    val_loss_plot.append(total_val_loss/(batch+1))
    val_metric_plot.append(total_val_acc/(batch+1))
    
    if np.max(val_metric_plot) == val_metric_plot[-1]:
        torch.save(model.state_dict(), f'{save_path}')
        early_stopping = 0
    
    elif np.max(val_metric_plot) > val_metric_plot[-1]: 
        early_stopping += 1
        print(f"Early Stopping Step : [{early_stopping} / {num_early_stopping}]")
    
    if early_stopping == num_early_stopping :
        print("== Early Stop ==")
        break

5431it [21:19,  4.25it/s, Epoch=1, Loss=0.624950, Mean Loss=0.939954, Mean F-1=0.688476]
1358it [02:01, 11.22it/s, Epoch=1, Val Loss=0.009675, Mean Val Loss=0.047106, Mean Val F-1=0.980876]
5431it [23:46,  3.81it/s, Epoch=2, Loss=0.011112, Mean Loss=0.662797, Mean F-1=0.753075]
1358it [02:40,  8.44it/s, Epoch=2, Val Loss=0.006284, Mean Val Loss=0.036876, Mean Val F-1=0.986077]
5431it [28:51,  3.14it/s, Epoch=3, Loss=0.764464, Mean Loss=0.609953, Mean F-1=0.769400]
1358it [04:57,  4.57it/s, Epoch=3, Val Loss=0.007251, Mean Val Loss=0.094473, Mean Val F-1=0.990559]
2736it [16:40,  2.60it/s, Epoch=4, Loss=0.648836, Mean Loss=0.597982, Mean F-1=0.780314]