In [8]:
import sys

sys.path.append('../Classification_PyTorch/')

In [9]:
import os
import argparse
import yaml

# import load_data, models, train

import torch

In [10]:
yaml_data = '../Classification_PyTorch/configs/config.yaml'

with open(yaml_data) as file:
    config = yaml.safe_load(file)

use_gpu = True

In [11]:
# Training Device
if use_gpu:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device('cpu')

In [16]:
# Root
project_root = '/home/sargis/Datasets/Stepan'

In [17]:
# Datasets
data_root = f'{project_root}/DFBS_Combine'

train_dir = os.path.join(data_root, 'train')
val_dir = None
test_dir = os.path.join(data_root, 'test')

In [18]:
# Batch sizes
train_batch_size = 256
val_batch_size = 1
test_batch_size = 128

In [19]:
# Network parameters
num_epochs = 301
num_classes = len(os.listdir(train_dir))
input_shape = (160, 50)

In [20]:
# Project
root_dir = f'{project_root}/DFBS_Classification'

# Train Id is the unique name of the current training process.
# It should be a descriptive name for the current training setup.
train_id = 'Default_10_Notebook_Train_1'

In [21]:
# Models

# The model's final checkpoint will be saved in {model_dir}
model_dir = f'{root_dir}/model/{train_id}'
# All other checkpoints are saved in {checkpoints_dir}
checkpoints_dir = f'{root_dir}/Checkpoint/{train_id}'

if not os.path.exists(model_dir):
    os.makedirs(model_dir)
if not os.path.exists(checkpoints_dir):
    os.makedirs(checkpoints_dir)

In [None]:
# In case if you want to continue your training from a certain checkpoint
start_epoch = 0
load_model_path = None  # f'{checkpoints_dir}/100.pth'
save_model_path = f"{model_dir}/final.pth"

# start_epoch = 6
# load_model_path = f'{checkpoints_dir}/5.pth'

train_data, train_classes, train_proportions = load_data.load_images(train_dir, train_batch_size, 'train')
val_data, val_classes, _ = load_data.load_images(val_dir, val_batch_size, 'val') if val_dir else (None, None, None)
test_data, test_classes, _ = load_data.load_images(test_dir, test_batch_size, 'test') if test_dir else (None, None, None)

print('\nTraining started:')

net = models.Model(num_classes=num_classes, input_shape=input_shape).to(device)
print(net)

if load_model_path:
    net.load_state_dict(torch.load(load_model_path))

net = train.train_model(
    net,
    train=train_data,
    val=val_data,
    test=test_data,
    epochs=num_epochs,
    start_epoch=start_epoch,
    device=device,
    model_folder=checkpoints_dir,
    train_id=train_id,
    classes=test_classes,
    train_proportions=train_proportions
)

torch.save(net.state_dict(), save_model_path)