In [93]:
import torch
from torch import nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F

In [94]:
def reparametrize(mu, logvar):
    std = logvar.mul(0.5).exp_()
    eps = std.data.new(std.size()).normal_()
    return eps.mul(std).add_(mu)

class VAEModel(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder

    def encode(self, x, **kwargs):
        return self.encoder(x)

    def decode(self, z, **kwargs):
        return torch.sigmoid(self.decoder(z))

    def forward(self, x, **kwargs):
        mu, logvar = self.encode(x)
        z = reparametrize(mu, logvar)
        #NOTE: we use decodeR(!) here not decode so we can use BCEWithLogitsLoss
        return self.decoder(z), (mu, logvar)

class Flatten3D(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x

class Unsqueeze3D(nn.Module):
    def forward(self, x):
        x = x.unsqueeze(-1)
        x = x.unsqueeze(-1)
        return x

def _init_layer(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
    if isinstance(m, torch.nn.Linear):
        init.kaiming_normal_(m.weight.data)


def init_layers(modules):
    for block in modules:
        from collections.abc import Iterable
        if isinstance(modules[block], Iterable):
            for m in modules[block]:
                _init_layer(m)
        else:
            _init_layer(modules[block])

In [95]:
class BaseImageEncoder(nn.Module):
    def __init__(self, latent_dim, num_channels, image_size):
        super().__init__()

        self._latent_dim = latent_dim
        self._num_channels = num_channels
        self._image_size = image_size

    def forward(self, *input):
        raise NotImplementedError

    def latent_dim(self):
        return self._latent_dim

    def num_channels(self):
        return self._num_channels

    def image_size(self):
        return self._image_size

class SimpleConv56(BaseImageEncoder):
    def __init__(self, latent_dim, num_channels, image_size):
        super().__init__(latent_dim, num_channels, image_size)
        assert image_size == 56, 'This model only works with image size 64x64.'

        self.main = nn.Sequential(
            nn.Conv2d(num_channels, 32, 4, 2, 1), #28x28
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1), #14x14
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), #7x7
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 1, 0), #4x4
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1), #2x2
            nn.ReLU(True),
            nn.Conv2d(256, 256, 4, 2, 1), #2x2
            nn.ReLU(True),
            Flatten3D(),
            nn.Linear(256, latent_dim, bias=True)
        )

        init_layers(self._modules)

    def forward(self, x):
        return self.main(x)


class SimpleGaussianConv56(SimpleConv56):
    def __init__(self, latent_dim, num_channels, image_size):
        super().__init__(latent_dim * 2, num_channels, image_size)

        # override value of _latent_dim
        self._latent_dim = latent_dim

    def forward(self, x):
        mu_logvar = self.main(x)
        mu = mu_logvar[:, :self._latent_dim]
        logvar = mu_logvar[:, self._latent_dim:]
        return mu, logvar


In [96]:
class SimpleConv56Decoder(BaseImageEncoder):
    def __init__(self, latent_dim, num_channels, image_size):
        super().__init__(latent_dim, num_channels, image_size)
        assert image_size == 56, 'This model only works with image size 64x64.'

        self.main = nn.Sequential(
            Unsqueeze3D(),
            nn.Conv2d(latent_dim, 256, 1, 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 256, 4, 2, 1), #2x2
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2), #6x6
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 128, 4, 2, 1), #12x12
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2), #26x26
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 64, 4, 2), #54x54
            nn.ReLU(True), 
            nn.ConvTranspose2d(64, num_channels, 3, 1) # 56x56
        )
        # output shape = bs x 3 x 64 x 64

        init_layers(self._modules)

    def forward(self, x):
        return self.main(x)


In [97]:
from torchvision.datasets import ImageFolder
from torchvision import transforms

normalize = transforms.Normalize(
    mean=0.4897,
    std=0.1285,
)


transform = transforms.Compose([
    transforms.ToTensor(),
    # normalize, Disabled for now because of BCELoss
])
ds = ImageFolder("./data", transform=transform)

In [98]:
z_dim = 25
num_channels = 3
image_size = 56
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [99]:
model = VAEModel(SimpleGaussianConv56(z_dim, num_channels, image_size), SimpleConv56Decoder(z_dim, num_channels, image_size)).to(device)

In [100]:
save_path="./checkpoints/"
# filename="vae_2021-02-04_11-09-53_[488194.2].save"
filename="vae_2021-01-28_14-09-52_26.959888.save"
model.load_state_dict(torch.load(save_path+filename))
model.eval()

VAEModel(
  (encoder): SimpleGaussianConv56(
    (main): Sequential(
      (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): ReLU(inplace=True)
      (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Flatten3D()
      (13): Linear(in_features=256, out_features=50, bias=True)
    )
  )
  (decoder): SimpleConv56Decoder(
    (main): Sequential(
      (0): Unsqueeze3D()
      (1): Conv2d(25, 256, kernel_size=(1, 1), stride=(2, 2))
      (2): ReLU(inplace=True)
      (3): ConvTranspose2d

In [101]:
import random
limit=3
inter=2/3

In [102]:
interpolation = torch.arange(-limit, limit+0.1, inter)

In [103]:
#rand_idx = random.randint(1, n_dsets-1)
rand_idx = 0

In [104]:
random_img = ds.__getitem__(rand_idx)[0].cuda().unsqueeze(0)
random_img_z = model.encode(random_img)[0]

In [105]:
random_img_z

tensor([[-0.2995, -0.0123, -0.0139,  0.0058, -0.0096,  0.0086, -0.0059,  0.0032,
          0.0046, -0.0117, -0.0123,  0.0065,  0.0025, -0.0093, -0.0074,  0.0054,
          0.8741, -0.0147, -0.0034,  0.0088,  0.0015,  0.0035, -0.0096,  0.0057,
          0.0125]], device='cuda:0', grad_fn=<SliceBackward>)

In [106]:
random_z = torch.rand(1, z_dim).cuda()

In [107]:
gifs = []

In [108]:
samples = []
z_ori = random_img_z
for z_index in range(z_dim):
    z = z_ori.clone()
    for val in interpolation:
        z[:, z_index] = val
        sample = model.decode(z).data
        samples.append(sample)
        gifs.append(sample)

samples = torch.cat(samples).cpu()

In [109]:
gifs = samples.view(z_dim, len(interpolation), 3, 56, 56).transpose(0,1)

In [110]:
from torchvision.utils import save_image
import subprocess

def grid2gif(image_str, output_gif, delay=100):
    """Make GIF from images.
    code from:
        https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python/34555939#34555939
    """
    str1 = 'magick convert -delay '+str(delay)+' -loop 0 ' + image_str  + ' ' + output_gif
    subprocess.call(str1, shell=True)

In [111]:
output_dir = "./viz/"
key = "rand"
for j, val in enumerate(interpolation):
    save_image(tensor=gifs[j].cpu(),
                fp=os.path.join(output_dir, '{}_{}.png'.format(key, j)),
                nrow=z_dim, pad_value=1)

grid2gif(os.path.join(output_dir, key+'*.png'),
            os.path.join(output_dir, key+'.gif'), delay=10)