In [1]:
import datetime
import os
from tensorboardX import SummaryWriter

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pylab as plt

import torchvision.datasets as datasets
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(
    root="../chap13/data", train=True, transform=transform, download=True)

test_dataset = datasets.MNIST(
    root="../chap13/data", train=False, transform=transform, download=True)

train_loader = DataLoader(
    train_dataset, batch_size=100, shuffle=True, num_workers=4, pin_memory=False)

test_loader = DataLoader(
    test_dataset, batch_size=100, shuffle=False, num_workers=4)

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.input1 = nn.Linear(input_dim, hidden_dim)
        self.input2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.var = nn.Linear(hidden_dim, latent_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)
        self.training = True

    def forward(self, x):
        h_ = self.LeakyReLU(self.input1(x))
        h_ = self.LeakyReLU(self.input2(h_))
        mean = self.mean(h_)
        log_var = self.var(h_)
        return mean, log_var

In [4]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.hidden1 = nn.Linear(latent_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)
        self.LeakyReLU = nn.LeakyReLU(0.2)

    def forward(self, x):
        h = self.LeakyReLU(self.hidden1(x))
        h = self.LeakyReLU(self.hidden2(h))
        x_hat = torch.sigmoid(self.output(h))
        return x_hat

In [5]:
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)
        z = mean + var*epsilon
        return z

    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x_hat = self.Decoder(z)
        return x_hat, mean, log_var

In [6]:
x_dim = 784
hidden_dim = 400
latent_dim = 200
epochs = 30
batch_size = 100

encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim,
                  latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim,
                  hidden_dim=hidden_dim, output_dim=x_dim)

model = Model(Encoder=encoder, Decoder=decoder).to(device)

In [7]:
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(
        x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reproduction_loss, KLD


optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [8]:
saved_loc = 'scalar/'
writer = SummaryWriter(saved_loc)

model.train()


def train(epoch, model, train_loader, optimizer):
    train_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.view(batch_size, x_dim)
        x = x.to(device)

        optimizer.zero_grad()
        x_hat, mean, log_var = model(x)
        BCE, KLD = loss_function(x, x_hat, mean, log_var)
        loss = BCE + KLD
        writer.add_scalar("Train/Reconstruction Error", BCE.item(),
                          batch_idx + epoch * (len(train_loader.dataset)/batch_size))
        writer.add_scalar("Train/KL-Divergence", KLD.item(),
                          batch_idx + epoch * (len(train_loader.dataset)/batch_size))
        writer.add_scalar("Train/Total Loss", loss.item(), batch_idx +
                          epoch * (len(train_loader.dataset)/batch_size))

        train_loss += loss.item()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(x)))

    print("======> Epoch: {} Average loss: {:.4f}".format(
        epoch, train_loss / len(train_loader.dataset)))

In [9]:
def test(epoch, model, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (x, _) in enumerate(test_loader):
            x = x.view(batch_size, x_dim)
            x = x.to(device)
            x_hat, mean, log_var = model(x)
            BCE, KLD = loss_function(x, x_hat, mean, log_var)
            loss = BCE + KLD

            writer.add_scalar("Test/Reconstruction Error", BCE.item(),
                              batch_idx + epoch * (len(test_loader.dataset)/batch_size))
            writer.add_scalar("Test/KL-Divergence", KLD.item(),
                              batch_idx + epoch * (len(test_loader.dataset)/batch_size))
            writer.add_scalar("Test/Total Loss", loss.item(), batch_idx +
                              epoch * (len(test_loader.dataset)/batch_size))
            test_loss += loss.item()

            if batch_idx == 0:
                n = min(x.size(0), 8)
                comparison = torch.cat(
                    [x[:n], x_hat.view(batch_size, x_dim)[:n]])
                grid = torchvision.utils.make_grid(comparison.cpu())
                writer.add_image(
                    "Test image - Above: Real data, below: reconstruction data", grid, epoch)

In [10]:
from tqdm.auto import tqdm
for epoch in tqdm(range(0, epochs)):
    train(epoch, model, train_loader, optimizer)
    test(epoch, model, test_loader)
    print("\n")
writer.close()

  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/30 [00:00<?, ?it/s]



  3%|▎         | 1/30 [00:01<00:32,  1.12s/it]





  7%|▋         | 2/30 [00:02<00:30,  1.10s/it]





 10%|█         | 3/30 [00:03<00:28,  1.07s/it]





 13%|█▎        | 4/30 [00:04<00:27,  1.05s/it]





 17%|█▋        | 5/30 [00:05<00:26,  1.06s/it]





 20%|██        | 6/30 [00:06<00:25,  1.05s/it]





 23%|██▎       | 7/30 [00:07<00:24,  1.05s/it]





 27%|██▋       | 8/30 [00:08<00:23,  1.05s/it]





 30%|███       | 9/30 [00:09<00:22,  1.05s/it]





 33%|███▎      | 10/30 [00:10<00:20,  1.04s/it]





 37%|███▋      | 11/30 [00:11<00:19,  1.03s/it]





 40%|████      | 12/30 [00:12<00:18,  1.04s/it]





 43%|████▎     | 13/30 [00:13<00:17,  1.04s/it]





 47%|████▋     | 14/30 [00:14<00:17,  1.06s/it]





 50%|█████     | 15/30 [00:15<00:15,  1.06s/it]





 53%|█████▎    | 16/30 [00:16<00:14,  1.04s/it]





 57%|█████▋    | 17/30 [00:17<00:13,  1.04s/it]





 60%|██████    | 18/30 [00:18<00:12,  1.06s/it]





 63%|██████▎   | 19/30 [00:20<00:11,  1.06s/it]





 67%|██████▋   | 20/30 [00:21<00:10,  1.06s/it]





 70%|███████   | 21/30 [00:22<00:09,  1.07s/it]





 73%|███████▎  | 22/30 [00:23<00:08,  1.05s/it]





 77%|███████▋  | 23/30 [00:24<00:07,  1.05s/it]





 80%|████████  | 24/30 [00:25<00:06,  1.07s/it]





 83%|████████▎ | 25/30 [00:26<00:05,  1.06s/it]





 87%|████████▋ | 26/30 [00:27<00:04,  1.05s/it]





 90%|█████████ | 27/30 [00:28<00:03,  1.04s/it]





 93%|█████████▎| 28/30 [00:29<00:02,  1.05s/it]





 97%|█████████▋| 29/30 [00:30<00:01,  1.07s/it]





100%|██████████| 30/30 [00:31<00:00,  1.06s/it]








In [11]:
%load_ext tensorboard
%tensorboard --logdir scalar --port=9000

ERROR: Failed to launch TensorBoard (exited with 255).
Contents of stderr:
TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

E1010 02:03:54.039050 139957028296512 program.py:300] TensorBoard could not bind to port 9000, it was already in use
ERROR: TensorBoard could not bind to port 9000, it was already in use