In [11]:
import numpy as np
import sklearn
from sklearn import model_selection
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import matplotlib.gridspec as gridspec


def load_data(data_name):
    with np.load(data_name) as fh:
        data_x = fh['data_x']
        data_y = fh['data_y']
    return data_x, data_y



## Model definition
## Encoder


class Encoder(torch.nn.Module):
    """Documentation for Encoder

    """
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.e1 = torch.nn.Linear(input_dim, hidden_dim)
        self.e2 = torch.nn.Linear(hidden_dim, 2*hidden_dim)
        self.e3 = torch.nn.Linear(2*hidden_dim, latent_dim)
        self.e4 = torch.nn.Linear(2*hidden_dim, latent_dim)

    def forward(self, x):
        x = F.leaky_relu(self.e1(x))
        x = F.leaky_relu(self.e2(x))
        mean = self.e3(x)
        log_variance = self.e4(x)
        return mean, log_variance


## Decoder


class Decoder(torch.nn.Module):
    """Documentation for Decoder

    """
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.d1 = torch.nn.Linear(latent_dim, 2*hidden_dim)
        self.d2 = torch.nn.Linear(2*hidden_dim, hidden_dim)
        self.d3 = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.leaky_relu(self.d1(x))
        x = F.leaky_relu(self.d2(x))
        output = torch.sigmoid(self.d3(x))        
        return output


## VAE


class VAE(torch.nn.Module):
    """Documentation for VAE

    """
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, log_variance):
        variance = torch.exp(.5 * log_variance)
        epsilon = torch.randn_like(variance).to(DEVICE)
        z = mean + variance * epsilon
        return z

    def forward(self, x):
        mean, log_variance = self.Encoder(x)
        z = self.reparameterization(mean, log_variance)
        output = self.Decoder(z)
        return output, mean, log_variance


## Loss function

def loss_func(output, x, mean, log_variance):
    loss_reproduction = F.binary_cross_entropy(output, x, reduction='sum')
    kld = -.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())
    return loss_reproduction + kld


In [12]:
data_name = 'vae-cvae-challenge.npz'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x_dim = 784
batch_size = 1024
hidden_dim = 256
latent_dim = 2
epochs = 300
lr = .001
# load the data
data_x, data_y = load_data(data_name)
print(data_x.shape, data_y.shape)

(20000, 784) (20000,)


In [13]:
x_train, x_test, y_train, y_test = model_selection.train_test_split(data_x, data_y, test_size=.33)
data_train = torch.from_numpy(x_train)
label_train = torch.from_numpy(y_train)
data_test = torch.from_numpy(x_test)
label_test = torch.from_numpy(y_test)
dataset_test = TensorDataset(data_test, label_test)

dataset_train = TensorDataset(data_train, label_train)   
dl_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

In [14]:
encoder = Encoder(x_dim, hidden_dim, latent_dim)
decoder = Decoder(latent_dim, hidden_dim, x_dim)
vae = VAE(encoder, decoder)
vae = vae.to(DEVICE)
optimizer = Adam(vae.parameters(), lr=lr)

In [15]:
loss_list = []
for epoch in tqdm(range(epochs)):
    losses = 0
    # unsupervised.
    for idx, (x, _) in enumerate(dl_train):
        x = x.to(DEVICE)
        optimizer.zero_grad()
        output, mean, log_variance = vae(x)
        loss = loss_func(output, x, mean, log_variance)
        loss.backward()
        optimizer.step()
        losses += loss.item()
        if x.shape[0] < batch_size:
            loss_list += [loss.item() / x.shape[0]]
        else:
            loss_list += [loss.item() / batch_size]
    #print(f' training loss -- {losses / (idx * batch_size)}') 

  0%|▎                                                                                 | 1/300 [00:00<01:08,  4.37it/s]

 training loss -- 393.60636006868805


  1%|▌                                                                                 | 2/300 [00:00<01:03,  4.69it/s]

 training loss -- 222.2914644388052


  1%|▊                                                                                 | 3/300 [00:00<01:05,  4.51it/s]

 training loss -- 211.4545308626615


  2%|█▎                                                                                | 5/300 [00:01<01:02,  4.69it/s]

 training loss -- 206.05731494610126
 training loss -- 202.16491655203012


  2%|█▋                                                                                | 6/300 [00:01<00:59,  4.92it/s]

 training loss -- 198.44672100360577


  2%|█▉                                                                                | 7/300 [00:01<01:00,  4.83it/s]

 training loss -- 195.49705784137433


  3%|██▏                                                                               | 8/300 [00:01<01:01,  4.75it/s]

 training loss -- 193.06100258460413


  3%|██▍                                                                               | 9/300 [00:01<01:02,  4.69it/s]

 training loss -- 189.84704597179706


  4%|██▉                                                                              | 11/300 [00:02<01:00,  4.76it/s]

 training loss -- 184.75047397613525
 training loss -- 181.0928451831524


  4%|███▏                                                                             | 12/300 [00:02<00:58,  4.94it/s]

 training loss -- 178.67774882683386


  4%|███▌                                                                             | 13/300 [00:02<01:03,  4.55it/s]

 training loss -- 176.407762160668





KeyboardInterrupt: 

## Save model

In [None]:
torch.save(vae.state_dict(), "vae_model")

In [None]:
import matplotlib.pyplot as plt
# visualize losses
f, ax = plt.subplots(figsize=(16, 5))
ax.plot(loss_list)
ax.title.set_text("Train loss")

## Load model

In [None]:
vae = VAE(Encoder(x_dim, hidden_dim, latent_dim),Decoder(latent_dim, hidden_dim, x_dim))
vae.load_state_dict(torch.load("vae_model"))
vae = vae.to(DEVICE)

In [None]:
#plt.cla()
fig, axe = plt.subplots(figsize=(16, 10))
data_test = data_test.to(DEVICE)
mu, log_variance = vae.Encoder(data_test)
mu_x = mu.cpu().detach().numpy()[:, 0]
mu_y = mu.cpu().detach().numpy()[:, 1]
labels = label_test.detach().numpy()
plt.scatter(mu_x, mu_y, c=labels, cmap=plt.cm.get_cmap('Spectral', 10), alpha=1, edgecolors='black')
#plt.xlim(np.percentile(mu_x, 0.1), np.percentile(mu_y, 99.9))
#plt.ylim(np.percentile(mu_y, 0.1), np.percentile(mu_y, 99.9))
plt.colorbar()
#plt.show()

In [None]:
samples, _, _ = vae(data_test)
grid_x = 6
sample_size = grid_x**2
ids = np.random.randint(0, samples.shape[0], sample_size)
samples = samples[ids].cpu().detach().numpy()
fig_dec = plt.figure(figsize=(grid_x, grid_x))
gs = gridspec.GridSpec(grid_x, grid_x)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

In [None]:
samples = data_test
grid_x = 6
sample_size = grid_x**2
samples = samples[ids].cpu().detach().numpy()
fig_dec = plt.figure(figsize=(grid_x, grid_x))
gs = gridspec.GridSpec(grid_x, grid_x)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(sample.reshape(28, 28), cmap='Greys_r')