## Set-up Joint Distribution

In [1]:
import tensorflow_probability as tfp
tfd = tfp.distributions
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import maxentep
# disabling gpu for now
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [2]:
# use scatter (bool mask?) to set
# what about normalization?
M = 3
R = np.abs(np.random.normal(size=(M,M), loc=1.25))
# make some elements zero
R[0,2] = 0
R[1,2] = 0
R = R / np.sum(R, axis=1).reshape(M,1)
R, np.sum(R, axis=1)

(array([[0.77574843, 0.22425157, 0.        ],
        [0.60882887, 0.39117113, 0.        ],
        [0.33034764, 0.19889464, 0.47075772]]),
 array([1., 1., 1.]))

In [3]:
rmodel = maxentep.NormmatDist(R, start_var=0.5, clip_high=5)
rmodel.model(0.).sample(1)

<tf.Tensor: shape=(1, 3, 3), dtype=float32, numpy=
array([[[6.9594872e-01, 4.0099382e-01, 2.4493564e-20],
        [1.3441851e+00, 5.4470426e-01, 8.6639479e-21],
        [5.6792980e-01, 6.1962956e-01, 7.3514259e-01]]], dtype=float32)>

In [4]:
# compartment parameters
compartments = ['E', 'A', 'I', 'R']
infections_compartments = [1,2]
C = len(compartments)
tmat = maxentep.TransitionMatrix(1.0, compartments)
tmat.add_norm_dist('E', 'A', 5, 2)
tmat.add_norm_dist('A', 'I', 3, 2)
tmat.add_norm_dist('I', 'R', 14, 4)
T = tmat.sample(1)[0]

In [5]:
tmodel = maxentep.DirichletmatDist(T)
tmodel.model(0.).sample(1)

<tf.Tensor: shape=(1, 4, 4), dtype=float32, numpy=
array([[[9.9940938e-01, 5.9062615e-04, 7.5122335e-39, 7.5122335e-39],
        [4.4568251e-37, 7.7082217e-01, 2.2917783e-01, 4.4568251e-37],
        [1.7510371e-38, 1.7510371e-38, 8.8402444e-01, 1.1597557e-01],
        [1.9057450e-38, 1.9057450e-38, 1.9057450e-38, 1.0000000e+00]]],
      dtype=float32)>

In [6]:
start = np.zeros((M, C))
start[0, 1] = 0.1
startmodel = maxentep.NormmatDist(start, start_var=0.2, clip_high=0.5)

In [7]:
infect = maxentep.contact_infection_func(0.1,infections_compartments)

In [8]:
i = tf.keras.layers.Input((1,))
param_model = maxentep.metapop_parameter_dist(i, start, R, T)

In [9]:
model = tf.keras.Model(inputs=i, outputs=param_model)

In [22]:
joint = model(tf.constant([0.]))
joint

<tfp.distributions.JointDistributionSequential 'model_joint_JointDistributionSequential' batch_shape=[[3, 4], [3, 3], [4]] event_shape=[[], [], [4]] dtype=[float32, float32, float32]>

In [23]:
x = joint.sample(1)

In [24]:
x

[<tf.Tensor: shape=(1, 3, 4), dtype=float32, numpy=
 array([[[2.40553202e-20, 4.39683311e-02, 5.62027302e-21, 1.90130081e-20],
         [6.62274726e-21, 7.64256402e-21, 3.21849999e-20, 1.69896371e-20],
         [2.40534542e-20, 1.31737956e-20, 2.47044825e-20, 7.06086379e-21]]],
       dtype=float32)>,
 <tf.Tensor: shape=(1, 3, 3), dtype=float32, numpy=
 array([[[1.3443437e+00, 1.2924173e+00, 1.3025647e-21],
         [1.0860901e+00, 1.0901672e+00, 3.2347516e-21],
         [2.2846625e+00, 1.1550552e+00, 1.2724214e+00]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 4, 4), dtype=float32, numpy=
 array([[[9.8030758e-01, 1.9692412e-02, 8.1010423e-39, 8.1010423e-39],
         [3.1511531e-38, 6.7439014e-01, 3.2560989e-01, 3.1511531e-38],
         [6.6101182e-38, 6.6101182e-38, 9.9997735e-01, 2.2598622e-05],
         [6.6742215e-38, 6.6742215e-38, 6.6742215e-38, 1.0000000e+00]]],
       dtype=float32)>]

In [27]:
joint.prob(x)

InvalidArgumentError: Incompatible shapes: [1,3,4] vs. [1,3,3] [Op:AddV2]

In [35]:
d = tfd.JointDistributionSequential([tfd.Normal(loc=start, scale=1.0), tfd.Dirichlet(R)])
d

<tfp.distributions.JointDistributionSequential 'JointDistributionSequential' batch_shape=[[3, 4], [3]] event_shape=[[], [3]] dtype=[float64, float64]>

In [36]:
x = d.sample(1)
x

[<tf.Tensor: shape=(1, 3, 4), dtype=float64, numpy=
 array([[[-1.47320733, -0.19613861,  1.98884805,  1.76857882],
         [ 0.85429586,  1.74153838, -1.61731307, -0.74101026],
         [-1.89006436, -3.15173859,  0.546037  ,  0.01095902]]])>,
 <tf.Tensor: shape=(1, 3, 3), dtype=float64, numpy=
 array([[[3.34781483e-001, 6.65218517e-001, 1.99386308e-308],
         [3.39271589e-001, 6.60728411e-001, 2.03545292e-308],
         [2.80538290e-006, 9.05398733e-001, 9.45984613e-002]]])>]

In [37]:
d.log_prob(x, )

InvalidArgumentError: Incompatible shapes: [1,3,4] vs. [1,3] [Op:AddV2]