In [1]:
# %matplotlib widget
import torch
from torch import nn

import math
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np


In [2]:
device = ''
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), 784)
        output = self.model(x)
        return output

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):        
        output = self.model(x)
        output = output.view(x.size(0), 1, 28, 28)
        return output

In [5]:
generator = Generator().to(device=device)
discriminator = Discriminator().to(device=device)

generator.load_state_dict(torch.load("./generator_mnist_4.pt"))
discriminator.load_state_dict(torch.load("./discriminator_mnist_4.pt"))

generator.eval()
discriminator.eval()

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=256, out_features=1, bias=True)
    (10): Sigmoid()
  )
)

In [6]:
@widgets.interact(z_0=(-3, 3, .1),
                  z_1=(-3, 3, .1),
                  z_2=(-3, 3, .1),
                  z_3=(-3, 3, .1),
                 )
def update(z_0=0.0, 
           z_1=0.0,
           z_2=0.0,
           z_3=0.0,
          ):
    
    latent_space_samples = torch.tensor([[z_0,
                                          z_1,
                                          z_2,
                                          z_3,
                                         ]]).to(device=device)
                                          
    generated_samples = generator(latent_space_samples)
    generated_samples = generated_samples.cpu().detach().numpy()
    fig, ax = plt.subplots(figsize=(1.5, 1.5))
    plt.imshow(generated_samples.reshape(28, 28), cmap='gray_r')  

interactive(children=(FloatSlider(value=0.0, description='z_0', max=3.0, min=-3.0), FloatSlider(value=0.0, des…