In [1]:
TRIAL_ID = '32a_01'
USE_TBX = False
COEFF_REC = 1
COEFF_KLD = 1
objective = 'H'

In [2]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset,DataLoader
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import os

from utils import plot_3d_mesh, VertDataset, ResizeTo
import utils as ut
%load_ext autoreload
%autoreload 2


In [3]:
#Define paramaters
VOL_SIZE = 32
EPOCHS = 500
LEARNING_RATE = 1e-3
LOG_INTERVAL = 3

N = 1    # = None # for full dataset
BATCH_SIZE = 64
Z_DIMS = 64

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_path = pathlib.Path("./for_VAE/npz")

# beta = torch.tensor(4.0, requires_grad=True).float().to(device)

if USE_TBX:
    from tensorboardX import SummaryWriter
    #SummaryWriter encapsulates everything
    log_dir = pathlib.Path('./logs/%s' % TRIAL_ID)
    writer = SummaryWriter(log_dir)

In [4]:
dataset = VertDataset(train_path,n=N, transform=transforms.Compose([ResizeTo(VOL_SIZE),
                                                               transforms.ToTensor()]))
print('Number of samples: ',len(dataset))
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print('Number of batches: ',len(train_loader))

Number of samples:  1
Number of batches:  1


In [5]:
# x = dataset[0]
# print(x.size())
# print(x.unique())
# plot_3d_mesh(x)
# del x

In [7]:
# from vae32 import Vae32 as VAE
from vae32 import Vae32_v2 as VAE

# model = VAE(debug=False).to(device)
model = VAE(debug=False, z_dim=Z_DIMS).to(device)
weights_init = ut.init_weights(init_type='normal')
model.apply(weights_init)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [8]:
for name, data in model.named_parameters():
    print(name)

beta
encoder.0.weight
encoder.0.bias
encoder.2.weight
encoder.2.bias
encoder.4.weight
encoder.4.bias
encoder.6.weight
encoder.6.bias
encoder.8.weight
encoder.8.bias
encoder.10.weight
encoder.10.bias
decoder.1.weight
decoder.1.bias
decoder.3.weight
decoder.3.bias
decoder.5.weight
decoder.5.bias
decoder.7.weight
decoder.7.bias
decoder.9.weight
decoder.9.bias
decoder.11.weight
decoder.11.bias


In [9]:
def reconstruction_loss(x, x_recon, distribution='gaussian'):
    batch_size = x.size(0)
    assert batch_size != 0

#     x_recon = F.sigmoid(x_recon)
    recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)

    return recon_loss

def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)

    return total_kld, dimension_wise_kld, mean_kld

In [11]:
import time
loss_history = list()
iters=0
for epoch in range(1, EPOCHS + 1):
    start = time.time()
################## TRAIN ########################
    model.train()
    train_loss = 0
    
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        x = data
        recon_batch, mu, logvar = model(x)

        recon_loss = reconstruction_loss(x, recon_batch)
        total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)    

        loss = recon_loss + model.beta * total_kld
    
        loss.backward()
        loss_history.append( loss.item() )
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {:02d} [{:04d}/{:04d} ({:.0f}%)]\tBeta: {:.4f}, Rec: {:.4f}, TotKLD: {:.4f}, MeanKLD: {:.5f}, Loss: {:.5f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                model.beta.item(), recon_loss.item(), total_kld.item(), mean_kld.item(), loss.item() ))
        iters+=1
    #Finished one epoch        
    print('====> Epoch: {:02d} Average loss: {:.5f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


    end = time.time()
    print('Training time for epoch: {:.2f}s'.format(end-start))
        
    print('======================================================')


size_average and reduce args will be deprecated, please use reduction='sum' instead.



NameError: name 'beta' is not defined

In [None]:
#plot loss
SHOW_FROM = 0
plt.plot(loss_history[SHOW_FROM:])

In [None]:
#Check reconstruction
inbatch_idx = 0
with torch.no_grad():
    x = next(iter(train_loader))
    x = x.to(device)
    recon_x, _, _ = model(x)
    
    plot_3d_mesh(x[inbatch_idx])
    plot_3d_mesh(recon_x[inbatch_idx])

In [None]:
#CHECK GENERATION RESULTS
print('Generating samples')
N_samples = 20
with torch.no_grad():
    z_samples = torch.randn(N_samples, Z_DIMS).to(device)
    samples = model.decode(z_samples)
        
for x in samples:
        plot_3d_mesh(x)
#     plot_3d_mesh(ut.do_threshold(x, eps=0.3))
#     plot_3d_mesh(np.round(x))    