In [1]:
from architectures.ResConv64 import *
import torch
import numpy as np
from carla_disentanglement.datasets.dsprites import DSpritesDataset
from models.annealed_vae import AnnealedVAE

import torch.nn as nn
import torch.distributions as td

In [2]:
class BetaVAE_custom_conv(nn.Module):
    def __init__(self, enc, dec, latent=10, channels=1, MNIST=False):
        super(BetaVAE_custom_conv, self).__init__()

        self.latent = latent
        self.channels = channels
        self.img_dim = 28 if MNIST else 64

        self.encoder = enc

        # Decoder
        self.decoder = dec

    def BottomUp(self, x):
        mu, lv = self.encoder(x)
        return mu.contiguous(), lv.contiguous()

    def reparameterize(self, mu, lv):
        std = lv.mul(0.5).exp()
        z = td.Normal(mu, std).rsample()
        return z.contiguous()

    def TopDown(self, z):
        # z = self.conv_prep(x)
        # unflatten_dim = int(np.sqrt(self.conv_out_dim / self.filters[-1]))
        # z = z.view(x.shape[0], self.filters[-1], unflatten_dim, unflatten_dim)
        out = self.decoder(z)
        return out

    def forward(self, x):
        mu, lv = self.BottomUp(x)
        z = self.reparameterize(mu, lv)
        out = self.TopDown(z)
        return torch.sigmoid(out)

    def calc_loss(self, x, beta):
        mu, lv = self.BottomUp(x)
        z = self.reparameterize(mu, lv)
        out = torch.sigmoid(self.TopDown(z))

        # zeros = torch.zeros_like(mu).detach()
        # ones = torch.ones_like(lv).detach()
        # p_x = td.Normal(loc=zeros, scale=ones)
        # q_zGx = td.Normal(loc=mu, scale=lv.mul(0.5).exp())
        # kl = td.kl_divergence(q_zGx, p_x).sum()# / x.shape[0]

        # x = x*0.3081 + 0.1307
        # nll = td.Bernoulli(logits=out).log_prob(x).sum() / x.shape[0]
        # BCEWithLogitsLoss because binary_cross_entropy_with_logits will not accepts reduction = none
        # nll = -nn.BCEWithLogitsLoss(reduction='none')(out, x).sum()# / x.shape[0]

        nll = -nn.functional.binary_cross_entropy(out, x, reduction='sum') / x.shape[0]
        kl = (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) + 1e-5) / x.shape[0]
        # print(kl, nll, out.min(), out.max())

        return (-nll + kl * beta).contiguous(), kl, nll, out

    def LT_fitted_gauss_2std(self, x, num_var=6, num_traversal=5, gif_fps=5, silent=False):
        # Cycle linearly through +-2 std dev of a fitted Gaussian.
        mu, lv = self.BottomUp(x)
        num_traversal += 1 if num_traversal % 2 == 0 else num_traversal

        for i, batch_mu in enumerate(mu[:num_var]):
            images = []
            images.append(torch.sigmoid(self.TopDown(batch_mu.unsqueeze(0))))
            for latent_var in range(batch_mu.shape[0]):
                new_mu = batch_mu.unsqueeze(0).repeat([num_traversal, 1])
                loc = mu[:, latent_var].mean()
                total_var = lv[:, latent_var].exp().mean() + mu[:, latent_var].var()
                scale = total_var.sqrt()

                # gif
                new_mu[:, latent_var] = cycle_interval(batch_mu[latent_var], num_traversal,
                                                       loc - 2 * scale, loc + 2 * scale)
                filename = os.path.join(os.getcwd(), "figures/mu_gifs/mu%d_var%d.gif" % (i+1,latent_var+1))
                save_animation(torch.sigmoid(self.TopDown(new_mu)), filename, num_traversal, fps=gif_fps)  #gif

                # Plot
                new_mu[:, latent_var] = torch.linspace((loc - 2 * scale).item(),
                                                       (loc + 2 * scale).item(),
                                                       steps = num_traversal)
                images.append(torch.sigmoid(self.TopDown(new_mu)))

            img_name = os.path.join(os.getcwd(), "figures/traversals/Traversal%d.pdf" % (i+1))
            traversal_plotting(images, img_name, num_traversals=num_traversal, silent=silent)  # Traversal image
        return images

    def get_latent(self, x):
        mu, _ = self.BottomUp(x)
        return mu

In [3]:
seed = 2
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

In [4]:
ds = DSpritesDataset()

In [5]:
z_dim = 10
num_channels = 1
image_size = 64

In [6]:
enc = GaussianResConv64(z_dim, num_channels, image_size)
dec = ResConv64Decoder(z_dim, num_channels, image_size)
net1 = AnnealedVAE(enc, dec, gamma=100.0, max_c=20, iterations_c=1e5, reconstruction='bce')
net2 = BetaVAE_custom_conv(enc, dec, z_dim, num_channels)

cuda


In [7]:
x = ds[0][0].unsqueeze(0)
net1.model.cpu()

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z1 = net1.model.encode(x)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z2 = net1.model.encode(x)

print(z1[0] - z2[0])
print(z1[1] - z2[1])

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)


In [8]:
net2.cpu()

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z1 = net2.BottomUp(x)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z2 = net2.BottomUp(x)

print(z1[0] - z2[0])
print(z1[1] - z2[1])

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)


In [9]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z1 = net1.model.encode(x)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z2 = net2.BottomUp(x)

print(z1[0] - z2[0])
print(z1[1] - z2[1])

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)


In [10]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z1 = net1.model.encode(x)
z1_ = net1.model.reparametrize(z1[0],z1[1])

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z2 = net2.BottomUp(x)
z2_ = net2.reparameterize(z2[0],z2[1])

print(z1_ - z2_)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)


In [11]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z1 = net1.model.forward(x)


torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z2 = net2.forward(x)

print(torch.sigmoid(z1[0]) - z2[0])
print((torch.sigmoid(z1[0]) - z2[0]).sum())


tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]], grad_fn=<SubBackward0>)
tensor(0., grad_fn=<SumBackward0>)


In [12]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
y1, (mu1, log_var1), _ = net1.model.forward(x)
recon_loss1 = net1.reconstruction_loss(y1, x).div(1)
# reg_loss1 = net1.beta * net1.gauss_vae_regulariser(mu1, log_var1)
capacity = min(net1.max_c, net1.max_c * net1.train_iter_f / net1.iterations_c)
kld1 = ((torch.exp(log_var1) + mu1*mu1 - log_var1 - 1)/2).sum(1).mean(0)
kld_c1 = (kld1 - capacity).abs()
reg_loss1 = net1.beta *kld_c1
loss1 = recon_loss1 + reg_loss1


C_max = torch.Tensor([20])
C_stop_iter = 1e5
global_iter = 0
gamma = 100
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
loss2_, kl, nll, y2 = net2.calc_loss(x, beta=99)
C = torch.clamp(C_max/C_stop_iter*global_iter, 0, C_max.item())
loss2 = -nll + gamma*(kl-C).abs()

print((torch.sigmoid(y1) - y2).sum())
print(recon_loss1 + nll)
print(reg_loss1 - gamma*(kl-C).abs())
print(loss1 - loss2)

print(net1.beta - gamma)
print(kld1 - kl)
print(capacity - C)



tensor(0., grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<AddBackward0>)
tensor([-0.0012], grad_fn=<SubBackward0>)
tensor([-0.0010], grad_fn=<SubBackward0>)
0.0
tensor(-1.1444e-05, grad_fn=<SubBackward0>)
tensor([0.])


In [13]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z1 = net1.model.encode(x)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
z2 = net2.BottomUp(x)

print(z1[0] - z2[0])
print(z1[1] - z2[1])

mu1, log_var1 = z1
kl1 = ((torch.exp(log_var1) + mu1*mu1 - log_var1 - 1)/2).sum(1).mean(0)

mu2, lv2 = z2
kl2 = (-0.5 * torch.sum(1 + lv2 - mu2.pow(2) - lv2.exp())+ 1e-5) / x.shape[0] 
kl2_ = (-0.5 * torch.sum(1 + lv2 - mu2.pow(2) - lv2.exp())) / x.shape[0] 

print(kl1 -kl2)
print(kl1 -kl2_)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)
tensor(-1.1444e-05, grad_fn=<SubBackward0>)
tensor(0., grad_fn=<SubBackward0>)


In [16]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
y1, (mu1, log_var1), _ = net1.model.forward(x)
recon_loss1 = net1.reconstruction_loss(y1, x).div(1)
net1.global_iter = 100
reg_loss1 = net1.beta * net1.gauss_vae_regulariser(mu1, log_var1)
loss1 = recon_loss1 + reg_loss1


C_max = torch.Tensor([20])
C_stop_iter = 1e5
global_iter = 100
gamma = 100
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
loss2_, kl, nll, y2 = net2.calc_loss(x, beta=99)
C = torch.clamp(C_max/C_stop_iter*global_iter, 0, C_max.item())
loss2 = -nll + gamma*(kl-C).abs()

print((torch.sigmoid(y1) - y2).sum())
print(recon_loss1 + nll)
print(reg_loss1 - gamma*(kl-C).abs())
print(loss1 - loss2)

tensor(0., grad_fn=<SumBackward0>)
tensor(0.0005, grad_fn=<AddBackward0>)
tensor([-0.0012], grad_fn=<SubBackward0>)
tensor([-0.0010], grad_fn=<SubBackward0>)
