# Shapes and Dimensions

https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/

https://pyro.ai/examples/tensor_shapes.html



**Event Shape:** the atomic shape of a single event/observation from the distribution (or batch of distributions of the same family)

**Batch Shape:** the atomic shape of a single event/observation from one or more distributions of the same family. We can't have a batch of a Gaussian and a Gamma distribution together, but we can have a batch of more than one Gaussians.

**Sample Shape:** the shape of bunch of samples drawn from the distributions

Why do these different shapes matter? Just think about the computation of the log likelihood of a vector of two numbers. In the case of bivariate Gaussians and the case of the batch of two independent Gaussians, how many log probabilities should we return?

> if problems can be cast into a tensor-space operation, vectorization can help speed up many operations that we wish to handle. -- Eric Ma

> apparently same-shaped draws are shaped differently semantically


In [32]:
# import jax 
import jax.numpy as np
from jax import random 

# import numpyro
import numpyro.distributions as dist 


In [61]:
# one draw from one normal
# event: one draw [scalar]
# batch: one normal
d = dist.Normal()
print("Distribution Event Shape:", d.event_shape)
print("Distribution Batch Shape:", d.batch_shape)

draw = d.sample(key=random.PRNGKey(123), sample_shape=(1,)) # sample+shape + batch_shape + event_shape
print("Data Shapes:",draw.shape)
print("Samples:", draw)

logp = d.log_prob(draw)
print("Log Probability:", logp)

# event shape: ()
# batch shape: ()
# sample shape: (1,)
# data shape: (1,)

Distribution Event Shape: ()
Distribution Batch Shape: ()
Data Shapes: (1,)
Samples: [-0.75307846]
Log Probability: [-1.2025021]


In [60]:
# two draws from one normal 
# the elementary event of **drawing a single number** did not fundamentally cahnge
# (repeat the same event twice)
# event: one draw [scalar]
# batch: one normal
d = dist.Normal()
print("Distribution Event Shape:", d.event_shape)
print("Distribution Batch Shape:", d.batch_shape)

draw = d.sample(key=random.PRNGKey(123), sample_shape=(2,)) # sample+shape + batch_shape + event_shape
print("Data Shapes:",draw.shape)
print("Samples:", draw)

logp = d.log_prob(draw)
print("Log Probability:", logp)

# event shape: ()
# batch shape: ()
# sample shape: (2,)
# data shape: (2,)

Distribution Event Shape: ()
Distribution Batch Shape: ()
Data Shapes: (2,)
Samples: [-0.03049826  0.49289012]
Log Probability: [-0.9194036 -1.0404088]


In [58]:
# one draw from the first normal alongside one draw from the second normal, 
# then concatenate them into a vector 
# drawing number from INDEPENDENT Gaussians
# event: one draw
# batch: two normals
means = np.array([1., 0.])
stds = np.array([1., 3.])

d = dist.Normal(means, stds)
print("Distribution Event Shape:", d.event_shape)
print("Distribution Batch Shape:", d.batch_shape)

draw = d.sample(key=random.PRNGKey(123), sample_shape=(1,))
print("Data Shapes:",draw.shape)
print("Samples:", draw)

logp = d.log_prob(draw) # computed independently
print("Log Probability:", logp)

# event shape: ()
# batch shape: (2,)
# sample shape: (1,)
# data shape: (1, 2) # rightmost is the batch_shape

Distribution Event Shape: ()
Distribution Batch Shape: (2,)
Data Shapes: (1, 2)
Samples: [[0.96950173 1.4786704 ]]
Log Probability: [[-0.9194036 -2.1390214]]


In [57]:
# one draw from a multivariate Gaussian
# event: one draw with 2 elements [2-element vector]
d = dist.MultivariateNormal(covariance_matrix=np.array([[1., 0.5],[0.5,1.]]))
print("Distribution Event Shape:", d.event_shape, "[2-element vector]")
print("Distribution Batch Shape:", d.batch_shape)

draw = d.sample(key=random.PRNGKey(123), sample_shape=(1,))
print("Data Shapes:",draw.shape)
print("Samples:", draw)

logp = d.log_prob(draw) # computed with consideration to the full join distribution
print("Log Probability:", logp)
# event shape: (2,)
# batch shape: ()
# sample shape: (1,)
# data shape: (1,2) # rightmost is the event_shape

Distribution Event Shape: (2,) [2-element vector]
Distribution Batch Shape: ()
Data Shapes: (1, 2)
Samples: [[-0.03049826  0.41160622]]
Log Probability: [-1.8159714]


In [62]:
# two draws from a multivariate Gaussian
# event: two draws with 2 elements [2-element vectors]
d = dist.MultivariateNormal(covariance_matrix=np.array([[1., 0.5],[0.5,1.]]))
print("Distribution Event Shape:", d.event_shape, "[2-element vector]")
print("Distribution Batch Shape:", d.batch_shape)

draw = d.sample(key=random.PRNGKey(123), sample_shape=(2,))
print("Data Shapes:",draw.shape)
print("Samples:", draw)

logp = d.log_prob(draw) # computed with consideration to the full join distribution
print("Log Probability:", logp)
# event shape: (2,)
# batch shape: ()
# sample shape: (2,)
# data shape: (2,2) # rightmost is the event_shape

Distribution Event Shape: (2,) [2-element vector]
Distribution Batch Shape: ()
Data Shapes: (2, 2)
Samples: [[-0.1470326 -2.3694336]
 [ 1.648498   0.5994194]]
Log Probability: [-5.2190027 -3.0865078]
