### 0. 초기 환경 셋팅

In [1]:
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

from torch.utils.data import Dataset, DataLoader
from PIL import Image
from numpy import asarray
from pathlib import Path
from torchvision import models
# from tqdm import tqdm
from functools import lru_cache

print(torch.__version__)
print(pd.__version__)

1.6.0
1.1.3


In [2]:
# set device
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device: {}".format(device))

device: cuda


In [3]:
# for reproducibility
seed = 0

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

### 1. DATA Check

In [4]:
data_train = pd.read_csv("./train/train.tsv", sep='\t', header=None)
print(data_train)

data_test = pd.read_csv("./test/test.tsv", sep='\t', header=None)
print(data_test)

                    0   1   2
0        3_5_1123.jpg   3   5
1       3_20_1048.jpg   3  20
2         4_2_401.jpg   4   2
3         4_7_740.jpg   4   7
4         4_11_93.jpg   4  11
...               ...  ..  ..
15995  13_15_1600.jpg  13  15
15996  13_16_1570.jpg  13  16
15997   13_17_986.jpg  13  17
15998  13_18_4980.jpg  13  18
15999   13_20_282.jpg  13  20

[16000 rows x 3 columns]
             0
0        0.jpg
1        1.jpg
2        2.jpg
3        3.jpg
4        4.jpg
...        ...
3992  3992.jpg
3993  3993.jpg
3994  3994.jpg
3995  3995.jpg
3996  3996.jpg

[3997 rows x 1 columns]


### 2. Dataset & DataLoader

In [5]:
combined_label = {
    (3, 5): 0,
    (3, 20): 1,
    (4, 2): 2,
    (4, 7): 3,
    (4, 11): 4,
    (5, 8): 5,
    (7, 1): 6,
    (7, 20): 7,
    (8, 6): 8,
    (8, 9): 9,
    (10, 20): 10,
    (11, 14): 11,
    (13, 1): 12,
    (13, 6): 13,
    (13, 9): 14,
    (13, 15): 15,
    (13, 16): 16,
    (13, 17): 17,
    (13, 18): 18,
    (13, 20): 19
}

index2label = {v: k for k, v in combined_label.items()}

In [6]:
class PlantImageDataset(Dataset):
    
    def __init__(self, tsv_path, data_path):
        self.data_path = Path(data_path)
        self.data = self._preprocess(tsv_path)
            
    def _preprocess(self, tsv_path):
        data = pd.read_csv(tsv_path, sep='\t', header=None)
        return data
    
    def __len__(self):
        return len(self.data)
    
    # Override 필수 함수
    @lru_cache(maxsize=100000)
    def __getitem__(self, index):
        file_name = Path(self.data.iloc[index][0])
        
        if str(self.data_path) == 'train':
            plant_label = self.data.iloc[index][1]
            disease_label = self.data.iloc[index][2]
        else:
            plant_label = -1
            disease_label = -1
        
        file_path = self.data_path / file_name
        
        data = Image.open(file_path)
        data = asarray(data) # (H, W, C) 
        data = torch.FloatTensor(data)
        data = data.permute(2, 0, 1).contiguous()  # (H, W, C) --> (C, H, W)
        
        # print(data, plant_label, disease_label)
        
        label = combined_label.get((plant_label, disease_label), -1)
        
        if str(self.data_path) == 'train' and label == -1:
            raise Exception('Label Error!!')
        
        return data, torch.tensor(label)

In [7]:
train_dataset = PlantImageDataset(tsv_path="./train/train.tsv", data_path="train")


train_size = int(len(train_dataset) * 0.8)
valid_size = int(len(train_dataset) * 0.2)
assert len(train_dataset) == (train_size + valid_size)

train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size, valid_size])
print(len(train_dataset), len(valid_dataset))

test_dataset = PlantImageDataset(tsv_path="./test/test.tsv", data_path="test")

12800 3200


In [8]:
batch_size = 8 #16

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=False)

valid_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          drop_last=False)

test_loader = DataLoader(dataset=test_dataset,
                          batch_size=batch_size,
                          shuffle=False,
                          drop_last=False)

In [9]:
# for x, y in test_loader:
#     print(x, y)
#     break

### 3. Model

In [10]:
class PlantModel(nn.Module):
    def __init__(self, label_size=20):
        super(PlantModel, self).__init__()

        self.pre_model = models.resnext101_32x8d()
        #self.pre_model = models.wide_resnet101_2()
        self.pre_model.fc = nn.Linear(2048, label_size)

    def forward(self, x):
        out = self.pre_model(x)
        
        return out

model = PlantModel()
model = model.to(device)
print(model)

PlantModel(
  (pre_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): 

In [11]:
# model = models.wide_resnet50_2()
# model.fc = nn.Linear(2048, 14)

# model = model.to(device)

# print(model)

In [12]:
# model = models.vgg16()
# model.classifier._modules['6'] = nn.Linear(4096, 14)
# # model.fc = nn.Linear(1000, 14)

# model = model.to(device)

# print(model)

### 4. Train

In [13]:
# optimizer
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)

criterion = nn.CrossEntropyLoss(reduction='sum').to(device)

In [14]:
def train_one_epoch(model, train_loader, criterion, optimizer=None, device=None, epoch=None):
    model.train() # 학습 모드
    
    total_loss = 0.
    total_num = 0
    
    start_time = time.time()
   
    for batch, (data, label) in enumerate(train_loader):
        # data
        data = data.to(device)
        label = label.to(device)
        
        optimizer.zero_grad()
        hyp = model(data)
        
        loss = criterion(hyp, label)
        
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()

        total_loss += loss.detach().item()
        total_num += 1
        
        log_interval = 100
        if batch % log_interval == 0 and batch >= 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | grad_norm {:5.2f} | loss {:5.2f} | time {:.4f}'.format(
                    epoch+1, batch+1, len(train_loader), grad_norm, loss.detach().item(), elapsed))
            start_time = time.time()
    
    train_loss = total_loss / total_num
    return train_loss
        
        
def validate_one_epoch(model, valid_loader, criterion, optimizer=None, device=None, epoch=None, print_output=False):
    model.eval() # 평가 모드
    
    total_loss = 0.
    total_num = 0
    
    correct= 0
    total_size = 0
   
    with torch.no_grad():
        for batch, (data, label) in enumerate(valid_loader):
            # data
            data = data.to(device)
            label = label.to(device)
            
            hyp = model(data)
        
            loss = criterion(hyp, label)
            
            if print_output:
                hyp = torch.argmax(hyp, dim=-1)
                correct += (hyp == label).sum().item()

            total_loss += loss.detach().item()
            total_num += 1
            total_size += data.size(0)
    
    valid_loss = total_loss / total_num
    accuracy = correct / total_size
    
    return valid_loss, accuracy

PATH = './models/'

def save_model():
    torch.save(model, PATH + 'model.pt')  # 전체 모델 저장
    torch.save(model.state_dict(), PATH + 'model_state_dict.pt')  # 모델 객체의 state_dict 저장
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, PATH + 'all.tar')

In [15]:
training_epochs = 20 # 50 # 100

for epoch in range(training_epochs):
    
    # train_one_epoch
    train_loss = train_one_epoch(model, train_loader, criterion,  optimizer=optimizer, device=device, epoch=epoch)
    print("| epoch {:3d} | train_loss : {:7.4f}".format(epoch+1, train_loss))
    
    # validate_one_epoch
    valid_loss, accuracy = validate_one_epoch(model, valid_loader, criterion, optimizer=None, device=device, epoch=epoch, print_output=True)
    print("| epoch {:3d} | valid_loss : {:7.4f} accuracy : {:.4f}".format(epoch+1, valid_loss, accuracy))
    
    # save_model
    save_model()
    
    lr_scheduler.step(valid_loss)



| epoch   1 |     1/ 1600 batches | grad_norm 8841.16 | loss 26.05 | time 1.2172
| epoch   1 |   101/ 1600 batches | grad_norm 11.11 | loss 23.83 | time 33.6796
| epoch   1 |   201/ 1600 batches | grad_norm  8.70 | loss 17.20 | time 33.7390
| epoch   1 |   301/ 1600 batches | grad_norm 11.33 | loss 16.58 | time 33.8337
| epoch   1 |   401/ 1600 batches | grad_norm 12.00 | loss 19.12 | time 33.9338
| epoch   1 |   501/ 1600 batches | grad_norm  8.35 | loss 15.76 | time 33.9479
| epoch   1 |   601/ 1600 batches | grad_norm 11.09 | loss 17.08 | time 34.2105
| epoch   1 |   701/ 1600 batches | grad_norm  8.35 | loss 16.69 | time 33.9752
| epoch   1 |   801/ 1600 batches | grad_norm  8.75 | loss 17.55 | time 33.9737
| epoch   1 |   901/ 1600 batches | grad_norm 11.55 | loss 19.78 | time 34.0069
| epoch   1 |  1001/ 1600 batches | grad_norm  7.52 | loss 14.46 | time 34.1090
| epoch   1 |  1101/ 1600 batches | grad_norm  7.83 | loss 18.05 | time 34.0244
| epoch   1 |  1201/ 1600 batches | gra

| epoch   7 |     1/ 1600 batches | grad_norm  5.57 | loss  1.34 | time 0.3531
| epoch   7 |   101/ 1600 batches | grad_norm 14.65 | loss  6.76 | time 31.6718
| epoch   7 |   201/ 1600 batches | grad_norm  5.06 | loss  1.41 | time 31.7184
| epoch   7 |   301/ 1600 batches | grad_norm 10.86 | loss  4.64 | time 31.6974
| epoch   7 |   401/ 1600 batches | grad_norm 11.81 | loss  8.40 | time 31.7001
| epoch   7 |   501/ 1600 batches | grad_norm  6.44 | loss  1.76 | time 31.6652
| epoch   7 |   601/ 1600 batches | grad_norm  1.85 | loss  0.50 | time 31.6826
| epoch   7 |   701/ 1600 batches | grad_norm  2.68 | loss  0.61 | time 31.6804
| epoch   7 |   801/ 1600 batches | grad_norm  9.52 | loss  5.66 | time 31.7028
| epoch   7 |   901/ 1600 batches | grad_norm  6.54 | loss  1.50 | time 31.7331
| epoch   7 |  1001/ 1600 batches | grad_norm  6.79 | loss  4.00 | time 31.6684
| epoch   7 |  1101/ 1600 batches | grad_norm  0.64 | loss  0.09 | time 31.7133
| epoch   7 |  1201/ 1600 batches | grad_

| epoch  12 | valid_loss :  1.2322 accuracy : 0.9486
| epoch  13 |     1/ 1600 batches | grad_norm  8.43 | loss  1.05 | time 0.3537
| epoch  13 |   101/ 1600 batches | grad_norm  0.26 | loss  0.03 | time 31.9055
| epoch  13 |   201/ 1600 batches | grad_norm  4.03 | loss  0.51 | time 31.9435
| epoch  13 |   301/ 1600 batches | grad_norm 12.16 | loss  2.00 | time 32.0420
| epoch  13 |   401/ 1600 batches | grad_norm  0.17 | loss  0.02 | time 32.0891
| epoch  13 |   501/ 1600 batches | grad_norm  1.55 | loss  0.21 | time 32.0095
| epoch  13 |   601/ 1600 batches | grad_norm 16.57 | loss  6.36 | time 32.1582
| epoch  13 |   701/ 1600 batches | grad_norm  2.63 | loss  0.50 | time 32.1584
| epoch  13 |   801/ 1600 batches | grad_norm  7.49 | loss  1.39 | time 32.1073
| epoch  13 |   901/ 1600 batches | grad_norm 12.00 | loss  1.90 | time 32.0241
| epoch  13 |  1001/ 1600 batches | grad_norm 13.57 | loss  2.29 | time 32.0234
| epoch  13 |  1101/ 1600 batches | grad_norm  3.28 | loss  0.31 | t

| epoch  18 | train_loss :  0.5340
| epoch  18 | valid_loss :  0.1311 accuracy : 0.9952
| epoch  19 |     1/ 1600 batches | grad_norm  9.98 | loss  1.28 | time 0.3510
| epoch  19 |   101/ 1600 batches | grad_norm  0.20 | loss  0.02 | time 31.7727
| epoch  19 |   201/ 1600 batches | grad_norm  6.97 | loss  0.60 | time 31.8542
| epoch  19 |   301/ 1600 batches | grad_norm  0.00 | loss  0.00 | time 32.0643
| epoch  19 |   401/ 1600 batches | grad_norm  1.03 | loss  0.13 | time 32.6137
| epoch  19 |   501/ 1600 batches | grad_norm  2.90 | loss  0.31 | time 31.8810
| epoch  19 |   601/ 1600 batches | grad_norm  0.89 | loss  0.09 | time 32.1978
| epoch  19 |   701/ 1600 batches | grad_norm  0.37 | loss  0.05 | time 32.0332
| epoch  19 |   801/ 1600 batches | grad_norm 12.95 | loss  3.02 | time 31.8935
| epoch  19 |   901/ 1600 batches | grad_norm  0.11 | loss  0.01 | time 31.9427
| epoch  19 |  1001/ 1600 batches | grad_norm  5.67 | loss  0.82 | time 31.9499
| epoch  19 |  1101/ 1600 batches

### 5. Inference

In [16]:
result_p = []
result_d = []

for batch, (data, label) in enumerate(test_loader):
    # data
    data = data.to(device)
    label = label.to(device)
    
    hyp = model(data)
    
    pred = torch.argmax(hyp, dim=-1)
    
    sample_num = pred.size(0)
    
    for i in range(sample_num):
        pred_idx = pred[i].detach().item()
        
        pred_plant, pred_disease = index2label[pred_idx]
        result_p.append(pred_plant)
        result_d.append(pred_disease)
        
# print(result)

In [17]:
data_test = pd.read_csv("./test/test.tsv", sep='\t', header=None)
# print(data_test)

data_test[1] = result_p
data_test[2] = result_d

print(data_test)

data_test.to_csv("test_result.tsv", index=False, header=None, sep="\t")

             0   1   2
0        0.jpg   3   5
1        1.jpg   3  20
2        2.jpg   4   2
3        3.jpg   4   7
4        4.jpg   4  11
...        ...  ..  ..
3992  3992.jpg  13   6
3993  3993.jpg  13  16
3994  3994.jpg  13  17
3995  3995.jpg  13  18
3996  3996.jpg  13  20

[3997 rows x 3 columns]
