### Implement VAE-Like model via pixyz

In [1]:
from torch.utils import data

from utils import split_dataset
from datasets import OppG
from opportunity import Encoder, ContextEncoder, Predictor

In [2]:
K = 12
L=3

print("Load datasets ...")
dataset_joint = OppG('S2,S3,S4', 'Gestures', l_sample=30, interval=15, T=K+L)
train_dataset_joint, valid_dataset_joint = split_dataset(dataset_joint)
train_loader_joint = data.DataLoader(dataset_joint, batch_size=128, shuffle=True)
valid_loader_joint = data.DataLoader(dataset_joint, batch_size=128, shuffle=False)

Load datasets ...


In [8]:
import torch
from torch import nn
from torch.nn import functional as F
from pixyz.distributions import Normal, Normal
from pixyz.losses import KullbackLeibler
from pixyz.models import VAE


class Inference(Normal):
    def __init__(self, network, network_output, z_size):
        super().__init__(cond_var=["x"], var=["z"], name="q")
        
        self.network = network
        self.network_mu = nn.Linear(network_output, z_size)
        self.network_sigma = nn.Linear(network_output, z_size)
        
    def forward(self, x):
        h = self.network(x)
        return {"loc": self.network_mu(h), "scale": F.softplus(self.network_sigma(h))}
    

class Geneator(Normal):
    def __init__(self):
        super().__init__(cond_var=["z"], var=["x"], name="p")
        self.fc = nn.Linear(z_size, g_enc.output_shape()[1]).cuda()
        self.deconv1 = nn.ConvTranspose2d(20, 40, kernel_size=(1, 3), stride=(1, 2))
        self.deconv2 = nn.ConvTranspose2d(40, 50, kernel_size=(1, 5), stride=(1, 2))
        self.deconv3 = nn.ConvTranspose2d(50, 1, kernel_size=(1, 5), stride=(1, 2), output_padding=(0, 1))
        
    def forward(self, z):
        h = self.fc(z)
        h = h.view(-1, 20, 113, 2)
        h = self.deconv1(h)
        h = self.deconv2(h)
        h = self.deconv3(h)
        return {"loc": h, "scale": torch.tensor(1.0).cuda()}

In [9]:
z_size = 400
g_enc = Encoder(input_shape=dataset_joint.get('input_shape'), hidden_size=None).cuda()
q = Inference(g_enc, network_output=g_enc.output_shape()[1], z_size=z_size).cuda()
p = Geneator().cuda()

# prior
loc = torch.tensor(0.).cuda()
scale = torch.tensor(1.).cuda()
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_size, name="p_prior")

In [15]:
X, _ = train_loader_joint.__iter__().__next__()

In [22]:
q.eval()
mu = q(X[..., 0].float().cuda())['loc']
print(mu[0, 0])

mu = q(X[..., 0].float().cuda())['loc']
print(mu[0, 0])
q.train()

tensor(-0.0179, device='cuda:0', grad_fn=<SelectBackward>)
tensor(-0.0179, device='cuda:0', grad_fn=<SelectBackward>)


Inference(
  (network): Encoder(
    (conv): Sequential(
      (0): Conv2d(1, 50, kernel_size=(1, 5), stride=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(50, 40, kernel_size=(1, 5), stride=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(40, 20, kernel_size=(1, 3), stride=(1, 1))
      (7): ReLU()
      (8): Dropout(p=0.5)
      (9): Flatten()
    )
  )
  (network_mu): Linear(in_features=4520, out_features=400, bias=True)
  (network_sigma): Linear(in_features=4520, out_features=400, bias=True)
)

In [23]:
from torch import optim
kl = KullbackLeibler(q, prior)
model = VAE(q, p, regularizer=kl, optimizer=optim.Adam, optimizer_params={"lr": 0.001})
print(model)

Distributions (for training): 
  q(z|x), p(x|z) 
Loss function: 
  mean(-E_q(z|x)[log p(x|z)] + KL[q(z|x)||p_prior(z)]) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


In [24]:
report_per = 100
for num_iter in range(1000):
    x, _ = train_loader_joint.__iter__().__next__()
    loss = model.train({"x": x[..., 0].float().cuda()})
    
    if ((num_iter+1) % report_per) !=0:
        continue
    print(loss.item())

4127.953125
4194.474609375
3985.358154296875
3970.8154296875
4044.540283203125
4052.95556640625
4344.876953125
4035.642578125
3888.2109375
3900.81884765625


In [132]:
torch.save(q.state_dict(), 'test.pth')