In [1]:
import logging
import sys

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader

from dataset import MNISTDataset
from model import MNISTResNet50
from utils import train, test

logging.config.fileConfig(fname='train_resnet50.conf')
logger = logging.getLogger(__name__)


In [2]:
train_df = pd.read_csv('train.csv')
train_df, val_df = train_test_split(train_df, test_size=0.2, shuffle=False)

batch_size = 128
train_loader = DataLoader(MNISTDataset(
    train_df), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(MNISTDataset(
    val_df), batch_size=batch_size, shuffle=False)


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MNISTResNet50()
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)


In [None]:
epoch_num = 20
logger.info('train_acc' + ',' + 'val_acc')
for epoch in range(1, epoch_num + 1):
    train(model, device, train_loader, optimizer)
    train_error_rate = test(model, device, train_loader)
    val_error_rate = test(model, device, val_loader)
    logger.info(str(1 - train_error_rate) + ',' + str(1 - val_error_rate))
torch.save(model, 'resnet50.pt')
