In [1]:
import numpy as np
import os
import cv2
import random
import torch
import yaml
import math
import shutil
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from modules.dataset import SignatureFigDataset
from torch.utils.data import DataLoader
from models.cvae.cvae import get_model
from models.cvae.train_cvae import train_for_one_epoch, visualize_latent_space

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
with open("config/config.yaml", "r") as f:
    try:
        config = yaml.safe_load(f)
    except yaml.YAMLError as exc:
        print(exc)
print(config)

seed = config['train_params']['seed']
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == 'cuda':
    torch.cuda.manual_seed_all(seed)

In [2]:
signatures_fig_path = os.path.join("data", "a_fig")

if not os.path.exists(signatures_fig_path):
    raise FileNotFoundError(f"The directory {signatures_fig_path} does not exist.")

signature_figs = []
for signature_fig_file in os.listdir(signatures_fig_path):
    signature_fig_file_path = os.path.join(signatures_fig_path, signature_fig_file)
    signature_fig = cv2.imread(signature_fig_file_path, cv2.IMREAD_GRAYSCALE)
    signature_fig = signature_fig.reshape(1, *signature_fig.shape, 1)
    signature_figs.append(signature_fig)

signature_figs = np.concatenate(signature_figs, axis=0) # (N, H, W, C)

signature_fig_dataset = SignatureFigDataset(signature_figs)

train_dataloader = DataLoader(signature_fig_dataset, batch_size=config['train_params']['batch_size'], shuffle=True)

model = get_model(config).to(device)

num_epochs = config['train_params']['epochs']
optimizer = Adam(model.parameters(), lr=config['train_params']['lr'])
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1, verbose=True)
criterion = {
    'l1': torch.nn.L1Loss(),
    'l2': torch.nn.MSELoss()
}.get(config['train_params']['crit'])

In [None]:
# Create output directories
if os.path.exists(config['train_params']['task_name']):
    shutil.rmtree(config['train_params']['task_name'])
os.mkdir(config['train_params']['task_name'])
os.mkdir(os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir']))

best_loss = math.inf
latent_im_path = os.path.join(config['train_params']['task_name'],
                                config['train_params']['output_train_dir'],
                                'latent_epoch_{}.jpeg')
with torch.no_grad():
    model.eval()
    visualize_latent_space(config, model, train_dataloader, save_fig_path=latent_im_path.format(0), device=device)
    model.train()
for epoch_idx in range(num_epochs):
    mean_loss = train_for_one_epoch(epoch_idx, model, train_dataloader, optimizer, criterion, config, device)
    if config['train_params']['save_latent_plot']:
        model.eval()
        with torch.no_grad():
            print('Generating latent plot on test set')
            visualize_latent_space(config, model, train_dataloader, save_fig_path=latent_im_path.format(epoch_idx + 1), device=device)
        model.train()
    scheduler.step(mean_loss)
    # Simply update checkpoint if found better version
    if mean_loss < best_loss:
        print('Improved Loss to {:.4f} .... Saving Model'.format(mean_loss))
        torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'],
                                                    config['train_params']['ckpt_name']))
        best_loss = mean_loss
    else:
        print('No Loss Improvement')