In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F

import pixyz.distributions as pxd
import pixyz.losses as pxl
import pixyz.models as pxm

# SRNN

# Generative model

In [4]:
class ForwardRNN(pxd.Deterministic):
    def __init__(self, u_dim, d_dim):
        super().__init__(cond_var=["u", "d_prev"], var=["d"])

        self.rnn_cell = nn.GRUCell(u_dim, d_dim)
        self.d0 = nn.Parameter(torch.zeros(1, 1, d_dim))

    def forward(self, u, d_prev):
        d = self.rnn_cell(u, d_prev)
        return {"d": d}

class Prior(pxd.Normal):
    def __init__(self, d_dim, z_dim):
        super().__init__(cond_var=["z_prev", "d"], var=["z"])

        self.fc1 = nn.Linear(d_dim + z_dim, 512)
        self.fc21 = nn.Linear(512, z_dim)
        self.fc22 = nn.Linear(512, z_dim)

    def forward(self, z_prev, d):
        h = F.relu(self.fc1(torch.cat([z_prev, d], dim=-1)))
        scale = self.fc21(h)
        loc = F.softplus(self.fc22(h))
        return {"scale": scale, "loc": loc}

class Generator(pxd.Bernoulli):
    def __init__(self, z_dim, d_dim, x_dim):
        super().__init__(cond_var=["z", "d"], var=["x"])

        self.fc1 = nn.Linear(z_dim + d_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, x_dim)

    def forward(self, z, d):
        h = F.relu(self.fc1(torch.cat([z, d], dim=-1)))
        h = F.relu(self.fc2(h))
        probs = torch.sigmoid(self.fc3(h))
        return {"probs": probs}

# Variational Model

In [2]:
class BackwardRNN(pxd.Deterministic):
    def __init__(self, x_dim, d_dim, a_dim):
        super().__init__(cond_var=["x", "d"], var=["a"])

        self.rnn = nn.GRU(x_dim + d_dim, a_dim, bidirectional=True)
        self.a0 = nn.Parameter(torch.zeros(2, 1, a_dim))

    def forward(self, x, d):
        a0 = self.a0.expand(2, x.size(1), self.a0.size(2)).contiguous()
        a, _ = self.rnn(torch.cat([x, d], dim=-1), a0)
        return {"a": a[:, :, a.size(2) // 2:]}


class VariationalPrior(pxd.Normal):
    def __init__(self, z_dim, a_dim):
        super().__init__(cond_var=["z_prev", "a"], var=["z"])

        self.fc1 = nn.Linear(z_dim + a_dim, 512)
        self.fc21 = nn.Linear(512, z_dim)
        self.fc22 = nn.Linear(512, z_dim)

    def forward(self, z_prev, a):
        h = F.relu(self.fc1(torch.cat([z_prev, a], dim=-1)))
        scale = self.fc21(h)
        loc = F.softplus(self.fc22(h))
        return {"scale": scale, "loc": loc}

In [3]:
x_dim = 2
t_dim = 5
device = "cpu"
u_dim = x_dim

# Latent dimension
d_dim = 3
z_dim = 4
a_dim = 2

In [5]:
prior = Prior(d_dim, z_dim).to(device)
frnn = ForwardRNN(u_dim, d_dim).to(device)
decoder = Generator(z_dim, d_dim, x_dim).to(device)
brnn = BackwardRNN(x_dim, d_dim, a_dim).to(device)
encoder = VariationalPrior(z_dim, a_dim).to(device)

In [6]:
print(prior)
print(frnn)
print(decoder)
print(brnn)
print(encoder)

Distribution:
  p(z|z_{prev},d)
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev', 'd'], input_var=['z_prev', 'd'], features_shape=torch.Size([])
    (fc1): Linear(in_features=7, out_features=512, bias=True)
    (fc21): Linear(in_features=512, out_features=4, bias=True)
    (fc22): Linear(in_features=512, out_features=4, bias=True)
  )
Distribution:
  p(d|u,d_{prev})
Network architecture:
  ForwardRNN(
    name=p, distribution_name=Deterministic,
    var=['d'], cond_var=['u', 'd_prev'], input_var=['u', 'd_prev'], features_shape=torch.Size([])
    (rnn_cell): GRUCell(2, 3)
  )
Distribution:
  p(x|z,d)
Network architecture:
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z', 'd'], input_var=['z', 'd'], features_shape=torch.Size([])
    (fc1): Linear(in_features=7, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (fc3): Linear(in_features=256, out_features=2

In [7]:
ce = pxl.CrossEntropy(encoder * frnn, decoder)
print(ce)

- \mathbb{E}_{p(z,d|z_{prev},a,u,d_{prev})} \left[\log p(x|z,d) \right]


In [8]:
kl = pxl.KullbackLeibler(encoder, prior)
print(kl)

D_{KL} \left[p(z|z_{prev},a)||p(z|z_{prev},d) \right]


# Sample

In [17]:
device = "cpu"
minibatch_size = 1

data = {
    "z_prev": torch.zeros(minibatch_size, z_dim).to(device),
    "d_prev": torch.zeros(minibatch_size, d_dim).to(device),
    "u": torch.zeros(minibatch_size, x_dim).to(device),
    "dummy": torch.zeros(minibatch_size, x_dim).to(device),
}

In [18]:
sample = (prior * frnn * decoder).sample(data)

In [19]:
sample

{'z_prev': tensor([[0., 0., 0., 0.]]),
 'd_prev': tensor([[0., 0., 0.]]),
 'u': tensor([[0., 0.]]),
 'dummy': tensor([[0., 0.]]),
 'd': tensor([[-0.0011,  0.0067,  0.1337]], grad_fn=<AddBackward0>),
 'z': tensor([[0.7717, 0.7760, 0.6791, 0.6562]]),
 'x': tensor([[0., 0.]])}

In [20]:
(prior * frnn * decoder).input_var

['z_prev', 'u', 'd_prev']

In [40]:
x_t = decoder.sample_mean({"z": sample["z"], "d": sample["d"]})

x_t

tensor([[0.5095, 0.4868]], grad_fn=<SigmoidBackward>)