# Model inspection

In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F

import pixyz.distributions as pxd
import pixyz.losses as pxl
import pixyz.models as pxm
import pixyz.utils as pxu

In [2]:
device = "cpu"

# Model

In [3]:
class GeneratorRNN(pxd.Deterministic):
    def __init__(self, z_dim, u_dim, h_dim):
        super().__init__(cond_var=["z", "u", "h_prev"], var=["h"])

        self.rnn_cell = nn.RNNCell(z_dim + u_dim, h_dim)

    def forward(self, z, u, h_prev):
        h = self.rnn_cell(torch.cat([z, u], dim=-1), h_prev)
        return {"h": h}

In [4]:
print(GeneratorRNN(2, 3, 4))

Distribution:
  p(h|z,u,h_{prev})
Network architecture:
  GeneratorRNN(
    name=p, distribution_name=Deterministic,
    var=['h'], cond_var=['z', 'u', 'h_prev'], input_var=['z', 'u', 'h_prev'], features_shape=torch.Size([])
    (rnn_cell): RNNCell(5, 4)
  )


In [5]:
class Generator(pxd.Bernoulli):
    # TODO: `h_prev` is not updated.
    def __init__(self, h_dim, x_dim):
        super().__init__(cond_var=["h"], var=["x"])

        self.fc1 = nn.Linear(h_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, x_dim)

    def forward(self, h):
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        probs = torch.sigmoid(self.fc3(h))
        return {"probs": probs}

In [6]:
print(Generator(2, 3))

Distribution:
  p(x|h)
Network architecture:
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['h'], input_var=['h'], features_shape=torch.Size([])
    (fc1): Linear(in_features=2, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (fc3): Linear(in_features=256, out_features=3, bias=True)
  )


In [7]:
class InferenceRNN(pxd.Deterministic):
    def __init__(self, x_dim, z_dim):
        super().__init__(cond_var=["x"], var=["h_v"])

        self.rnn = nn.RNN(x_dim, z_dim * 2)
        self.h0 = nn.Parameter(torch.zeros(1, 1, z_dim * 2))

    def forward(self, x):
        h0 = self.h0.expand(1, x.size(1), self.h0.size(2)).contiguous()
        h, _ = self.rnn(x, h0)
        return {"h_v": h}

In [8]:
print(InferenceRNN(2, 3))

Distribution:
  p(h_{v}|x)
Network architecture:
  InferenceRNN(
    name=p, distribution_name=Deterministic,
    var=['h_v'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (rnn): RNN(2, 6)
  )


In [9]:
class Inference(pxd.Normal):
    def __init__(self):
        super().__init__(cond_var=["h_v"], var=["z"])

    def forward(self, h_v):
        loc = h_v[:, :h_v.size(1) // 2]
        scale = h_v[:, h_v.size(1) // 2:] ** 2
        return {"loc": loc, "scale": scale}

In [10]:
z_dim = 2
h_dim = 4
x_dim = 10
u_dim = x_dim

prior = pxd.Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
                    var=["z"], features_shape=torch.Size([z_dim])).to(device)
grnn = GeneratorRNN(z_dim, u_dim, h_dim).to(device)
decoder = Generator(h_dim, x_dim).to(device)
irnn = InferenceRNN(x_dim, z_dim).to(device)
encoder = Inference().to(device)

In [11]:
ce = pxl.CrossEntropy(grnn * encoder, decoder)  #.expectation(grnn)
pxu.print_latex(ce)

<IPython.core.display.Math object>

# Data sample

In [12]:
(decoder * grnn * prior)

Normal(
  name=p, distribution_name=Normal,
  var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([2])
  (loc): torch.Size([1, 2])
  (scale): torch.Size([1, 2])
)
GeneratorRNN(
  name=p, distribution_name=Deterministic,
  var=['h'], cond_var=['z', 'u', 'h_prev'], input_var=['z', 'u', 'h_prev'], features_shape=torch.Size([])
  (rnn_cell): RNNCell(12, 4)
)
Generator(
  name=p, distribution_name=Bernoulli,
  var=['x'], cond_var=['h'], input_var=['h'], features_shape=torch.Size([])
  (fc1): Linear(in_features=4, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=10, bias=True)
)

In [13]:
minibatch_size = 1

data = {
    "h_prev": torch.zeros(minibatch_size, h_dim).to(device),
    "u": torch.zeros(minibatch_size, x_dim).to(device),
}

In [14]:
sample = (decoder * grnn * prior).sample(data)

In [15]:
sample

{'h_prev': tensor([[0., 0., 0., 0.]]),
 'u': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 'z': tensor([[0.4675, 0.4149]]),
 'h': tensor([[ 0.0749, -0.2037,  0.3348,  0.0514]], grad_fn=<TanhBackward>),
 'x': tensor([[0., 0., 1., 1., 0., 0., 0., 0., 0., 1.]])}

In [16]:
decoder.sample_mean({"h": sample["h"]})

tensor([[0.5168, 0.5016, 0.5092, 0.5175, 0.5024, 0.4851, 0.4986, 0.4875, 0.5098,
         0.4952]], grad_fn=<SigmoidBackward>)

# Distributions class

In [17]:
print(decoder)

Distribution:
  p(x|h)
Network architecture:
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['h'], input_var=['h'], features_shape=torch.Size([])
    (fc1): Linear(in_features=4, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=256, bias=True)
    (fc3): Linear(in_features=256, out_features=10, bias=True)
  )


In [18]:
decoder.replace_params_dict

{}

In [19]:
sample

{'h_prev': tensor([[0., 0., 0., 0.]]),
 'u': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 'z': tensor([[0.4675, 0.4149]]),
 'h': tensor([[ 0.0749, -0.2037,  0.3348,  0.0514]], grad_fn=<TanhBackward>),
 'x': tensor([[0., 0., 1., 1., 0., 0., 0., 0., 0., 1.]])}

In [20]:
decoder.sample_mean({"z": sample["z"]})

TypeError: forward() got an unexpected keyword argument 'z'