In [1]:
import os
import json
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
from sklearn.model_selection import train_test_split
from helper import *

In [2]:
LOG_PATH = Path('LOG')
SHARE_PATH = Path('share')
CHECKPOINT_PATH = Path('checkpoint')

In [3]:
if not os.path.exists(LOG_PATH):
    os.mkdir(LOG_PATH)

if not os.path.exists(CHECKPOINT_PATH):
    os.mkdir(CHECKPOINT_PATH)

In [4]:
logger = get_logger(LOG_PATH / 'conv_train.log', 'conv_train')

In [5]:
random_state = 777

root_path = SHARE_PATH / '1_train+val_210220 upload'
ann_path = root_path / 'Annotation_v2_Train+Val_210208.json'
save_path = CHECKPOINT_PATH / 'resnet50.pth'

image_size = 224
batch_size = 1024
lr = 0.01
epoch = 30
device = 'cuda'
num_classes = 3
n_splits = 5

In [6]:
train_transforms = transforms.Compose([transforms.Grayscale(),
                                       transforms.Resize((image_size, image_size)),
                                       transforms.ToTensor(),
                                       transforms.Normalize(0.5, 0.5)])

test_transforms = transforms.Compose([transforms.Grayscale(),
                                      transforms.Resize((image_size, image_size)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(0.5, 0.5)])

In [7]:
with open(ann_path, 'r') as f:
    json_data = json.load(f)
    
patients = json_data['Patient']

In [8]:
train_patients, valid_patients = train_test_split(patients, test_size=0.2, random_state=random_state)

In [9]:
print(f"TRAIN Patients : {len(train_patients)}")
print(f"VALID Patients : {len(valid_patients)}")

TRAIN Patients : 3944
VALID Patients : 986


In [10]:
train_dataset = SleepConvDataset(train_patients, root_path, train_transforms)
valid_dataset = SleepConvDataset(valid_patients, root_path, test_transforms)

In [11]:
train_loader = DataLoader(train_dataset,
                           batch_size=batch_size,
                           num_workers=8,
                           pin_memory=True,
                           shuffle=True)

valid_loader = DataLoader(valid_dataset,
                           batch_size=batch_size,
                           num_workers=8,
                           pin_memory=True,
                           shuffle=True)

In [12]:
early_stopping = EarlyStopping(verbose=True, path=save_path)

train_total = len(train_dataset)
valid_total = len(valid_dataset)

model = get_resnet50(num_classes, pretrained=True)
model = nn.DataParallel(model)
model = model.to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-5, momentum=0.9)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], gamma=0.1)

for e in range(0, epoch):
    train_correct, train_loss = train(model, train_loader, optimizer, criterion, device=device)
    train_acc = train_correct / train_total
    train_loss = train_loss / train_total

    valid_correct, valid_loss = valid(model, valid_loader, criterion, device=device)
    valid_acc = valid_correct / valid_total
    valid_loss = valid_loss / valid_total

    scheduler.step()

    logger.info("===============================================================")
    logger.info("===============================================================")
    logger.info(f"||    EPOCH : {epoch} / {e}]   ||")
    logger.info(f"|| [TRAIN ACC : {train_acc}] || [TRAIN LOSS : {train_loss}] ||")
    logger.info(f"|| [VALID ACC : {valid_acc}] || [VALID LOSS : {valid_loss}] ||")
    logger.info("===============================================================")
    logger.info("===============================================================")

    early_stopping(valid_loss, model)

    if early_stopping.early_stop:
        logger.info("Early stopping")
        break

    model.load_state_dict(torch.load(save_path))

100%|██████████| 2782/2782 [3:40:28<00:00,  4.75s/it]  
100%|██████████| 698/698 [56:03<00:00,  4.82s/it]  


Validation loss decreased (inf --> 0.000388).  Saving model ...


100%|██████████| 2782/2782 [3:43:14<00:00,  4.81s/it]  
100%|██████████| 698/698 [56:01<00:00,  4.82s/it]  


Validation loss decreased (0.000388 --> 0.000358).  Saving model ...


100%|██████████| 2782/2782 [3:42:05<00:00,  4.79s/it]  
100%|██████████| 698/698 [56:35<00:00,  4.86s/it]  
  0%|          | 0/2782 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 7


100%|██████████| 2782/2782 [3:42:37<00:00,  4.80s/it]  
100%|██████████| 698/698 [56:15<00:00,  4.84s/it]  


Validation loss decreased (0.000358 --> 0.000354).  Saving model ...


100%|██████████| 2782/2782 [3:42:18<00:00,  4.79s/it]  
100%|██████████| 698/698 [56:05<00:00,  4.82s/it]  
  0%|          | 0/2782 [00:00<?, ?it/s]

EarlyStopping counter: 1 out of 7


100%|██████████| 2782/2782 [3:41:21<00:00,  4.77s/it]  
100%|██████████| 698/698 [55:56<00:00,  4.81s/it]  
  0%|          | 0/2782 [00:00<?, ?it/s]

EarlyStopping counter: 2 out of 7


100%|██████████| 2782/2782 [3:43:44<00:00,  4.83s/it]  
100%|██████████| 698/698 [56:02<00:00,  4.82s/it]  
  0%|          | 0/2782 [00:00<?, ?it/s]

EarlyStopping counter: 3 out of 7


100%|██████████| 2782/2782 [3:40:22<00:00,  4.75s/it]  
100%|██████████| 698/698 [55:35<00:00,  4.78s/it]  
  0%|          | 0/2782 [00:00<?, ?it/s]

EarlyStopping counter: 4 out of 7


100%|██████████| 2782/2782 [3:39:29<00:00,  4.73s/it]  
100%|██████████| 698/698 [55:25<00:00,  4.76s/it]  
  0%|          | 0/2782 [00:00<?, ?it/s]

EarlyStopping counter: 5 out of 7


100%|██████████| 2782/2782 [3:40:09<00:00,  4.75s/it]  
100%|██████████| 698/698 [55:36<00:00,  4.78s/it]  
  0%|          | 0/2782 [00:00<?, ?it/s]

EarlyStopping counter: 6 out of 7


100%|██████████| 2782/2782 [3:41:07<00:00,  4.77s/it]  
100%|██████████| 698/698 [55:54<00:00,  4.81s/it]  

EarlyStopping counter: 7 out of 7



