In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
import torchvision
from torch import optim
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from pathlib import Path
from torch.utils.data import DataLoader
from arc_vae.data_loader import Tasks,torchify, Grids, Grids_n
from arc_vae.vae.models import vae
from arc_vae.vae.models import losses
import torch.nn.functional as F
from arc_vae.vae import training
from arc_vae.utils import arc_to_image, visualize_grids

In [49]:
def get_dataloader(batch_size, eval=False):
    
    tasks = Grids_n(eval=eval)
    loader = DataLoader(tasks, batch_size=batch_size, drop_last=True)
    
    return loader

In [50]:
class VAE(nn.Module):
    def __init__(self,imgChannels=11, zDim=64):
        super(VAE, self).__init__()

        self.batch_size = 1
        self.encConv1 = nn.Conv2d(imgChannels, 50, 3)
        self.encConv2 = nn.Conv2d(50, 100, 3)
        self.encFC1 = nn.Linear(128*2*2, zDim)
        self.encFC2 = nn.Linear(128*2*2, zDim)

        self.decFC1 = nn.Linear(zDim, 128*2*2)
        self.decConv1 = nn.ConvTranspose2d(36, 16, 1)
        self.decConv2 = nn.ConvTranspose2d(16, 11, 1)

    def encoder(self, x):
        
        x = F.relu(self.encConv1(x))
        x = F.relu(self.encConv2(x))
        
        x = x.view(-1, 128*2*2)
        
        mu = self.encFC1(x)
        logVar = self.encFC2(x)
        
        return mu, logVar

    def reparameterize(self, mu, logVar):

        
        std = torch.exp(logVar/2)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decoder(self, z):
        
        
        x = F.relu(self.decFC1(z))
        batch_size = x.shape[1]
        x = x.view(batch_size,-1,10,10)
        
        x = F.relu(self.decConv1(x))
        x = torch.sigmoid(self.decConv2(x))
       
        return x

    def forward(self, x):

        mu, logVar = self.encoder(x)
        z = self.reparameterize(mu, logVar)
        
        out = self.decoder(z)
        return out, mu, logVar

In [51]:
def get_model():
    
    model = VAE()
    model.cuda()
    return model

def weight_reset(m):
    
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m,nn.ConvTranspose2d):
        m.reset_parameters()

def get_optimizer(learning_rate, model):
    
    return optim.Adam(model.parameters(), lr=learning_rate)

def train(model, data_loader, optimizer, epochs):
    
    for epoch in range(epochs):
        
        for data in data_loader:
            inputs = data
            
            inputs = inputs.cuda()
            
            out,mu,logVar = model(inputs)
            kl_divergence = 0.5 * torch.sum(-1 - logVar + mu.pow(2) + logVar.exp())
            loss = F.binary_cross_entropy(out, inputs) + kl_divergence
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f'epoch: {epoch}, loss: {loss}')


In [52]:

def main():
    batch_size = 512
    epochs = 50
    train_loader = get_dataloader(batch_size=batch_size)
    model = get_model()
    optimizer = get_optimizer(1e-3, model)
    
    model.apply(weight_reset)
    train(model, train_loader, optimizer, epochs)

In [53]:
if __name__ == "__main__":
    main()

epoch: 0, loss: 0.224081888794899
epoch: 1, loss: 0.22172623872756958
epoch: 2, loss: 0.2204207181930542
epoch: 3, loss: 0.22131764888763428
epoch: 4, loss: 0.21913108229637146
epoch: 5, loss: 0.21926121413707733
epoch: 6, loss: 0.22023163735866547
epoch: 7, loss: 0.22008129954338074
epoch: 8, loss: 0.22105348110198975
epoch: 9, loss: 0.22156380116939545
epoch: 10, loss: 0.2214413583278656
epoch: 11, loss: 0.21934586763381958
epoch: 12, loss: 0.21985502541065216
epoch: 13, loss: 0.2199607789516449
epoch: 14, loss: 0.21956989169120789
epoch: 15, loss: 0.22002233564853668
epoch: 16, loss: 0.2202400118112564
epoch: 17, loss: 0.2203466296195984
epoch: 18, loss: 0.21881918609142303
epoch: 19, loss: 0.22034616768360138
epoch: 20, loss: 0.22035172581672668
epoch: 21, loss: 0.2206840217113495
epoch: 22, loss: 0.22027507424354553
epoch: 23, loss: 0.21947236359119415
epoch: 24, loss: 0.2196587324142456
epoch: 25, loss: 0.21997959911823273
epoch: 26, loss: 0.21914519369602203
epoch: 27, loss: 0.2