## Set-up Joint Distribution

In [1]:
import tensorflow_probability as tfp
tfd = tfp.distributions
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import maxentep
# disabling gpu for now
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [2]:
# use scatter (bool mask?) to set
# what about normalization?
M = 3
R = np.abs(np.random.normal(size=(M,M), loc=1.25))
# make some elements zero
R[0,2] = 0
R[1,2] = 0
R = R / np.sum(R, axis=1).reshape(M,1)
R, np.sum(R, axis=1)

(array([[0.13614885, 0.86385115, 0.        ],
        [0.56534268, 0.43465732, 0.        ],
        [0.6375896 , 0.27006964, 0.09234075]]),
 array([1., 1., 1.]))

In [3]:
# compartment parameters
compartments = ['E', 'A', 'I', 'R']
infections_compartments = [1,2]
C = len(compartments)
tmat = maxentep.TransitionMatrix(1.0, compartments)
tmat.add_norm_dist('E', 'A', 5, 2)
tmat.add_norm_dist('A', 'I', 3, 2)
tmat.add_norm_dist('I', 'R', 14, 4)
T = tmat.sample(1)[0]
np.sum(T, axis=1)
T

array([[0.62719929, 0.37280071, 0.        , 0.        ],
       [0.        , 0.51438227, 0.48561773, 0.        ],
       [0.        , 0.        , 0.9187771 , 0.0812229 ],
       [0.        , 0.        , 0.        , 1.        ]])

In [4]:
start = np.zeros((M, C))
start[0, 1] = 0.1

In [5]:
infect = maxentep.contact_infection_func(0.1,infections_compartments)

In [6]:
i = tf.keras.layers.Input((1,))
R_dist = maxentep.normal_mat_layer(i, R, name='R')
T_dist =  maxentep.dirichlet_mat_layer(i, T, name='T')
start_dist = maxentep.normal_mat_layer(i, start, clip_high=0.5, name='rho')

In [7]:
model = tf.keras.Model(inputs=i, outputs=[R_dist, T_dist, start_dist])

In [8]:
joint = model(tf.constant(2 * [0.]))
joint

[<tfp.distributions.TruncatedNormal 'model_R_dist_TruncatedNormal' batch_shape=[3, 3] event_shape=[] dtype=float32>,
 <tfp.distributions.Dirichlet 'model_T_dist_Dirichlet' batch_shape=[4] event_shape=[4] dtype=float32>,
 <tfp.distributions.TruncatedNormal 'model_rho_dist_TruncatedNormal' batch_shape=[3, 4] event_shape=[] dtype=float32>]

In [54]:
y = [j.sample(10) for j in joint]
print(joint[1])
print(y[1], tf.reduce_sum(y[1], axis=2))
print(joint[1].log_prob(y[1]))
print(joint[0].log_prob(y[0]))

tfp.distributions.Dirichlet("model_T_dist_Dirichlet", batch_shape=[4], event_shape=[4], dtype=float32)
tf.Tensor(
[[[8.6452013e-01 1.3547990e-01 1.0913425e-39 1.0913425e-39]
  [1.2982105e-39 7.5682509e-01 2.4317490e-01 1.2982105e-39]
  [8.9022950e-40 8.9022950e-40 7.5088096e-01 2.4911901e-01]
  [8.5970081e-40 8.5970081e-40 8.5970081e-40 1.0000000e+00]]

 [[6.1483002e-01 3.8516995e-01 1.1266159e-39 1.1266159e-39]
  [7.0282825e-40 3.1091654e-01 6.8908346e-01 7.0282825e-40]
  [9.3134640e-40 9.3134640e-40 8.7816042e-01 1.2183956e-01]
  [9.1089165e-40 9.1089165e-40 9.1089165e-40 1.0000000e+00]]

 [[6.4645815e-01 3.5354185e-01 1.0044662e-39 1.0044662e-39]
  [1.6010073e-39 2.7170673e-01 7.2829324e-01 1.6010073e-39]
  [1.0813078e-39 1.0813078e-39 9.8302770e-01 1.6972300e-02]
  [1.5873012e-39 1.5873012e-39 1.5873012e-39 1.0000000e+00]]

 [[6.8209678e-01 3.1790322e-01 1.1179966e-39 1.1179966e-39]
  [3.0761262e-39 5.8225292e-01 4.1774711e-01 3.0761262e-39]
  [8.1327019e-40 8.1327019e-40 9.2845798

In [55]:
negloglik = lambda y, rv_y: -rv_y.log_prob(tf.clip_by_value(y, np.finfo(np.float32).tiny, np.finfo(np.float32).max))
negloglikd = lambda y, rv_y: (tf.print(y, rv_y, rv_y.log_prob(y)), -tf.reduce_sum(rv_y.log_prob(y)))[1]
model.compile(tf.optimizers.Adam(1e-10), loss=[negloglik, negloglik, negloglik])

In [56]:
l = model.get_layer('T-hypers')
l.w

<tf.Variable 'value:0' shape=(4, 4) dtype=float32, numpy=
array([[0.63320214, 0.43034753, 0.        , 0.        ],
       [0.        , 0.4420422 , 0.5881161 , 0.        ],
       [0.        , 0.        , 0.92711926, 0.17774531],
       [0.        , 0.        , 0.        , 1.        ]], dtype=float32)>

In [57]:
model.fit(tf.constant(10 * [0.]), y, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x21407bee2c8>

In [58]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
R-hypers (TrainableInputLayer)  (1, 2, 3, 3)         18          input_1[0][0]                    
__________________________________________________________________________________________________
T-hypers (TrainableInputLayer)  (1, 4, 4)            16          input_1[0][0]                    
__________________________________________________________________________________________________
rho-hypers (TrainableInputLayer (1, 2, 3, 4)         24          input_1[0][0]                    
______________________________________________________________________________________________

In [59]:
l = model.get_layer('T-hypers')
l.w

<tf.Variable 'value:0' shape=(4, 4) dtype=float32, numpy=
array([[0.63320214, 0.43034753, 0.        , 0.        ],
       [0.        , 0.4420422 , 0.5881161 , 0.        ],
       [0.        , 0.        , 0.92711926, 0.17774531],
       [0.        , 0.        , 0.        , 1.        ]], dtype=float32)>

In [14]:
out = model(tf.constant([0.]))

In [15]:
x c= out[1].sample()

In [16]:
negloglik(x, out[1])

[-131.116516 -131.354095 -132.689392 -195.680969]


<tf.Tensor: shape=(), dtype=float32, numpy=-590.84094>