## Pixyz API takes into account the features of deep generative models
- The Deep Neural Network that composes the generative model is hidden by the probability distribution
    - A framework that can separate defining DNNs and operating probability distributions(Distribution API)  
- Model types and regularization of random variables are described as objective functions(error functions)
    - A framework that receives probability distribution and define objective function(Loss API)  
- Deep generative models learn by defining objective function and using gradient descent method
    - A framework in which objective function and optimization algorithm can be set independently(Model API)
<img src="../tutorial_figs/pixyz_API.png">

In [None]:
# install pixyz
!pip install pixyz

In [None]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

from tqdm import tqdm

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

## Overviewing relationships between  each APIs through implementing VAE

### 1. Distribution API
- A framework that can separate defining DNNs and operating probability distributions(Distribution API)
- https://pixyz.readthedocs.io/en/latest/distributions.html

<img src="../tutorial_figs/vae_graphicalmodel.png">

We define these three probability distributions

Prior: $p(z) = N(z; 0, 1)$

Generator: $p_{\theta}(x|z) = B(x; \lambda = g(z))$

Inference: $q_{\phi}(z|x) = N(z; µ = f_{\mu}(x), \sigma^2 = f_{\sigma^2}(x))$

In [None]:
from pixyz.distributions import Normal, Bernoulli
from pixyz.utils import print_latex

#### Define prior probability distribution

prior is a gaussian distribution with mean 0 and variance 1

$p(z) = N(z; 0, 1)$

In [None]:
# prior
z_dim = 64
prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
              var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)
print(prior)

In [None]:
print_latex(prior)

#### Define generator probability distribution
Generator is a bernoulli distribution over x given z

$p_{\theta}(x|z) = B(x; \lambda = g(z))$

Inherit pixyz.Distribution class to define a distribution with Deep neural networks

In [None]:
x_dim = 784
# generative model p(x|z)
# inherit pixyz.Distribution Bernoulli class
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
p = Generator().to(device)
print(p)
print_latex(p)

#### Define Inference probability distribution

Inference is a gaussian distribution over z given x  
$\mu$ and $\sigma$ are parameterized by $\phi$

$q_{\phi}(z|x) = N(z; µ = f_{\mu}(x), \sigma^2 = f_{\sigma^2}(x))$

In [None]:
# inference model q(z|x)
# inherit pixyz.Distribution Normal class
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

q = Inference().to(device)
print(q)
print_latex(q)

#### Sampling from a probability distribution
- Sampling can be done by .sample() in defined Distribution class regardless of DNN architecture or distribution type
- In Pixyz, samples are dict type(key is variable name, value is sample)

$z\sim p(z)$

In [None]:
# z ~ p(z)
prior_samples = prior.sample(batch_n=1)
print(prior_samples)
print(prior_samples.keys())
print(prior_samples['z'].shape)

#### Define joint distribution
- joint distribution can be difined by multiplying distributions
    - Sampling can be done by .sample()

$p_{\theta}(x, z) = p_{\theta}(x|z)p(z)$

In [None]:
p_joint = p * prior
print(p_joint)
print_latex(p_joint)

#### Sampling from a joint distribution

$x, z \sim p_{\theta}(x, z) $

In [None]:
p_joint_samples = p_joint.sample(batch_n=1)
print(p_joint_samples)
print(p_joint_samples.keys())
print(p_joint_samples['x'].shape)
print(p_joint_samples['z'].shape)

#### For more detailed Distribution API Turorial
- 01-DistributionAPITutorial.ipynb

### 2. Loss API
- A framework that receives probability distribution and define objective function(Loss API)
    - pixyz.Loss receives Distribution and defines Loss
        - Arithmetic operations can be done between Loss classes, so any Loss can be designed
            - -> Paper's formula can be put into codes easily
- Loss value is evaluated by inputting the data
    - Each Loss is treated as symbol
        - Independent of data or DNN, we can design probabilistic model explicitly ->Define-and-run like framework

VAE Loss
$$
-\mathcal { L } _ { \mathrm { VAE } } ( \theta , \phi ) =   \mathbb { E } _ { p_{data}( x ) } \left [D _ { \mathrm { KL } } \left[ q _ \phi ( z | x ) \| p ( z ) \right] - \mathbb { E } _ { q _ { \phi } ( z | x ) } \left[\log p _ { \theta } ( x | z ) \right]\right]
$$

#### Define loss using pixyz.loss

In [None]:
log_p = p.log_prob()
print_latex(log_p)

In [None]:
from pixyz.losses import Expectation as E
reconst = E(q, log_p)
print_latex(reconst)

In [None]:
from pixyz.losses import KullbackLeibler
kl = KullbackLeibler(q, prior)
print_latex(kl)

#### Operations between Loss classes

In [None]:
vae_loss = (kl - reconst).mean()
print_latex(vae_loss)

#### Input data and loss is evaluated
- loss is calculated by .eval()

In [None]:
# dummy_x for data
dummy_x = torch.randn([4, 784]).to(device)
vae_loss.eval({"x": dummy_x})

#### For more detailed Loss API Turorial
- 02-LossAPITutorial.ipynb

### 3. Model API
- A framework in which objective function and optimization algorithm can be set independently
- Set loss and optimization algorithm, then train with data

In [None]:
from pixyz.models import Model
model = Model(loss=vae_loss, distributions=[p, q],
             optimizer=optim.Adam, optimizer_params={"lr": 1e-3})
print(model)
print_latex(model)

In [None]:
dummy_x = torch.randn([10, 784]).to(device)
loss = model.train({"x": dummy_x})
print('Train Loss: {:4f}'.format(loss))

#### For more detailed Model API Turorial
- 03-ModelAPITutorial.ipynb

### Training VAE with MNIST dataset

#### Import modules

In [None]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

batch_size = 256
epochs = 20
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


#### Prepare MNIST dataset

In [None]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

#### Import Pixyz modules

In [None]:
from pixyz.distributions import Normal, Bernoulli
from pixyz.losses import KullbackLeibler, Expectation as E
from pixyz.models import Model
from pixyz.utils import print_latex

#### Define probability distributions

In [None]:
x_dim = 784
z_dim = 64


# inference model q(z|x)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
    
p = Generator().to(device)
q = Inference().to(device)

prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)

In [None]:
print(prior)
print_latex(prior)

In [None]:
print(p)
print_latex(p)

In [None]:
print(q)
print_latex(q)

#### Define Loss

In [None]:
kl = KullbackLeibler(q, prior)
reconst = -p.log_prob().expectation(q)
vae_loss = (kl + reconst).mean()
print_latex(vae_loss)

#### Set optimization algorithm and model

In [None]:
model = Model(loss=vae_loss, distributions=[p, q],
             optimizer=optim.Adam, optimizer_params={"lr": 1e-3})
print(model)
print_latex(model)

In [None]:
def train(epoch):
    train_loss = 0
    #for x, _ in tqdm(train_loader):
    for x, _ in train_loader:
        x = x.to(device)
        loss = model.train({"x": x})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

def test(epoch):
    test_loss = 0
    #for x, _ in tqdm(test_loader):
    for x, _ in test_loader:
        x = x.to(device)
        loss = model.test({"x": x})
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

#### Reconstruction

In [None]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = q.sample({"x": x}, return_all=False)
        recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison

#### generate images from latent variable space

In [None]:
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample_mean({"z": z_sample}).view(-1, 1, 28, 28).cpu()
        return sample

In [None]:
# functions to show an image
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
# z_sample for generate imgs from prior
z_sample = 0.5 * torch.randn(64, z_dim).to(device)

# fixed _x for watching reconstruction improvement
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)
    
    print('Epoch: {}'.format(epoch))
    print('Reconstruction')
    imshow(torchvision.utils.make_grid(recon))
    print('generate from prior z:')
    imshow(torchvision.utils.make_grid(sample))