In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim

import albumentations
import albumentations.pytorch

from mask_dataset import MaskDataset, get_transforms
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import f1_score
import timm

import os
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from PIL import Image
import random
from datetime import datetime

In [2]:
train_dir = '/opt/ml/input/data/train'
test_dir = '/opt/ml/input/data/eval'

In [3]:
# Training settings
batch_size = 32
epochs = 30
lr = 0.00003
gamma = 0.7
seed = 42
device = 'cuda'

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [5]:
transform = get_transforms()
dataset = MaskDataset(train_dir, transform)

In [6]:
lengths = [int(len(dataset)*0.8), int(len(dataset)*0.2)]
train_set, val_set = torch.utils.data.random_split(dataset, lengths,generator=torch.Generator().manual_seed(seed))

train_set.dataset.set_transform(transform['train'])
val_set.dataset.set_transform(transform['val'])

In [7]:
train_loader = DataLoader(
    train_set,
    batch_size = batch_size,
    num_workers=2,
    shuffle=True
)
valid_loader = DataLoader(
    val_set,
    batch_size = batch_size,
    num_workers=2,
    shuffle=False
)

In [8]:
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)

['adv_inception_v3',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'cspdarknet53',
 'cspresnet50',
 'cspresnext50',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dm_nfnet_f0',
 'dm_nfnet_f1',
 'dm_nfnet_f2',
 'dm_nfnet_f3',
 'dm_nfnet_f4',
 'dm_nfnet_f5',
 'dm_nfnet_f6',
 'dpn68',
 'dpn

In [9]:
model = timm.create_model('resnet18', pretrained=True).to(device)

In [10]:
today = datetime.today().strftime("%m%d")
path = f'/opt/ml/code/model/{today+"_"+model.__class__.__name__+"_"+"centcrop_only"}'
if not os.path.exists(path):
    os.makedirs(path)

In [11]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=gamma)
lrs = []

In [12]:
best_val_loss = 9999
best_val_f1 = 0
NUM_ACCUM = 4


for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    epoch_f1 = 0
    
    for i, data in enumerate(tqdm(train_loader,leave=True)):
        model.train()
        inputs, label = data
        inputs = inputs['image'].to(device)
        label = label.to(device)

        output = model(inputs)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        
#         if i % NUM_ACCUM == 0:
        optimizer.step()

            
        acc = (output.argmax(dim=1) == label).float().mean()

        epoch_accuracy += acc / len(train_loader)
        epoch_f1 += f1_score(output.argmax(dim=1).cpu(), label.cpu(), average='macro') / len(train_loader)
        epoch_loss += loss / len(train_loader)
        
    lrs.append(optimizer.param_groups[0]["lr"])
#     scheduler.step()

    with torch.no_grad():
        model.eval()
        epoch_val_accuracy = 0
        epoch_val_f1 = 0
        epoch_val_loss = 0
        
        for val_batch in valid_loader:
            inputs, label = val_batch
            inputs = inputs['image'].to(device)
            label = label.to(device)

            val_output = model(inputs)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_f1 += f1_score(val_output.argmax(dim=1).cpu(), label.cpu(), average='macro') / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
    

            
    if epoch_val_loss < best_val_loss:
        print("New best model for val loss! saving the model..")
        torch.save(model.state_dict(), os.path.join(f'{path}/{epoch:03}_loss_{epoch_val_loss:4.2}.pt'))
        best_val_loss = epoch_val_loss


    if epoch_val_f1 > best_val_f1:
        print("New best model for val f1! saving the model..")
        torch.save(model.state_dict(), os.path.join(f'{path}/{epoch:03}_f1_{epoch_val_f1:4.2}.pt'))
        best_val_f1 = epoch_val_f1
        best_model = f'{path}/{epoch:03}_f1_{epoch_val_f1:4.2}.pt'
    
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - f1: {epoch_f1:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f} - val_f1: {epoch_val_f1:.4f}\n"
    )

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 1 - loss : 1.2212 - acc: 0.7690 - f1: 0.6178 - val_loss : 0.2982 - val_acc: 0.9086 - val_f1: 0.8080



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 2 - loss : 0.2269 - acc: 0.9317 - f1: 0.8539 - val_loss : 0.2065 - val_acc: 0.9349 - val_f1: 0.8580



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 3 - loss : 0.1034 - acc: 0.9717 - f1: 0.9301 - val_loss : 0.1368 - val_acc: 0.9561 - val_f1: 0.8916



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 4 - loss : 0.0507 - acc: 0.9894 - f1: 0.9688 - val_loss : 0.1063 - val_acc: 0.9674 - val_f1: 0.9187



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 5 - loss : 0.0274 - acc: 0.9954 - f1: 0.9896 - val_loss : 0.0899 - val_acc: 0.9745 - val_f1: 0.9327



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 6 - loss : 0.0146 - acc: 0.9983 - f1: 0.9961 - val_loss : 0.0802 - val_acc: 0.9785 - val_f1: 0.9420



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 7 - loss : 0.0081 - acc: 0.9997 - f1: 0.9994 - val_loss : 0.0750 - val_acc: 0.9795 - val_f1: 0.9519



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
Epoch : 8 - loss : 0.0071 - acc: 0.9993 - f1: 0.9992 - val_loss : 0.0734 - val_acc: 0.9785 - val_f1: 0.9441



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 9 - loss : 0.0072 - acc: 0.9992 - f1: 0.9991 - val_loss : 0.1132 - val_acc: 0.9645 - val_f1: 0.9150



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 10 - loss : 0.0108 - acc: 0.9980 - f1: 0.9961 - val_loss : 0.1060 - val_acc: 0.9651 - val_f1: 0.9206



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
Epoch : 11 - loss : 0.0088 - acc: 0.9984 - f1: 0.9976 - val_loss : 0.0675 - val_acc: 0.9785 - val_f1: 0.9435



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 12 - loss : 0.0066 - acc: 0.9987 - f1: 0.9974 - val_loss : 0.0767 - val_acc: 0.9777 - val_f1: 0.9477



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val f1! saving the model..
Epoch : 13 - loss : 0.0044 - acc: 0.9995 - f1: 0.9993 - val_loss : 0.0708 - val_acc: 0.9814 - val_f1: 0.9519



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 14 - loss : 0.0046 - acc: 0.9992 - f1: 0.9983 - val_loss : 0.0759 - val_acc: 0.9793 - val_f1: 0.9460



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 15 - loss : 0.0075 - acc: 0.9980 - f1: 0.9952 - val_loss : 0.0826 - val_acc: 0.9761 - val_f1: 0.9425



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 16 - loss : 0.0062 - acc: 0.9985 - f1: 0.9981 - val_loss : 0.0778 - val_acc: 0.9777 - val_f1: 0.9489



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
Epoch : 17 - loss : 0.0053 - acc: 0.9987 - f1: 0.9974 - val_loss : 0.0636 - val_acc: 0.9806 - val_f1: 0.9512



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 18 - loss : 0.0026 - acc: 0.9995 - f1: 0.9986 - val_loss : 0.0721 - val_acc: 0.9814 - val_f1: 0.9488



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 19 - loss : 0.0036 - acc: 0.9992 - f1: 0.9977 - val_loss : 0.0883 - val_acc: 0.9774 - val_f1: 0.9410



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 20 - loss : 0.0010 - acc: 1.0000 - f1: 1.0000 - val_loss : 0.0596 - val_acc: 0.9856 - val_f1: 0.9644



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 21 - loss : 0.0011 - acc: 0.9999 - f1: 1.0000 - val_loss : 0.0742 - val_acc: 0.9808 - val_f1: 0.9539



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 22 - loss : 0.0012 - acc: 0.9998 - f1: 0.9996 - val_loss : 0.3590 - val_acc: 0.9110 - val_f1: 0.8800



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 23 - loss : 0.0132 - acc: 0.9962 - f1: 0.9931 - val_loss : 0.0767 - val_acc: 0.9758 - val_f1: 0.9392



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 24 - loss : 0.0045 - acc: 0.9989 - f1: 0.9975 - val_loss : 0.0754 - val_acc: 0.9769 - val_f1: 0.9439



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 25 - loss : 0.0020 - acc: 0.9999 - f1: 0.9997 - val_loss : 0.0624 - val_acc: 0.9840 - val_f1: 0.9565



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 26 - loss : 0.0008 - acc: 1.0000 - f1: 1.0000 - val_loss : 0.0646 - val_acc: 0.9840 - val_f1: 0.9612



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
New best model for val f1! saving the model..
Epoch : 27 - loss : 0.0007 - acc: 1.0000 - f1: 1.0000 - val_loss : 0.0549 - val_acc: 0.9863 - val_f1: 0.9700



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 28 - loss : 0.0061 - acc: 0.9980 - f1: 0.9960 - val_loss : 0.1377 - val_acc: 0.9714 - val_f1: 0.9310



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


Epoch : 29 - loss : 0.0096 - acc: 0.9972 - f1: 0.9928 - val_loss : 0.0585 - val_acc: 0.9824 - val_f1: 0.9514



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


New best model for val loss! saving the model..
Epoch : 30 - loss : 0.0021 - acc: 0.9996 - f1: 0.9994 - val_loss : 0.0525 - val_acc: 0.9871 - val_f1: 0.9677



In [13]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform=None):
        self.img_paths = img_paths
        self.transform = transform
        
    def set_transform(self, transform):
        """
        transform 함수를 설정하는 함수입니다.
        """
        self.transform = transform
    
    
    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image=np.array(image))
        return image

    def __len__(self):
        return len(self.img_paths)

In [16]:
model = timm.create_model('resnet18', pretrained=True).to(device)
model.load_state_dict(torch.load(f'{path}/026_f1_0.97.pt'))

<All keys matched successfully>

In [17]:
# meta 데이터와 이미지 경로를 불러옵니다.
submission = pd.read_csv(os.path.join(test_dir, 'info.csv'))
image_dir = os.path.join(test_dir, 'images')

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]

dataset = TestDataset(image_paths, transform)
dataset.set_transform(transform['val'])
loader = DataLoader(
    dataset,
    shuffle=False
)

In [18]:
device = torch.device('cuda')

model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in tqdm(loader):
    with torch.no_grad():
        images = images['image'].to(device)
        pred = model(images.float())
        pred = pred.argmax(dim=-1)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(test_dir, 'submission_0827_resnet18_1.csv'), index=False)
print('test inference is done!')

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12600.0), HTML(value='')))


test inference is done!
