In [None]:
import os
import torch
import configparser
from dataloader.dataloader import MyDataLoader
from model.Model import Model
import datetime

# Setting

In [None]:
config = configparser.ConfigParser()
config._interpolation = configparser.ExtendedInterpolation()
config.read('./setting.config')

In [None]:
is_train = str(config['data']['is_train']).lower() == 'true'
model_save_folder = str(config['data']['model_save_folder'])

# Data Loader

In [None]:
dataloader = MyDataLoader(config)
train_loader = dataloader.train_loader()
valid_loader = dataloader.valid_loader()
test_loader = dataloader.test_loader()

# Train

In [None]:
if is_train == True:

    # new training
    if model_save_folder == '':
        model_save_folder = './save-{}'.format(datetime.datetime.now().strftime('%Y%m%d_%H%M%S_%f'))
        model = Model(config, dataloader=dataloader, model_save_folder=model_save_folder)
        if os.path.exists(model_save_folder) == False:
            os.makedirs(model_save_folder)
            model.setup_logger(model_save_folder)
        print("start training....")
        model.train(train_loader, valid_loader)
    else:
        # retrain
        model = Model(config, dataloader=dataloader, model_save_folder=model_save_folder)
        start_epoch, min_val_loss = model.load(model_save_folder)
        model.setup_logger(model_save_folder)
        print("start retraining from epoch {} ....".format(start_epoch))
        model.train(train_loader, valid_loader,  start_epoch=start_epoch, min_val_loss=min_val_loss)
        
    print("start testing....")
    model_save_path = model_save_folder+'/best_validate_model.pth'
    model.test(test_loader, model_save_path=model_save_path)

# Test

In [None]:
if is_train == False:
    print("start testing....")

    model = Model(config, dataloader=dataloader, model_save_folder=model_save_folder)
    
    model.setup_logger(model_save_folder)
    model_save_path = '{}/best_validate_model.pth'.format(model_save_folder)
    model.test(test_loader, model_save_path=model_save_path, builtins_print=True)