# Gibbs Sampling

Gibbs sampling is a technique for **generating samples from multi-dimensional distributions**.

The reason it's a challenging task is because the **joint distribution is unknown**: we cannot sample directly from it. We only know some of the conditional distributions. 

For example, for a two-dimensional variable (x, y) we don't know the joint probability p(x, y). But we know conditional distributions such as, p(y | x) and p(x | y). Gibbs sampling provides a technique to use these two conditional distributions to sample from the unknown joint distribution.

In this notebook we use a simple example to understand how Gibbs sampling works.


## Rolling Two Dice

Assume that we have two fair dice.

Let $x$ represent the outcome of the first die and $y$ be the sum of the outcome of two dice.

Our goal is to generate lots of (x, y) pairs by using Gibbs sampling technique.

Before we do this, let's simplify the problem. Assume we know the joint probability distribution of (x, y), hence can **directly sample from the joint distribution**.

In [1]:
import random
from typing import Tuple, List, Dict
from collections import defaultdict

## Sampling Directly from Joint Distribution

Since we know the probability distribution of each die, we can generate (x, y) pairs **directly** by rolling each die.

In [2]:
def roll_a_die() -> int:
    return random.choice([1, 2, 3, 4, 5, 6])

def direct_sample() -> Tuple[int, int]:
    d1 = roll_a_die()
    d2 = roll_a_die()
    return d1, d1 + d2

# result_roll_a_die = roll_a_die()

# print(result_roll_a_die)

print("\nResult of rolling two dice (outcome of die 1, sum of the outcome of two dice): ", direct_sample())


Result of rolling two dice (outcome of die 1, sum of the outcome of two dice):  (2, 5)


## Sampling based on Conditional Distribution

What if we don't know the joint distribution, but have knowledge about some conditional distributions.

In Gibbs sampling we sample given two conditional distributions.
- We know the distribution of y conditional on x.
- We know the distribution of x conditional on y.

First, let's see how we could sample given one of these two conditional distrubitions.


## Distribution of y Conditional on x

If we know the value of x, y is equally likely to be x + 1, x + 2, x + 3, x + 4, x + 5, or x + 6.

Below we generate an (x, y) pair given x.

In [3]:
def random_y_given_x(x: int) -> int:
    """equally likely to be x + 1, x + 2, ... , x + 6"""
    return x + roll_a_die()


# Outcome of die 1
x = 6

print("\nSum of outcome of rolling two dice (y) given the outcome of the first die (x = %d): " 
      % x, random_y_given_x(x))


Sum of outcome of rolling two dice (y) given the outcome of the first die (x = 6):  8


## Distribution of x Conditional on y

If we know that y is 2, then necessarily x is 1 (since the only way two dice can sum to 2 is if both of them are 1). 

If we know y is 3, then x is equally likely to be 1 or 2. 

Similarly, if y is 11, then x has to be either 5 or 6.

In [4]:
def random_x_given_y(y: int) -> int:
    if y <= 7:
        # if the total is 7 or less, the first die is equally likely to be
        # 1, 2, ..., (total - 1)
        return random.randrange(1, y)
    else:
        # if the total is 7 or more, the first die is equally likely to be
        # (total - 6), (total - 5), ..., 6
        return random.randrange(y - 6, 7)
    
    
# sum of the sides of two dice (y)
y = 11
    
print("\nGiven the sum of the outcome of two dice (y = %d), outcome of die 1 (x): " 
      % y, random_x_given_y(y))


Given the sum of the outcome of two dice (y = 11), outcome of die 1 (x):  5


## Gibbs Sampling

In Gibbs sampling we use both conditional distributions.

We start with any (valid) values for x and y and then repeatedly alternate
- replacing x with a random value picked conditional on y and 
- replacing y with a random value picked conditional on x. 

After a number of iterations, the resulting values of x and y will represent a sample from the unconditional joint distribution.

In [5]:
def gibbs_sample(num_iters: int = 100) -> Tuple[int, int]:
    x, y = 1, 2 # doesn't really matter
    for _ in range(num_iters):
        x = random_x_given_y(y)
        y = random_y_given_x(x)
    return x, y

print("Outcome of die 1 (x) & sum of the outcome of two dice (y): ", gibbs_sample(10000))

Outcome of die 1 (x) & sum of the outcome of two dice (y):  (4, 6)


# Direct Sampling vs. Gibbs Sampling

We compare the results obtained from Gibbs Sampling with that from Direct Sampling.

We observe that as we add more samples, both techniques give similar result.

In [6]:
def compare_distributions(num_samples: int = 1000) -> Dict[int, List[int]]:
    counts = defaultdict(lambda: [0, 0])
    for _ in range(num_samples):
        counts[gibbs_sample()][0] += 1
        counts[direct_sample()][1] += 1
    return counts

In [7]:
counts = compare_distributions(10000)

for i in counts.items():
    print(i)

((2, 6), [264, 306])
((6, 10), [270, 280])
((3, 4), [240, 270])
((4, 8), [315, 289])
((3, 6), [277, 265])
((3, 7), [271, 283])
((6, 7), [303, 266])
((1, 7), [279, 275])
((4, 5), [268, 290])
((2, 8), [281, 277])
((5, 11), [277, 277])
((2, 5), [267, 298])
((6, 9), [300, 271])
((6, 8), [248, 295])
((4, 6), [286, 245])
((3, 5), [283, 275])
((1, 2), [291, 276])
((2, 3), [262, 277])
((3, 8), [249, 293])
((5, 6), [283, 244])
((4, 7), [283, 281])
((6, 12), [289, 288])
((3, 9), [273, 255])
((1, 4), [306, 277])
((5, 8), [286, 285])
((1, 3), [273, 275])
((1, 5), [294, 250])
((5, 7), [254, 281])
((4, 10), [292, 298])
((2, 4), [262, 296])
((2, 7), [267, 278])
((5, 9), [262, 266])
((4, 9), [295, 300])
((5, 10), [276, 250])
((6, 11), [297, 281])
((1, 6), [277, 287])
