In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2

from tqdm import tqdm
from glob import glob
import os
import json 
import timm
import torch
from torch import nn
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW
import torch.nn.functional as F

from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split, StratifiedKFold
import albumentations as A

# Training Dataset 준비

In [2]:
# 변수 설명 csv 파일 참조
crop = {'1':'딸기','2':'토마토','3':'파프리카','4':'오이','5':'고추','6':'시설포도'}
disease = {'1':{'a1':'딸기잿빛곰팡이병','a2':'딸기흰가루병','b1':'냉해피해','b6':'다량원소결핍 (N)','b7':'다량원소결핍 (P)','b8':'다량원소결핍 (K)'},
           '2':{'a5':'토마토흰가루병','a6':'토마토잿빛곰팡이병','b2':'열과','b3':'칼슘결핍','b6':'다량원소결핍 (N)','b7':'다량원소결핍 (P)','b8':'다량원소결핍 (K)'},
           '3':{'a9':'파프리카흰가루병','a10':'파프리카잘록병','b3':'칼슘결핍','b6':'다량원소결핍 (N)','b7':'다량원소결핍 (P)','b8':'다량원소결핍 (K)'},
           '4':{'a3':'오이노균병','a4':'오이흰가루병','b1':'냉해피해','b6':'다량원소결핍 (N)','b7':'다량원소결핍 (P)','b8':'다량원소결핍 (K)'},
           '5':{'a7':'고추탄저병','a8':'고추흰가루병','b3':'칼슘결핍','b6':'다량원소결핍 (N)','b7':'다량원소결핍 (P)','b8':'다량원소결핍 (K)'},
           '6':{'a11':'시설포도탄저병','a12':'시설포도노균병','b4':'일소피해','b5':'축과병'}}
risk = {'1':'초기','2':'중기','3':'말기'}

In [3]:
label_description = {}
for key, value in disease.items():
    label_description[f'{key}_00_0'] = f'{crop[key]}_정상'
    for disease_code in value:
        for risk_code in risk:
            label = f'{key}_{disease_code}_{risk_code}'
            label_description[label] = f'{crop[key]}_{disease[key][disease_code]}_{risk[risk_code]}'
list(label_description.items())[:10]

[('1_00_0', '딸기_정상'),
 ('1_a1_1', '딸기_딸기잿빛곰팡이병_초기'),
 ('1_a1_2', '딸기_딸기잿빛곰팡이병_중기'),
 ('1_a1_3', '딸기_딸기잿빛곰팡이병_말기'),
 ('1_a2_1', '딸기_딸기흰가루병_초기'),
 ('1_a2_2', '딸기_딸기흰가루병_중기'),
 ('1_a2_3', '딸기_딸기흰가루병_말기'),
 ('1_b1_1', '딸기_냉해피해_초기'),
 ('1_b1_2', '딸기_냉해피해_중기'),
 ('1_b1_3', '딸기_냉해피해_말기')]

In [4]:
# ============= add
labels = pd.read_csv('./data/train.csv')

train_label_encoder = {}
label_cnt = 0
previous_label = '0_00_0'
for i, label in enumerate(tqdm(sorted(labels['label']))) :
    crop_val = label.split('_')[0] # crop
    disease_val = label.split('_')[1] # disease
    risk_val = label.split('_')[2] # risk
    
    tmp_label = f'{crop_val}_{disease_val}_{risk_val}'
    if previous_label != tmp_label :
        train_label_encoder[tmp_label] = label_cnt
        previous_label = tmp_label
        label_cnt += 1
        
train_label_decoder = {val : key for key, val in train_label_encoder.items()}
display(train_label_decoder)
display(train_label_encoder)

100%|██████████████████████████████████████████████████████████████████████████| 5767/5767 [00:00<00:00, 963725.69it/s]


{0: '1_00_0',
 1: '2_00_0',
 2: '2_a5_2',
 3: '3_00_0',
 4: '3_a9_1',
 5: '3_a9_2',
 6: '3_a9_3',
 7: '3_b3_1',
 8: '3_b6_1',
 9: '3_b7_1',
 10: '3_b8_1',
 11: '4_00_0',
 12: '5_00_0',
 13: '5_a7_2',
 14: '5_b6_1',
 15: '5_b7_1',
 16: '5_b8_1',
 17: '6_00_0',
 18: '6_a11_1',
 19: '6_a11_2',
 20: '6_a12_1',
 21: '6_a12_2',
 22: '6_b4_1',
 23: '6_b4_3',
 24: '6_b5_1'}

{'1_00_0': 0,
 '2_00_0': 1,
 '2_a5_2': 2,
 '3_00_0': 3,
 '3_a9_1': 4,
 '3_a9_2': 5,
 '3_a9_3': 6,
 '3_b3_1': 7,
 '3_b6_1': 8,
 '3_b7_1': 9,
 '3_b8_1': 10,
 '4_00_0': 11,
 '5_00_0': 12,
 '5_a7_2': 13,
 '5_b6_1': 14,
 '5_b7_1': 15,
 '5_b8_1': 16,
 '6_00_0': 17,
 '6_a11_1': 18,
 '6_a11_2': 19,
 '6_a12_1': 20,
 '6_a12_2': 21,
 '6_b4_1': 22,
 '6_b4_3': 23,
 '6_b5_1': 24}

# Custom Dataset 선언

In [5]:
class CustomDataset(Dataset):
    def __init__(self, files, transform_224, transform_288, mode='train'):
        self.mode = mode
        self.files = files
        self.label_encoder = train_label_encoder #label_encoder
        
        self.transform_224 = transform_224
        self.transform_288 = transform_288
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, i):
        file = self.files[i]
        file_name = file.split('\\')[-1]
        
        
        # image
        image_path = f'{file}/{file_name}.jpg'
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#         img = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_AREA)
#         img = img.astype(np.float32)/255

        img_224 = self.transform_224(image=img)['image']
        img_288 = self.transform_288(image=img)['image']
        
        img_224 = img_224.transpose(2, 0, 1)
        img_288 = img_288.transpose(2, 0, 1)
        
        if self.mode == 'train':
            json_path = f'{file}/{file_name}.json'
            with open(json_path, 'r') as f:
                json_file = json.load(f)
            
            crop = json_file['annotations']['crop']
            disease = json_file['annotations']['disease']
            risk = json_file['annotations']['risk']
            label = f'{crop}_{disease}_{risk}'
            
            return {
#                 'img' : torch.tensor(img, dtype=torch.float32),
                'effi' : torch.tensor(img_288, dtype=torch.float32) / 255.0,
                'deit' : torch.tensor(img_224, dtype=torch.float32) / 255.0,
                'label' : torch.tensor(self.label_encoder[label], dtype=torch.long)
            }
        else:
            return {
                'effi' : torch.tensor(img_288, dtype=torch.float32) / 255.0,
                'deit' : torch.tensor(img_224, dtype=torch.float32) / 255.0
#                 'img' : torch.tensor(img, dtype=torch.float32)
            }

# Model 선언

In [6]:
import timm

class DeiT(nn.Module):
    def __init__(self, model_name, n_classes):
        super(DeiT, self).__init__()
        self.model = timm.create_model(model_name, num_classes=n_classes, pretrained=True)
    
    def forward(self, inputs):
        output = self.model(inputs)
        return output
    
class EffiV2S(nn.Module):
    def __init__(self, model_name, n_classes):
        super(EffiV2S, self).__init__()
        self.model = timm.create_model(model_name, num_classes=n_classes, pretrained=True)
    
    def forward(self, inputs):
        output = self.model(inputs)
        return output
    
class total_model(nn.Module):
    def __init__(self, model_name_1, model_name_2, n_classes) :
        super(total_model, self).__init__()
        self.effi = EffiV2S(model_name_1, n_classes)
        self.DeiT = DeiT(model_name_2, n_classes)
        
    def forward(self, input_224, input_288) : 
        output1 = self.effi(input_288)
        output2 = self.DeiT(input_224)
        
        output = (output1 + output2) / 2
        return output
    

class EffiDeiT(nn.Module):
    def __init__(self, model_name_1, model_name_2, n_classes, test=False) :
        super(EffiDeiT, self).__init__()
        self.total_model = total_model(model_name_1, model_name_2, 38)
        
        if test == False :
            self.total_model.load_state_dict(torch.load(model_path, map_location=device))
            
        self.fc = nn.Linear(38, n_classes)
        
        
    def forward(self, input_224, input_288) :      
        output = self.total_model(input_224, input_288)
        output = self.fc(output)
        
        return output

# Label 선언

In [7]:
json_path = glob('./data/train/*/*.json')

labels = []
for path in tqdm(json_path) :
    json_file = json.load(open(path, 'r'))
    
    crop = json_file['annotations']['crop']
    disease = json_file['annotations']['disease']
    risk = json_file['annotations']['risk']
    
    label = f'{crop}_{disease}_{risk}'
    labels.append(train_label_encoder[label])

100%|████████████████████████████████████████████████████████████████████████████| 5767/5767 [00:01<00:00, 3001.72it/s]


# transform

In [8]:
train_transforms_224 = A.Compose([
                A.Resize(224 ,224),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])

train_transforms_288 = A.Compose([
                A.Resize(288 ,288),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])

val_transforms_224 = A.Compose([
    A.Resize(224,224)
])
val_transforms_288 = A.Compose([
    A.Resize(288,288)
])

# Cutmix

In [9]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)
 
    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# Training method

In [10]:
def accuracy_function(real, pred):    
    real = real.cpu()
    pred = torch.argmax(pred, dim=1).cpu()
    score = f1_score(real, pred, average='macro')
    return score

def train_step(batch_item, training):
    img_effi = batch_item['effi'].to(device) # 288
    img_deit = batch_item['deit'].to(device) # 224
    label = batch_item['label'].to(device)

    lam = np.random.beta(1.0, 1.0)
    
    if training is True:
        model.train()
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            # add - cutmix
            rand_index = torch.randperm(img_deit.size()[0])
            target_a = label
            target_b = label[rand_index]
            
            # 224 size 기준으로 이미지 crop
            bbx1, bby1, bbx2, bby2 = rand_bbox(img_deit.size(), lam)
            
            img_deit[:, :, bbx1:bbx2, bby1:bby2] = img_deit[rand_index, :, bbx1:bbx2, bby1:bby2]
            img_effi[:, :, bbx1:bbx2, bby1:bby2] = img_effi[rand_index, :, bbx1:bbx2, bby1:bby2]
            
            # lam 값은 공유
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img_deit.size()[-1] * img_deit.size()[-2]))
            
            output = model(img_deit, img_effi)
            loss = criterion(output, target_a) * lam + criterion(output, target_b) * (1. - lam)
#             output = model(img, csv_feature)
#             loss = criterion(output, label) 
        loss.backward()
        optimizer.step()
        score = accuracy_function(label, output)
        return loss, score
    else:
        model.eval()
        with torch.no_grad():
            output = model(img_deit, img_effi)
            loss = criterion(output, label)
        score = accuracy_function(label, output)
        return loss, score

# HyperParameter

In [11]:
device = torch.device("cuda:0")#("cpu")
batch_size = 8
learning_rate = 1e-4
epochs = 15

model_name_1 = 'efficientnetv2_rw_s'
model_name_2 = 'deit_small_patch16_224'
n_classes = 25

model_path = './model/public_vill_50k_pretrain_EffiDeit.pt'
save_path = 'public_vill_50k_pretrain_EffiDeit.pt'

early_stopping_cnt = 4
fold_n = 4

# Call model

In [12]:
model = EffiDeiT(model_name_1, model_name_2, n_classes)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# k fold Training

In [16]:
data_list = glob('./data/train/*')
label_list = labels

kfold = StratifiedKFold(n_splits=fold_n, random_state=13, shuffle=True)

In [17]:
for k, (fold_train, fold_val) in enumerate(kfold.split(data_list, label_list), 1) :
    model = EffiDeiT(model_name_1, model_name_2, n_classes)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    train_data_list = []
    val_data_list = []
    
    for k_train in fold_train :
        train_data_list.append(data_list[k_train])
    
    for k_val in fold_val :
        val_data_list.append(data_list[k_val])
    
    print(f"\n\n\n===== k_fold : {k} / {fold_n} =====")
    train_dataset = CustomDataset(train_data_list, train_transforms_224, train_transforms_288)
    val_dataset = CustomDataset(val_data_list, val_transforms_224, val_transforms_288)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    loss_plot, val_loss_plot = [], []
    metric_plot, val_metric_plot = [], []
    
    early_stopping = 0
    for epoch in range(epochs):
        total_loss, total_val_loss = 0, 0
        total_acc, total_val_acc = 0, 0

        tqdm_dataset = tqdm(enumerate(train_dataloader))
        training = True
        for batch, batch_item in tqdm_dataset:
            batch_loss, batch_acc = train_step(batch_item, training)
            total_loss += batch_loss
            total_acc += batch_acc

            tqdm_dataset.set_postfix({
                'Epoch': epoch + 1,
                'Loss': '{:06f}'.format(batch_loss.item()),
                'Mean Loss' : '{:06f}'.format(total_loss/(batch+1)),
                'Mean F-1' : '{:06f}'.format(total_acc/(batch+1))
            })
        loss_plot.append(total_loss/(batch+1))
        metric_plot.append(total_acc/(batch+1))

        tqdm_dataset = tqdm(enumerate(val_dataloader))
        training = False
        for batch, batch_item in tqdm_dataset:
            batch_loss, batch_acc = train_step(batch_item, training)
            total_val_loss += batch_loss
            total_val_acc += batch_acc

            tqdm_dataset.set_postfix({
                'Epoch': epoch + 1,
                'Val Loss': '{:06f}'.format(batch_loss.item()),
                'Mean Val Loss' : '{:06f}'.format(total_val_loss/(batch+1)),
                'Mean Val F-1' : '{:06f}'.format(total_val_acc/(batch+1))
            })
        val_loss_plot.append(total_val_loss/(batch+1))
        val_metric_plot.append(total_val_acc/(batch+1))
        
        # scheduler
#         scheduler.step()
        
        if np.max(val_metric_plot) == val_metric_plot[-1]:
            torch.save(model.state_dict(), f'{k}_{save_path}')
            early_stopping = 0
    
        elif np.max(val_metric_plot) > val_metric_plot[-1]: 
            early_stopping += 1
            print(f"Early Stopping Step : [{early_stopping} / {early_stopping_cnt}]")

        if early_stopping == early_stopping_cnt :
            print("== Early Stop ==")
            break




===== k_fold : 1 / 4 =====


541it [02:25,  3.71it/s, Epoch=1, Loss=1.262334, Mean Loss=1.483260, Mean F-1=0.484079]
181it [00:20,  8.97it/s, Epoch=1, Val Loss=0.534476, Mean Val Loss=0.371048, Mean Val F-1=0.779719]
541it [02:21,  3.83it/s, Epoch=2, Loss=0.374705, Mean Loss=0.984043, Mean F-1=0.643270]
181it [00:20,  8.82it/s, Epoch=2, Val Loss=0.369761, Mean Val Loss=0.215696, Mean Val F-1=0.856405]
541it [02:25,  3.71it/s, Epoch=3, Loss=1.104398, Mean Loss=0.867648, Mean F-1=0.652309]
181it [00:22,  8.14it/s, Epoch=3, Val Loss=0.659879, Mean Val Loss=0.194741, Mean Val F-1=0.875375]
541it [02:36,  3.47it/s, Epoch=4, Loss=0.758780, Mean Loss=0.789214, Mean F-1=0.694373]
181it [00:23,  7.75it/s, Epoch=4, Val Loss=0.295669, Mean Val Loss=0.127029, Mean Val F-1=0.921111]
541it [02:51,  3.15it/s, Epoch=5, Loss=1.186151, Mean Loss=0.721706, Mean F-1=0.731193]
181it [00:26,  6.75it/s, Epoch=5, Val Loss=0.490410, Mean Val Loss=0.129279, Mean Val F-1=0.909031]


Early Stopping Step : [1 / 4


541it [02:56,  3.06it/s, Epoch=6, Loss=0.413052, Mean Loss=0.671426, Mean F-1=0.749833]
181it [00:27,  6.55it/s, Epoch=6, Val Loss=0.275321, Mean Val Loss=0.112037, Mean Val F-1=0.924627]
541it [03:01,  2.97it/s, Epoch=7, Loss=0.486192, Mean Loss=0.658108, Mean F-1=0.747335]
181it [00:29,  6.17it/s, Epoch=7, Val Loss=0.264688, Mean Val Loss=0.111929, Mean Val F-1=0.927304]
541it [03:07,  2.88it/s, Epoch=8, Loss=0.934950, Mean Loss=0.628171, Mean F-1=0.772970]
181it [00:30,  5.94it/s, Epoch=8, Val Loss=0.073486, Mean Val Loss=0.121832, Mean Val F-1=0.915695]


Early Stopping Step : [1 / 4


541it [03:13,  2.80it/s, Epoch=9, Loss=0.519669, Mean Loss=0.626554, Mean F-1=0.764980]
181it [00:32,  5.60it/s, Epoch=9, Val Loss=0.200498, Mean Val Loss=0.100779, Mean Val F-1=0.928080]
541it [03:21,  2.69it/s, Epoch=10, Loss=0.675605, Mean Loss=0.605398, Mean F-1=0.762400]
181it [00:32,  5.61it/s, Epoch=10, Val Loss=0.103660, Mean Val Loss=0.106921, Mean Val F-1=0.927504]


Early Stopping Step : [1 / 4


541it [03:29,  2.58it/s, Epoch=11, Loss=0.729499, Mean Loss=0.576031, Mean F-1=0.782970]
181it [00:37,  4.83it/s, Epoch=11, Val Loss=0.006523, Mean Val Loss=0.100206, Mean Val F-1=0.936547]
541it [03:34,  2.53it/s, Epoch=12, Loss=0.244528, Mean Loss=0.604836, Mean F-1=0.765667]
181it [00:39,  4.64it/s, Epoch=12, Val Loss=0.198404, Mean Val Loss=0.111016, Mean Val F-1=0.931671]


Early Stopping Step : [1 / 4


541it [03:44,  2.41it/s, Epoch=13, Loss=0.015268, Mean Loss=0.579740, Mean F-1=0.788381]
181it [00:42,  4.23it/s, Epoch=13, Val Loss=0.019683, Mean Val Loss=0.175373, Mean Val F-1=0.919309]


Early Stopping Step : [2 / 4


541it [03:52,  2.32it/s, Epoch=14, Loss=0.324875, Mean Loss=0.573758, Mean F-1=0.777259]
181it [00:44,  4.10it/s, Epoch=14, Val Loss=0.051260, Mean Val Loss=0.119307, Mean Val F-1=0.933193]


Early Stopping Step : [3 / 4


541it [03:54,  2.31it/s, Epoch=15, Loss=0.374604, Mean Loss=0.564092, Mean F-1=0.782154]
181it [00:47,  3.85it/s, Epoch=15, Val Loss=0.056598, Mean Val Loss=0.126276, Mean Val F-1=0.935431]


Early Stopping Step : [4 / 4
== Early Stop ==



===== k_fold : 2 / 4 =====


541it [02:53,  3.11it/s, Epoch=1, Loss=1.233305, Mean Loss=1.470769, Mean F-1=0.487548]
181it [00:25,  6.99it/s, Epoch=1, Val Loss=0.048926, Mean Val Loss=0.358544, Mean Val F-1=0.785379]
541it [02:50,  3.18it/s, Epoch=2, Loss=0.350313, Mean Loss=0.990728, Mean F-1=0.635618]
181it [00:25,  7.06it/s, Epoch=2, Val Loss=0.039179, Mean Val Loss=0.249132, Mean Val F-1=0.838644]
541it [02:52,  3.13it/s, Epoch=3, Loss=1.293963, Mean Loss=0.819541, Mean F-1=0.693467]
181it [00:26,  6.72it/s, Epoch=3, Val Loss=0.005963, Mean Val Loss=0.180652, Mean Val F-1=0.875679]
541it [02:56,  3.06it/s, Epoch=4, Loss=0.478967, Mean Loss=0.774001, Mean F-1=0.707493]
181it [00:24,  7.47it/s, Epoch=4, Val Loss=0.009182, Mean Val Loss=0.161686, Mean Val F-1=0.880047]
541it [02:41,  3.35it/s, Epoch=5, Loss=0.820817, Mean Loss=0.704911, Mean F-1=0.733950]
181it [00:24,  7.49it/s, Epoch=5, Val Loss=0.002666, Mean Val Loss=0.123866, Mean Val F-1=0.900634]
541it [02:41,  3.35it/s, Epoch=6, Loss=0.514521, Mean Loss=0

Early Stopping Step : [1 / 4


541it [02:46,  3.25it/s, Epoch=7, Loss=0.490658, Mean Loss=0.670987, Mean F-1=0.735606]
181it [00:26,  6.74it/s, Epoch=7, Val Loss=0.009444, Mean Val Loss=0.128065, Mean Val F-1=0.895962]


Early Stopping Step : [2 / 4


541it [02:51,  3.15it/s, Epoch=8, Loss=0.644825, Mean Loss=0.645003, Mean F-1=0.771374]
181it [00:28,  6.37it/s, Epoch=8, Val Loss=0.001613, Mean Val Loss=0.145521, Mean Val F-1=0.902049]
541it [02:56,  3.06it/s, Epoch=9, Loss=0.700701, Mean Loss=0.597609, Mean F-1=0.783557]
181it [00:29,  6.05it/s, Epoch=9, Val Loss=0.031579, Mean Val Loss=0.124879, Mean Val F-1=0.901232]


Early Stopping Step : [1 / 4


541it [03:01,  2.97it/s, Epoch=10, Loss=0.450081, Mean Loss=0.614080, Mean F-1=0.759407]
181it [00:34,  5.26it/s, Epoch=10, Val Loss=0.008846, Mean Val Loss=0.138348, Mean Val F-1=0.914851]
541it [03:07,  2.89it/s, Epoch=11, Loss=1.243132, Mean Loss=0.600131, Mean F-1=0.780875]
181it [00:36,  4.90it/s, Epoch=11, Val Loss=0.006395, Mean Val Loss=0.108517, Mean Val F-1=0.926798]
541it [03:13,  2.79it/s, Epoch=12, Loss=0.504172, Mean Loss=0.594912, Mean F-1=0.771280]
181it [00:38,  4.70it/s, Epoch=12, Val Loss=0.002203, Mean Val Loss=0.149165, Mean Val F-1=0.933560]
541it [03:19,  2.72it/s, Epoch=13, Loss=0.583256, Mean Loss=0.588367, Mean F-1=0.775559]
181it [00:42,  4.24it/s, Epoch=13, Val Loss=0.004583, Mean Val Loss=0.151618, Mean Val F-1=0.909546]


Early Stopping Step : [1 / 4


541it [03:25,  2.63it/s, Epoch=14, Loss=0.462105, Mean Loss=0.562251, Mean F-1=0.789381]
181it [00:42,  4.25it/s, Epoch=14, Val Loss=0.002372, Mean Val Loss=0.123883, Mean Val F-1=0.923813]


Early Stopping Step : [2 / 4


541it [03:29,  2.58it/s, Epoch=15, Loss=0.660171, Mean Loss=0.559112, Mean F-1=0.793107]
181it [00:44,  4.11it/s, Epoch=15, Val Loss=0.022026, Mean Val Loss=0.178286, Mean Val F-1=0.904915]


Early Stopping Step : [3 / 4



===== k_fold : 3 / 4 =====


541it [02:29,  3.62it/s, Epoch=1, Loss=0.652018, Mean Loss=1.433800, Mean F-1=0.506464]
181it [00:22,  8.17it/s, Epoch=1, Val Loss=0.012445, Mean Val Loss=0.393662, Mean Val F-1=0.766572]
541it [02:27,  3.67it/s, Epoch=2, Loss=0.216567, Mean Loss=0.975432, Mean F-1=0.643431]
181it [00:21,  8.30it/s, Epoch=2, Val Loss=0.009252, Mean Val Loss=0.305133, Mean Val F-1=0.829625]
541it [02:27,  3.66it/s, Epoch=3, Loss=0.645693, Mean Loss=0.830357, Mean F-1=0.675422]
181it [00:21,  8.34it/s, Epoch=3, Val Loss=0.014635, Mean Val Loss=1.883229, Mean Val F-1=0.836898] 
541it [02:29,  3.62it/s, Epoch=4, Loss=0.874368, Mean Loss=0.743409, Mean F-1=0.728040]
181it [00:22,  7.93it/s, Epoch=4, Val Loss=0.017069, Mean Val Loss=0.186955, Mean Val F-1=0.870965]
541it [02:35,  3.49it/s, Epoch=5, Loss=0.616026, Mean Loss=0.723690, Mean F-1=0.729123]
181it [00:24,  7.42it/s, Epoch=5, Val Loss=0.006131, Mean Val Loss=5.327231, Mean Val F-1=0.860966]  


Early Stopping Step : [1 / 4


541it [02:41,  3.34it/s, Epoch=6, Loss=0.699418, Mean Loss=0.683976, Mean F-1=0.743930]
181it [00:25,  7.01it/s, Epoch=6, Val Loss=0.016438, Mean Val Loss=0.400601, Mean Val F-1=0.890412] 
541it [02:50,  3.18it/s, Epoch=7, Loss=0.819117, Mean Loss=0.641112, Mean F-1=0.771144]
181it [00:27,  6.51it/s, Epoch=7, Val Loss=0.001976, Mean Val Loss=24.416039, Mean Val F-1=0.884098]   


Early Stopping Step : [1 / 4


541it [02:56,  3.07it/s, Epoch=8, Loss=0.492369, Mean Loss=0.614641, Mean F-1=0.746468]
181it [00:29,  6.17it/s, Epoch=8, Val Loss=0.003149, Mean Val Loss=1.735569, Mean Val F-1=0.898817]  
541it [03:03,  2.95it/s, Epoch=9, Loss=0.347127, Mean Loss=0.617854, Mean F-1=0.779960]
181it [00:36,  4.92it/s, Epoch=9, Val Loss=0.001965, Mean Val Loss=0.414169, Mean Val F-1=0.886996] 


Early Stopping Step : [1 / 4


541it [03:08,  2.87it/s, Epoch=10, Loss=0.324988, Mean Loss=0.593906, Mean F-1=0.774721]
181it [00:38,  4.75it/s, Epoch=10, Val Loss=0.002819, Mean Val Loss=0.196033, Mean Val F-1=0.902410]
541it [03:14,  2.79it/s, Epoch=11, Loss=0.647460, Mean Loss=0.594664, Mean F-1=0.776433]
181it [00:38,  4.69it/s, Epoch=11, Val Loss=0.005537, Mean Val Loss=0.333523, Mean Val F-1=0.905754] 
541it [03:19,  2.71it/s, Epoch=12, Loss=0.279971, Mean Loss=0.571283, Mean F-1=0.797479]
181it [00:40,  4.47it/s, Epoch=12, Val Loss=0.003217, Mean Val Loss=1.960877, Mean Val F-1=0.890987] 


Early Stopping Step : [1 / 4


541it [03:25,  2.64it/s, Epoch=13, Loss=0.800833, Mean Loss=0.574390, Mean F-1=0.773570]
181it [00:41,  4.32it/s, Epoch=13, Val Loss=0.001056, Mean Val Loss=0.357819, Mean Val F-1=0.918859]
541it [03:31,  2.55it/s, Epoch=14, Loss=0.125304, Mean Loss=0.565356, Mean F-1=0.773215]
181it [00:44,  4.08it/s, Epoch=14, Val Loss=0.003620, Mean Val Loss=3.182777, Mean Val F-1=0.901883]  


Early Stopping Step : [1 / 4


541it [03:42,  2.44it/s, Epoch=15, Loss=0.079399, Mean Loss=0.550638, Mean F-1=0.794100]
181it [00:46,  3.91it/s, Epoch=15, Val Loss=0.001704, Mean Val Loss=3.398019, Mean Val F-1=0.899875]  


Early Stopping Step : [2 / 4



===== k_fold : 4 / 4 =====


541it [02:54,  3.10it/s, Epoch=1, Loss=1.395277, Mean Loss=1.507653, Mean F-1=0.486589]
181it [00:27,  6.50it/s, Epoch=1, Val Loss=0.814030, Mean Val Loss=1.343399, Mean Val F-1=0.776442] 
541it [02:33,  3.52it/s, Epoch=2, Loss=0.597250, Mean Loss=0.973679, Mean F-1=0.627546]
181it [00:23,  7.81it/s, Epoch=2, Val Loss=0.904697, Mean Val Loss=0.335133, Mean Val F-1=0.808348]
541it [02:32,  3.54it/s, Epoch=3, Loss=0.935366, Mean Loss=0.839930, Mean F-1=0.684141]
181it [00:23,  7.84it/s, Epoch=3, Val Loss=0.323997, Mean Val Loss=0.168485, Mean Val F-1=0.886911]
541it [02:32,  3.54it/s, Epoch=4, Loss=0.871123, Mean Loss=0.771637, Mean F-1=0.693241]
181it [00:23,  7.65it/s, Epoch=4, Val Loss=0.103376, Mean Val Loss=0.154185, Mean Val F-1=0.907905]
541it [02:38,  3.41it/s, Epoch=5, Loss=0.795734, Mean Loss=0.698122, Mean F-1=0.734639]
181it [00:25,  7.13it/s, Epoch=5, Val Loss=0.055574, Mean Val Loss=0.105056, Mean Val F-1=0.930435]
541it [02:45,  3.27it/s, Epoch=6, Loss=0.508657, Mean Loss=

Early Stopping Step : [1 / 4


541it [02:54,  3.11it/s, Epoch=7, Loss=0.662048, Mean Loss=0.654567, Mean F-1=0.732482]
181it [00:29,  6.23it/s, Epoch=7, Val Loss=0.000184, Mean Val Loss=0.107817, Mean Val F-1=0.929118]


Early Stopping Step : [2 / 4


541it [03:01,  2.98it/s, Epoch=8, Loss=0.255549, Mean Loss=0.625723, Mean F-1=0.765249]
181it [00:32,  5.51it/s, Epoch=8, Val Loss=0.004176, Mean Val Loss=0.090530, Mean Val F-1=0.945685]
541it [03:08,  2.87it/s, Epoch=9, Loss=1.033882, Mean Loss=0.627510, Mean F-1=0.743888]
181it [00:37,  4.89it/s, Epoch=9, Val Loss=0.002754, Mean Val Loss=0.114787, Mean Val F-1=0.922545]


Early Stopping Step : [1 / 4


541it [03:14,  2.77it/s, Epoch=10, Loss=0.868999, Mean Loss=0.624367, Mean F-1=0.772972]
181it [00:38,  4.65it/s, Epoch=10, Val Loss=0.013982, Mean Val Loss=0.105603, Mean Val F-1=0.927604]


Early Stopping Step : [2 / 4


541it [03:21,  2.68it/s, Epoch=11, Loss=0.484558, Mean Loss=0.589643, Mean F-1=0.782846]
181it [00:41,  4.31it/s, Epoch=11, Val Loss=0.018749, Mean Val Loss=0.226079, Mean Val F-1=0.916993]


Early Stopping Step : [3 / 4


541it [03:28,  2.60it/s, Epoch=12, Loss=0.341924, Mean Loss=0.577737, Mean F-1=0.771209]
181it [00:44,  4.10it/s, Epoch=12, Val Loss=0.004625, Mean Val Loss=0.521098, Mean Val F-1=0.925555]

Early Stopping Step : [4 / 4
== Early Stop ==





# Test Dataset 정의

In [14]:
tta_transforms_224 = A.Compose([
                A.Resize(224 ,224),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])

tta_transforms_288 = A.Compose([
                A.Resize(288 ,288),
                A.OneOf([
                    A.Rotate(),
                    A.HorizontalFlip(),
                    A.VerticalFlip()
                ], p=1)
            ])


test = sorted(glob('data/test/*'))
test_dataset = CustomDataset(test, tta_transforms_224, tta_transforms_288, mode='test')


# test_dataset = CustomDataset(test, val_transforms, mode = 'test')
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Soft Voting

In [15]:
import torch.nn.functional as F

def softvoting(models, img_deit, img_effi, n_classes=25) :

    predicts = torch.zeros(img_deit.size(0), n_classes)
    with torch.no_grad() :
        for model in models :
            output = model(img_deit, img_effi)
            output = F.softmax(output.cpu(), dim=1)
            
            predicts += output

    # 둘다 값은 똑같이 나옴.
    # pred_avg = predicts / len(models)
    # answer = pred_avg.argmax(dim=-1)
    # _, answer2 = torch.max(pred_avg, 1)

    return predicts.detach().cpu() / len(models)

# Prediction

In [16]:
def predict(dataset, models) :
    tqdm_dataset = tqdm(enumerate(dataset))
    results = []
    for batch, batch_item in tqdm_dataset :
        img_effi = batch_item['effi'].to(device) # 288
        img_deit = batch_item['deit'].to(device) # 224

    
#         print(img.shape)
        predictions = softvoting(models, img_deit, img_effi)
        batch_result = [int(torch.argmax(prediction)) for prediction in predictions]
#         print(batch_result)
#         for prediction in predictions :
            
#         results.append(int(torch.argmax(predictions[0])))
#             output = model(img)
#         output = torch.tensor(torch.argmax(output, dim=1), dtype=torch.int32).cpu().numpy()
        results.extend(batch_result)
    return results


model_name_1 = 'efficientnetv2_rw_s'
model_name_2 = 'deit_small_patch16_224'
n_classes = 25

kfold_models_path = glob('./model/k_public_vill_50k_pretrain_EffiDeit/*.pt')
models = []
for kfold_model_path in kfold_models_path :
    model = EffiDeiT(model_name_1, model_name_2, n_classes=25)
    model.load_state_dict(torch.load(kfold_model_path, map_location=device))
    model.to(device).eval()
    models.append(model)
    
preds = predict(test_dataloader, models)




6489it [29:53,  3.62it/s]


In [17]:
preds_cp = preds

In [18]:
preds_cp = np.array([train_label_decoder[int(val)] for val in preds_cp])
submission_csv = pd.read_csv('./data/sample_submission.csv')
submission_csv['label'] = preds_cp
submission_csv.to_csv('./data/k_tta_public_vill_50k_pretrain_EffiDeit.csv', index=False)


# 단일 모델 prediction

In [13]:
def predict(dataset, model) :
    tqdm_dataset = tqdm(enumerate(dataset))
    results = []
    for batch, batch_item in tqdm_dataset :
        img_effi = batch_item['effi'].to(device) # 288
        img_deit = batch_item['deit'].to(device) # 224

    
#         print(img.shape)
        predictions = model(img_deit, img_effi)
        predictions = F.softmax(predictions.cpu(), dim=1)
        batch_result = [int(torch.argmax(prediction.detach().cpu())) for prediction in predictions]
#         print(batch_result)
#         for prediction in predictions :
            
#         results.append(int(torch.argmax(predictions[0])))
#             output = model(img)
#         output = torch.tensor(torch.argmax(output, dim=1), dtype=torch.int32).cpu().numpy()
        results.extend(batch_result)
    return results


model_name_1 = 'efficientnetv2_rw_s'
model_name_2 = 'deit_small_patch16_224'
n_classes = 25

kfold_model_path = '4_9456_public_vill_50k_pretrain_EffiDeit.pt'


model = EffiDeiT(model_name_1, model_name_2, n_classes=25, test=True)
model.load_state_dict(torch.load(kfold_model_path, map_location=device))
model.to(device).eval()



EffiDeiT(
  (total_model): total_model(
    (effi): EffiV2S(
      (model): EfficientNet(
        (conv_stem): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU(inplace=True)
        (blocks): Sequential(
          (0): Sequential(
            (0): EdgeResidual(
              (conv_exp): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (act1): SiLU(inplace=True)
              (se): Identity()
              (conv_pwl): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (1): EdgeResidual(
              (conv_exp): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding

In [14]:
    
preds = predict(test_dataloader, model)

1it [00:06,  6.89s/it]


RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 6.00 GiB total capacity; 3.64 GiB already allocated; 0 bytes free; 3.81 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
preds_cp = preds

In [None]:
preds_cp = np.array([train_label_decoder[int(val)] for val in preds_cp])
submission_csv = pd.read_csv('./data/sample_submission.csv')
submission_csv['label'] = preds_cp
submission_csv.to_csv('./data/k_public_vill_50k_pretrain_EffiDeit.csv', index=False)


In [16]:
a = torch.tensor([[ 3.6658, -0.5562, -0.1303,  0.5749, -0.8724, -0.6668, -2.5437,  0.1855,
         -0.2706,  0.0678, -2.0500,  1.9760, -1.0370,  0.3441,  0.0985, -0.5866,
          1.5043,  9.5056, -2.4885, -1.4415,  0.1996, -1.3245,  1.3968, -1.2040,
         -1.8589],
        [ 1.8219, -1.0577,  0.3212,  1.3284, -2.0215, -1.9177, -1.8310, -1.0591,
         -1.0233, -0.4418, -2.0763,  1.7701,  1.8592, -0.8363,  9.2857,  0.1729,
          4.6880,  0.9189, -0.9552, -3.4629, -2.3523, -1.3255, -2.2687, -0.3306,
         -2.9632],
        [ 2.5711, -0.8122, -0.7851,  3.7447,  0.2264, -0.7621, -2.6715,  0.6736,
          0.6371,  0.6878, -0.2916,  7.9526,  0.5164,  0.2882,  0.2075, -1.2009,
          0.3114,  1.2553, -1.6136, -3.8794, -2.2638, -3.2898, -1.2467, -2.5894,
         -2.0315],
        [ 2.8435,  0.1963,  1.5067, 12.1120, -1.7412, -3.2380, -3.3557, -0.2230,
          0.1163,  0.9585, -0.8354,  4.2385,  0.9182,  1.3466,  0.9568,  1.3733,
         -0.6904,  3.4432, -2.1798, -4.2389, -4.0742, -2.4559, -3.0247, -3.1489,
         -2.2941],
        [ 0.8880,  0.2074, -0.1185,  3.3785, -0.1042,  0.1022, -0.2805,  3.8969,
          4.9954, -0.7490,  8.5330,  1.3133, -1.4416,  0.1890, -1.2645,  0.6467,
         -1.5140,  1.0825, -2.3855, -3.1290, -1.3034, -1.7277, -2.6166, -1.6868,
         -2.7182],
        [ 1.9738,  0.9231,  0.7457,  1.3640,  0.5198,  0.2546, -1.9709,  1.8885,
          1.1139, -1.8140, -0.1462,  1.7628, -1.9919, -1.5093,  1.0957, -0.4808,
         -0.0146,  9.5250, -2.0937, -2.2077, -0.2840, -1.3846, -1.4953, -2.0964,
         -2.7014],
        [10.4486, -0.9522,  0.7026,  1.3279, -0.1882, -1.3398, -1.5953,  0.0131,
         -0.2339, -0.4893,  0.9711,  1.7524, -0.6166,  0.2391, -0.1995, -0.3715,
         -0.9501,  0.8616, -0.9770, -1.1656, -0.6833, -1.9415,  0.0552, -1.1305,
         -0.1138],
        [ 1.2428, -0.4092, -0.2502,  3.4074, -0.9933, -2.5447, -1.9779, -1.5831,
         -1.2632,  0.3616, -1.1316,  2.7023,  1.4124, -0.5245,  1.4021,  9.2077,
         -1.4010,  0.3429, -0.4383, -3.2723, -2.3361, -0.8439, -4.3588, -3.1138,
         -0.4072]], device='cuda:0')

In [19]:
batch_result = [int(torch.argmax(ab)) for ab in a]
batch_result

[17, 14, 11, 3, 10, 17, 0, 15]