### Load necessary packages

In [None]:
import torch
import torch.optim as optim

from dataset import FontsLoader
from models import AutoEncoder

### Setup the device on which to train

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
if device.type == 'cuda':
    torch.cuda.empty_cache()

### Load the data

In [None]:
set_loader = FontsLoader.get_set_loader()

### Build the Net

In [None]:
net = AutoEncoder(
    latent_dim=512,
    in_channels=1,
    num_hiddens=256,
    num_res_hiddens=64,
    num_res_layers=4,
    out_channels=1
).to(device)

## Train loop

In [None]:
def train(epochs=100):
    print('='*10, end='')
    print(' TRAIN', end=' ') 
    print('='*10, end='\n\n')
    net.train()

    for epoch in range(1, epochs+1):
        running_loss = 0

        for i, batch in enumerate(set_loader, 1):
            images = batch['image']
            images = images.to(device)
            
            # Zero grad
            optimizer.zero_grad()
            
            # Forward
            z, recon = net(images)
            # Compute Loss
            loss_value = net.loss_function(images, recon)
            running_loss += loss_value.item()
            # Backward
            loss_value.backward()
            # Update
            optimizer.step()

            if i % 5 == 0:
                print(f'==> EPOCH[{epoch}]({i}/{len(set_loader)}): LOSS: {loss_value.item()}')
        
        # Decrease LR
        lr_scheduler.step()
        
        print(f'=====> EPOCH[{epoch}] Completed: Avg. LOSS: {running_loss/len(set_loader)}')
        print()
        
    net.eval()

### Init optimizer and train

In [None]:
optimizer = optim.Adam(net.parameters(), lr=1e-4)
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [None]:
train()

In [None]:
torch.save(net.state_dict(), open('./checkpoints/ae-512-224x224-loss-0.024.pth', 'wb'))