In [1]:
import torch
from torch.distributions import MultivariateNormal
from torch.distributions.lowrank_multivariate_normal import LowRankMultivariateNormal

from cnp.output import MeanFieldGaussianLayer, InnerprodGaussianLayer, KvvGaussianLayer
from cnp.utils import Gamma

import matplotlib.pyplot as plt

# Mean field layer

In [2]:
mfg_layer = MeanFieldGaussianLayer()

In [3]:
B = 1
T = 7
C = 2

tensor = torch.rand(B, T, C)

mfg_layer.mean_and_cov(tensor)

(tensor([[0.2061, 0.6669, 0.6223, 0.6853, 0.3618, 0.2198, 0.1493]]),
 tensor([[[1.2321, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.9011, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.7264, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 1.1013, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.7697, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9445, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9628]]]),
 tensor([[[1.9253, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 1.5943, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 1.4196, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 1.7944, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.4629, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.6376, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0

In [4]:
B = 2
T = 7
C = 2

tensor = torch.rand(B, T, C)
y_target = torch.rand(B, T)

mfg_layer.loglik(tensor, y_target)

tensor(-8.6861, grad_fn=<MeanBackward0>)

In [5]:
B = 1
T = 7
C = 2
num_samples = 3

tensor = torch.rand(B, T, C)

mfg_layer.sample(tensor, num_samples=num_samples).shape

torch.Size([3, 1, 7])

# Innerprod layer

In [6]:
num_embedding = 100
noise_type = "hetero"

innerprod_layer = InnerprodGaussianLayer(num_embedding=num_embedding,
                                         noise_type=noise_type)

In [7]:
B = 1
T = 7
C = num_embedding + 1 + int(noise_type == "hetero")

tensor = torch.rand(B, T, C)

innerprod_layer.mean_and_cov(tensor)

(tensor([[0.6879, 0.3994, 0.7339, 0.6243, 0.1091, 0.2113, 0.3607]]),
 tensor([[[0.3334, 0.2073, 0.2486, 0.2582, 0.2713, 0.2457, 0.2721],
          [0.2073, 0.2442, 0.2115, 0.2089, 0.2223, 0.2088, 0.2309],
          [0.2486, 0.2115, 0.3269, 0.2445, 0.2751, 0.2292, 0.2538],
          [0.2582, 0.2089, 0.2445, 0.3346, 0.2730, 0.2434, 0.2701],
          [0.2713, 0.2223, 0.2751, 0.2730, 0.3732, 0.2661, 0.2756],
          [0.2457, 0.2088, 0.2292, 0.2434, 0.2661, 0.3023, 0.2510],
          [0.2721, 0.2309, 0.2538, 0.2701, 0.2756, 0.2510, 0.3591]]]),
 tensor([[[1.1523, 0.2073, 0.2486, 0.2582, 0.2713, 0.2457, 0.2721],
          [0.2073, 1.3352, 0.2115, 0.2089, 0.2223, 0.2088, 0.2309],
          [0.2486, 0.2115, 1.1770, 0.2445, 0.2751, 0.2292, 0.2538],
          [0.2582, 0.2089, 0.2445, 1.1028, 0.2730, 0.2434, 0.2701],
          [0.2713, 0.2223, 0.2751, 0.2730, 1.1409, 0.2661, 0.2756],
          [0.2457, 0.2088, 0.2292, 0.2434, 0.2661, 1.0281, 0.2510],
          [0.2721, 0.2309, 0.2538, 0.2701, 0

In [8]:
B = 3
T = 10000
C = num_embedding + 1 + int(noise_type == "hetero")

tensor = torch.rand(B, T, C)
y_target = torch.rand(B, T)

%time _ = innerprod_layer.loglik(tensor, y_target)

CPU times: user 70.4 ms, sys: 36.9 ms, total: 107 ms
Wall time: 52.9 ms


In [9]:
B = 3
T = 10000
C = num_embedding + 1 + int(noise_type == "hetero")
num_samples = 5

tensor = torch.rand(B, T, C)
y_target = torch.rand(B, T)

%time _ = innerprod_layer.loglik(tensor, y_target)

CPU times: user 60.7 ms, sys: 8.58 ms, total: 69.3 ms
Wall time: 43.1 ms


# Kvv layer

In [10]:
num_embedding = 10
noise_type = "homo"

kvv_layer = KvvGaussianLayer(num_embedding=num_embedding,
                             noise_type=noise_type)

In [11]:
B = 1
T = 7
C = 12

tensor = torch.rand(B, T, C)

kvv_layer.mean_and_cov(tensor)

(tensor([[0.4410, 0.7584, 0.5710, 0.0774, 0.6414, 0.7342, 0.3879]]),
 tensor([[[0.6693, 0.1476, 0.1149, 0.0953, 0.1111, 0.1153, 0.3803],
          [0.1476, 0.2319, 0.0650, 0.0464, 0.0730, 0.1086, 0.1158],
          [0.1149, 0.0650, 0.1245, 0.0428, 0.0674, 0.0389, 0.1440],
          [0.0953, 0.0464, 0.0428, 0.0390, 0.0495, 0.0237, 0.0877],
          [0.1111, 0.0730, 0.0674, 0.0495, 0.0915, 0.0405, 0.0913],
          [0.1153, 0.1086, 0.0389, 0.0237, 0.0405, 0.1127, 0.0586],
          [0.3803, 0.1158, 0.1440, 0.0877, 0.0913, 0.0586, 0.4498]]]),
 tensor([[[1.3624, 0.1476, 0.1149, 0.0953, 0.1111, 0.1153, 0.3803],
          [0.1476, 0.9250, 0.0650, 0.0464, 0.0730, 0.1086, 0.1158],
          [0.1149, 0.0650, 0.8177, 0.0428, 0.0674, 0.0389, 0.1440],
          [0.0953, 0.0464, 0.0428, 0.7321, 0.0495, 0.0237, 0.0877],
          [0.1111, 0.0730, 0.0674, 0.0495, 0.7846, 0.0405, 0.0913],
          [0.1153, 0.1086, 0.0389, 0.0237, 0.0405, 0.8058, 0.0586],
          [0.3803, 0.1158, 0.1440, 0.0877, 0

In [12]:
B = 3
T = 7
C = 12

tensor = torch.rand(B, T, C)
y_target = torch.rand(B, T)

kvv_layer.loglik(tensor, y_target)

tensor(-6.8717, grad_fn=<MeanBackward0>)

In [13]:
B = 3
T = 7
C = 12
num_samples = 5

tensor = torch.rand(B, T, C)
y_target = torch.rand(B, T)

kvv_layer.sample(tensor, num_samples=num_samples).shape

torch.Size([5, 3, 7])