### Load necessary packages

In [1]:
import torch
import torch.optim as optim

from dataset import FontsLoader
from models import VAE

### Setup the device on which to train

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [3]:
if device.type == 'cuda':
    torch.cuda.empty_cache()

### Load the data

In [4]:
set_loader = FontsLoader.get_set_loader()

### Build the Net

In [5]:
net = VAE(
    latent_dim=128,
    beta=5,
    in_channels=1,
    num_hiddens=256,
    num_res_hiddens=64,
    num_res_layers=4,
    out_channels=1,
).to(device)

## Train loop

In [6]:
def train(epochs=100):
    print('='*10, end='')
    print(' TRAIN', end=' ') 
    print('='*10, end='\n\n')
    net.train()

    for epoch in range(1, epochs+1):
        running_loss = 0

        for i, batch in enumerate(set_loader, 1):
            images = batch['image']
            images = images.to(device)
            
            # Zero grad
            optimizer.zero_grad()
            
            # Forward
            mu, logvar, z, recon = net(images)
            # Compute Loss
            loss_value = net.loss_function(images, recon, mu, logvar)
            running_loss += loss_value.item()
            # Backward
            loss_value.backward()
            # Update
            optimizer.step()

            if i % 5 == 0:
                print(f'==> EPOCH[{epoch}]({i}/{len(set_loader)}): LOSS: {loss_value.item()}')
        
        # Decrease LR
        lr_scheduler.step()
        
        print(f'=====> EPOCH[{epoch}] Completed: Avg. LOSS: {running_loss/len(set_loader)}')
        print()
        
    net.eval()

### Init optimizer and train

In [13]:
optimizer = optim.Adam(net.parameters(), lr=1e-7)
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [14]:
train()


==> EPOCH[1](5/59): LOSS: 0.3812302350997925
==> EPOCH[1](10/59): LOSS: 0.37514233589172363
==> EPOCH[1](15/59): LOSS: 0.3844597041606903
==> EPOCH[1](20/59): LOSS: 0.37093761563301086
==> EPOCH[1](25/59): LOSS: 0.3812357783317566
==> EPOCH[1](30/59): LOSS: 0.3675537407398224
==> EPOCH[1](35/59): LOSS: 0.38718488812446594
==> EPOCH[1](40/59): LOSS: 0.4050031006336212
==> EPOCH[1](45/59): LOSS: 0.37270650267601013
==> EPOCH[1](50/59): LOSS: 0.37772539258003235
==> EPOCH[1](55/59): LOSS: 0.3626149594783783
=====> EPOCH[1] Completed: Avg. LOSS: 0.3816200626098504

==> EPOCH[2](5/59): LOSS: 0.3702203035354614
==> EPOCH[2](10/59): LOSS: 0.3742978870868683
==> EPOCH[2](15/59): LOSS: 0.39860814809799194
==> EPOCH[2](20/59): LOSS: 0.3905717730522156
==> EPOCH[2](25/59): LOSS: 0.3924049735069275
==> EPOCH[2](30/59): LOSS: 0.37736549973487854
==> EPOCH[2](35/59): LOSS: 0.38443100452423096
==> EPOCH[2](40/59): LOSS: 0.389833003282547
==> EPOCH[2](45/59): LOSS: 0.37701845169067383
==> EPOCH[2](50

==> EPOCH[15](25/59): LOSS: 0.3816039562225342
==> EPOCH[15](30/59): LOSS: 0.3754428029060364
==> EPOCH[15](35/59): LOSS: 0.36900588870048523
==> EPOCH[15](40/59): LOSS: 0.3932681381702423
==> EPOCH[15](45/59): LOSS: 0.37762799859046936
==> EPOCH[15](50/59): LOSS: 0.3705005645751953
==> EPOCH[15](55/59): LOSS: 0.3644530475139618
=====> EPOCH[15] Completed: Avg. LOSS: 0.38020595807140156

==> EPOCH[16](5/59): LOSS: 0.3858593702316284
==> EPOCH[16](10/59): LOSS: 0.36046406626701355
==> EPOCH[16](15/59): LOSS: 0.3827281594276428
==> EPOCH[16](20/59): LOSS: 0.3577369153499603
==> EPOCH[16](25/59): LOSS: 0.3624385595321655
==> EPOCH[16](30/59): LOSS: 0.3869953453540802
==> EPOCH[16](35/59): LOSS: 0.3887266516685486
==> EPOCH[16](40/59): LOSS: 0.40283799171447754
==> EPOCH[16](45/59): LOSS: 0.3701143264770508
==> EPOCH[16](50/59): LOSS: 0.39139115810394287


KeyboardInterrupt: 

In [15]:
torch.save(net.state_dict(), open('./checkpoints/vae_loss_038.pth', 'wb'))