### Load necessary packages

In [1]:
import torch
import torch.optim as optim

from dataset import FontsLoader
from models import VAE

import matplotlib.pyplot as plt
%matplotlib inline

### Setup the device on which to train

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cpu


### Load the data

In [3]:
set_loader = FontsLoader.get_set_loader(num_workers=12)

### Build the Net

In [4]:
net = VAE(
    latent_dim=64,
    beta=5,
    in_channels=1,
    num_hiddens=128,
    num_res_hiddens=32,
    num_res_layers=2,
    out_channels=1,
).to(device)

## Train loop

In [5]:
def train(epochs=10):
    print('='*10, end='')
    print(' TRAIN', end=' ') 
    print('='*10, end='\n\n')
    net.train()

    for epoch in range(1, epochs+1):
        running_loss = 0

        for i, batch in enumerate(set_loader, 1):
            images = batch['image']
            images = images.to(device)
            
            # Zero grad
            optimizer.zero_grad()
            
            # Forward
            mu, logvar, z, recon = net(images)
            # Compute Loss
            loss_value = net.loss_function(images, recon, mu, logvar)
            running_loss += loss_value.item()
            # Backward
            loss_value.backward()
            # Update
            optimizer.step()

            if i % 100 == 0:
                print(f'==> EPOCH[{epoch}]({i}/{len(set_loader)}): LOSS: {loss_value.item()}')
            
        print(f'=====> EPOCH[{epoch}] Completed: Avg. LOSS: {running_loss/len(set_loader)}')
        print()
        
    net.eval()

### Init optimizer and train

In [6]:
optimizer = optim.Adam(net.parameters(), lr=1e-2)

In [7]:
train()


=====> EPOCH[1] Completed: Avg. LOSS: 2.655204953819554e+21



KeyboardInterrupt: 