### Load necessary packages

In [1]:
import torch
import torch.optim as optim

from dataset import FontsLoader
from models import AutoEncoder

### Setup the device on which to train

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [3]:
if device.type == 'cuda':
    torch.cuda.empty_cache()

### Load the data

In [4]:
set_loader = FontsLoader.get_set_loader()

### Build the Net

In [5]:
net = AutoEncoder(
    latent_dim=512,
    in_channels=1,
    num_hiddens=256,
    num_res_hiddens=64,
    num_res_layers=4,
    out_channels=1
).to(device)

## Train loop

In [6]:
def train(epochs=100):
    print('='*10, end='')
    print(' TRAIN', end=' ') 
    print('='*10, end='\n\n')
    net.train()

    for epoch in range(1, epochs+1):
        running_loss = 0

        for i, batch in enumerate(set_loader, 1):
            images = batch['image']
            images = images.to(device)
            
            # Zero grad
            optimizer.zero_grad()
            
            # Forward
            z, recon = net(images)
            # Compute Loss
            loss_value = net.loss_function(images, recon)
            running_loss += loss_value.item()
            # Backward
            loss_value.backward()
            # Update
            optimizer.step()

            if i % 5 == 0:
                print(f'==> EPOCH[{epoch}]({i}/{len(set_loader)}): LOSS: {loss_value.item()}')
        
        # Decrease LR
        lr_scheduler.step()
        
        print(f'=====> EPOCH[{epoch}] Completed: Avg. LOSS: {running_loss/len(set_loader)}')
        print()
        
    net.eval()

### Init optimizer and train

In [7]:
optimizer = optim.Adam(net.parameters(), lr=1e-3)
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [8]:
train()


==> EPOCH[1](5/59): LOSS: 1.2121385335922241
==> EPOCH[1](10/59): LOSS: 0.4766406714916229
==> EPOCH[1](15/59): LOSS: 0.3405570685863495
==> EPOCH[1](20/59): LOSS: 0.263521671295166
==> EPOCH[1](25/59): LOSS: 0.20399627089500427
==> EPOCH[1](30/59): LOSS: 0.1770455539226532
==> EPOCH[1](35/59): LOSS: 0.1739151030778885
==> EPOCH[1](40/59): LOSS: 0.15844199061393738
==> EPOCH[1](45/59): LOSS: 0.1453578770160675
==> EPOCH[1](50/59): LOSS: 0.14417175948619843
==> EPOCH[1](55/59): LOSS: 0.13143311440944672
=====> EPOCH[1] Completed: Avg. LOSS: 0.42104245318194566

==> EPOCH[2](5/59): LOSS: 0.08923208713531494
==> EPOCH[2](10/59): LOSS: 0.07781997323036194
==> EPOCH[2](15/59): LOSS: 0.07281296700239182
==> EPOCH[2](20/59): LOSS: 0.06618686765432358
==> EPOCH[2](25/59): LOSS: 0.06178554147481918
==> EPOCH[2](30/59): LOSS: 0.05492504686117172
==> EPOCH[2](35/59): LOSS: 0.06130988150835037
==> EPOCH[2](40/59): LOSS: 0.052529070526361465
==> EPOCH[2](45/59): LOSS: 0.04821160063147545
==> EPOCH

KeyboardInterrupt: 

In [9]:
torch.save(net.state_dict(), open('./checkpoints/ae-512-224x224-loss-0.028.pth', 'wb'))