In [1]:
import torch
from torch import nn
from model import VAE
from torch.distributions import Normal
import torch.nn.functional as F

In [2]:
in_channels = 1
out_channels = 1
intermediate_channels = 32
z_dimensions = 16
sigma_decoder = 0.1
q_z = Normal(torch.tensor(0.), torch.tensor(1.))

vae_with_no_kl = VAE(in_channels, intermediate_channels, 1, z_dimensions, pixelcnn=False, kl=0, mmd=0)
vae_with_kl = VAE(in_channels, intermediate_channels, 1, z_dimensions, pixelcnn=False)
vae_with_mmd = VAE(in_channels, intermediate_channels, 1, z_dimensions, pixelcnn=False, kl=0, mmd=1, require_rsample=False)

pixelvae_with_no_kl = VAE(in_channels, intermediate_channels, 256, z_dimensions, pixelcnn=True,kl=0, sigma_decoder=0, require_rsample=False)
pixelvae_with_kl = VAE(in_channels, intermediate_channels, 256, z_dimensions, pixelcnn=True, kl=1, mmd=0, sigma_decoder=0)
pixelvae_with_mmd = VAE(in_channels, intermediate_channels, 256, z_dimensions, pixelcnn=True, kl=0, mmd=1, require_rsample=False, sigma_decoder=0)


In [23]:
input_1 = torch.randn(10,1,64,64)
target_1 = torch.randn(10,1,64,64)
target_255 = torch.randint(0, 256, (10,64,64)).long()

In [24]:
mu, logvar, encoding, reconstruction = vae_with_no_kl(input_1)
vae_with_no_kl.loss(target_1, mu, logvar, encoding, reconstruction)
print(-Normal(reconstruction, sigma_decoder).log_prob(target_1).sum() / target_1.shape[0])

I am going to rsample from mu and logvar
NLL is Normal related
tensor(201467.1875, grad_fn=<DivBackward0>)


In [25]:
mu, logvar, encoding, reconstruction = vae_with_kl(input_1)
q_z_x = Normal(mu, (0.5 * logvar).exp())
print(vae_with_kl.loss(target_1, mu, logvar, encoding, reconstruction))
print((-Normal(reconstruction, sigma_decoder).log_prob(target_1).sum())/target_1.shape[0] +
     (torch.distributions.kl.kl_divergence(q_z_x, q_z).sum())/target_1.shape[0])

I am going to rsample from mu and logvar
Loss includes kl
NLL is Normal related
tensor(201555.0156, grad_fn=<DivBackward0>)
tensor(201555.0156, grad_fn=<ThAddBackward>)


In [26]:
mu, logvar, encoding, reconstruction = vae_with_mmd(input_1)
mmd = vae_with_mmd.compute_mmd(torch.randn(input_1.shape[0], mu.shape[1]), encoding.view(-1, mu.shape[1]))
print(vae_with_mmd.loss(target_1, mu, logvar, encoding, reconstruction))
print((-Normal(reconstruction, sigma_decoder).log_prob(target_1).sum() + mmd)/input_1.shape[0])

Loss includes MMD
NLL is Normal related
tensor(201791.4688, grad_fn=<DivBackward0>)
tensor(201791.4375, grad_fn=<DivBackward0>)


In [47]:
input_1 = torch.randn(10,1,64,64)
target_1 = torch.randn(10,1,64,64)
target_255 = torch.randint(0, 256, (10,64,64)).long()

mu, logvar, encoding, reconstruction = pixelvae_with_no_kl(input_1)
print(pixelvae_with_no_kl.loss(target_255, mu, logvar, encoding, reconstruction) -
F.cross_entropy(reconstruction, target_255, reduction='none').sum() / target_255.shape[0])

NLL is Cross Entropy
tensor(-0.0859, grad_fn=<ThSubBackward>)


In [29]:
mu, logvar, encoding, reconstruction = pixelvae_with_kl(input_1)
q_z_x = Normal(mu, (0.5 * logvar).exp())

print(pixelvae_with_kl.loss(target_255, mu, logvar, encoding, reconstruction))
print((F.cross_entropy(reconstruction, target_255, reduction='sum'))/target_1.shape[0] +
     (torch.distributions.kl.kl_divergence(q_z_x, q_z).sum())/target_1.shape[0])

I am going to rsample from mu and logvar
Loss includes kl
NLL is Cross Entropy
tensor(22833.8613, grad_fn=<DivBackward0>)
tensor(22833.8613, grad_fn=<ThAddBackward>)


In [31]:
mu, logvar, encoding, reconstruction = pixelvae_with_mmd(input_1)
mmd = pixelvae_with_mmd.compute_mmd(torch.randn(input_1.shape[0], mu.shape[1]), encoding.view(-1, mu.shape[1]))

print(pixelvae_with_mmd.loss(target_255, mu, logvar, encoding, reconstruction))
print(F.cross_entropy(reconstruction, target_255, reduction='sum') / target_255.shape[0]) 

Loss includes MMD
NLL is Cross Entropy
tensor(22830.5156, grad_fn=<DivBackward0>)
tensor(22830.2598, grad_fn=<DivBackward0>)
