# Belief Propagation

In this lab session you will implement the belief propagation (BP) algorithm on a factor graph using the [pgmpy library](http://www.pgmpy.org). By the end of this session you will be able to
- **Create a factor graph** respresenting a **hidden Markov model**.
- **Answer inference queries** using the built in methods of **pgmpy**.
- **Experiment with** the fundamental concepts behind **probabilistic inference and message passing**.



## The Model

We will consider our previous Markov chain model of wheather change over time. For simplicy, we will assume that the **weather is stationary** over the seven days of a week, **and that it can take three different values:** `sunny`, `cloudy` and `rainy`. The probability of rain or clouds after a rainy day is $0.25$ and $0.5$, respectively. A sunny day is also followed by a sunny day or a cloudy day with probabilities $0.7$ and $0.25$, respectively. Finally, if a day is cloudy, the next day will also be cloudy with probability $0.35$ or it will rain with probability $0.4$.

<style type="text/css">@media screen and (max-width: 767px) {.tg {width: auto !important;}.tg col {width: auto !important;}.tg-wrap {overflow-x: auto;-webkit-overflow-scrolling: touch;}}</style><div class="tg-wrap"><table style="border-collapse:collapse;border-spacing:0" class="tg"><thead><tr><th style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal" colspan="2" rowspan="2"></th><th style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal" colspan="3">Today's weather</th></tr><tr><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">sunny</th><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">cloudy</th><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">rainy</th></tr></thead><tbody><tr><td style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:middle;word-break:normal" rowspan="3">Yesterday's weather</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">sunny</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.7</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.25</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.05</td></tr><tr><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">cloudy</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.25</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.35</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.4</td></tr><tr><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">rainy</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.25</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.5</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.25</td></tr></tbody></table></div>

Let's assume that by some unfortunate reason we are not able to see the weather directly, but we can observe if someone makes use of an umbrella. This will be our observed random variable, which directly depends on the actual weather. The probability of carrying an umbrella is $0.95$ for a rainy day, $0.6$ for a cloudy day, and $0.2$ for a sunny day.  

<style type="text/css">@media screen and (max-width: 767px) {.tg {width: auto !important;}.tg col {width: auto !important;}.tg-wrap {overflow-x: auto;-webkit-overflow-scrolling: touch;}}</style><div class="tg-wrap"><table style="border-collapse:collapse;border-spacing:0" class="tg"><thead><tr><th style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal" colspan="2" rowspan="2"></th><th style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal" colspan="2">Umbrella</th></tr><tr><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">true</th><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">false</th></tr></thead><tbody><tr><td style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:middle;word-break:normal" rowspan="3">Weather</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">sunny</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.2</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.8</td></tr><tr><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">cloudy</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.6</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.4</td></tr><tr><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">rainy</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.95</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.05</td></tr></tbody></table></div>


Finally, we know that the first day of the week is sunny, cloudy, or rainy with probabilities $0.3$, $0.4$ and $0.3$, respectively.

## Creating the Factor graph

The model is composed of $14$ random variables, $7$ corresponding to the hidden weather state, and $7$ corresponding to the binary observed variable. Let's create the factors and variables enconding the transition and observational model.

In [None]:
from pgmpy.factors.discrete import DiscreteFactor

weather = ["w1", "w2", "w3", "w4", "w5", "w6", "w7"]
umbrella = ["u1", "u2", "u3", "u4", "u5", "u6", "u7"]
variables = weather + umbrella
state_names = dict([(k, ["sunny", "cloudy", "rainy"]) for k in weather] +
                   [(k, [True, False]) for k in umbrella])

def day_transition(d1, d2): # P(d2|d1)
     return DiscreteFactor(variables=[d2, d1],
                        cardinality=[len(state_names[d2]), len(state_names[d1])],
                        values=[0.7, 0.25, 0.25, # d2=sunny:  d1=sunny, cloudy, rainy
                                0.25, 0.35, 0.5, # d2=cloudy:  d1=sunny, cloudy, rainy
                                0.05, 0.4, 0.25], # d2=rainy:  d1=sunny, cloudy, rainy
                        state_names=state_names)

def take_umbrella_transition(d, u): # P(u|d)
    return DiscreteFactor(variables=[u, d],
                        cardinality=[len(state_names[u]), len(state_names[d])],
                        values=[0.2, 0.6, 0.95,   # u=True:  sunny, cloudy, rainy
                                0.8, 0.4, 0.05], # u=False: sunny, cloudy, rainy
                        state_names=state_names)

p = dict()
p["w1"] = DiscreteFactor(variables=["w1"],
                        cardinality=[len(state_names["w1"])],
                        values=[0.3, 0.4, 0.4],
                        state_names=state_names)

p["w2|w1"] = day_transition("w1", "w2")
p["w3|w2"] = day_transition("w2", "w3")
p["w4|w3"] = day_transition("w3", "w4")
p["w5|w4"] = day_transition("w4", "w5")
p["w6|w5"] = day_transition("w5", "w6")
p["w7|w6"] = day_transition("w6", "w7")

print("Table for P(w2|w1):")
print(p["w2|w1"])

p["u1|w1"] = take_umbrella_transition("w1", "u1")
p["u2|w2"] = take_umbrella_transition("w2", "u2")
p["u3|w3"] = take_umbrella_transition("w3", "u3")
p["u4|w4"] = take_umbrella_transition("w4", "u4")
p["u5|w5"] = take_umbrella_transition("w5", "u5")
p["u6|w6"] = take_umbrella_transition("w6", "u6")
p["u7|w7"] = take_umbrella_transition("w7", "u7")

print("\nTable for P(u3|w3):")
print(p["u1|w1"])


Table for P(w2|w1):
+------------+------------+--------------+
| w2         | w1         |   phi(w2,w1) |
| w2(sunny)  | w1(sunny)  |       0.7000 |
+------------+------------+--------------+
| w2(sunny)  | w1(cloudy) |       0.2500 |
+------------+------------+--------------+
| w2(sunny)  | w1(rainy)  |       0.2500 |
+------------+------------+--------------+
| w2(cloudy) | w1(sunny)  |       0.2500 |
+------------+------------+--------------+
| w2(cloudy) | w1(cloudy) |       0.3500 |
+------------+------------+--------------+
| w2(cloudy) | w1(rainy)  |       0.5000 |
+------------+------------+--------------+
| w2(rainy)  | w1(sunny)  |       0.0500 |
+------------+------------+--------------+
| w2(rainy)  | w1(cloudy) |       0.4000 |
+------------+------------+--------------+
| w2(rainy)  | w1(rainy)  |       0.2500 |
+------------+------------+--------------+

Table for P(u3|w3):
+-----------+------------+--------------+
| u1        | w1         |   phi(u1,w1) |
| u1(True)  | w

Now, let's build the model using the ```FactorGraph``` class. To do so, we first need to add all variable nodes with ```add_nodes_from```. Then, we add each factor ```f``` to the graph with```add_factors(f)``` and, for each of its variables ```v``` (available in ```f.variables```), we create an edge between ```f``` and ```v```  with ```add_edge(f,v)```. Finally, we verify that the factor graph is correct with the ```check_model``` method.

In [33]:
from pgmpy.models import FactorGraph
G = FactorGraph()

assert set(variables) == set([v for f in p.values() for v in f.variables])

### ADD NODES AND EDGES TO THE FACTOR GRAPH
G.add_nodes_from(variables)
for factor in p.values():
    G.add_factors(factor)
    for var in factor.variables:
        G.add_edge(var, factor)


print("Model is ok: ", G.check_model())

print("\nNumber of nodes: ", G.number_of_nodes())
print("Number of edges: ", G.number_of_edges())

Model is ok:  True

Number of nodes:  28
Number of edges:  27


## Running Inference Queries

We will use now some methods implemented in `pgmpy` that perform exact inference.
Consider the following queries:

1. We observe that on **Thursday and Friday the umbrella is being used**. What is the **weather prediction for Saturday**?
2. What is the **weather prediction for Sunday**, with the **same observations**?
3. What is the **weather prediction for Sunday** if, instead, we observe that the **umbrella is *NOT* being used on Thursday and Friday**?
4. What is the **most likely weather sequence for Monday and Tuesday**, if we observed that the **umbrella has been used both days**?


Answer the three previous queries using the `query` method of the `BeliefPropagation` class:
```python
    def query(self, variables, evidence=None, joint=True, show_progress=True):
        """
        Query method using belief propagation.

        Parameters
        ----------
        variables: list
            list of variables for which you want to compute the probability

        evidence: dict
            a dict key, value pair as {var: state_of_var_observed}
            None if no evidence

        joint: boolean
            If True, returns a Joint Distribution over `variables`.
            If False, returns a dict of distributions over each of the `variables`
```
<!-- You will need to normalize the output. -->

In [39]:
from pgmpy.inference import BeliefPropagation
bp = BeliefPropagation(G)

# Weather prediction for saturday, given that on thursday and friday the umbrella was used
query_1 = bp.query(
    variables=["w6"],
    evidence={"u4": True, "u5": True}
)
print("\nP(w6 | u4=True, u5=True):")
print(query_1)

# Weather prediction for sunday, given that on thursday and friday the umbrella was used
query_2 = bp.query(
    variables=["w7"],
    evidence={"u4": True, "u5": True}
)
print("\nP(w7 | u4=True, u5=True):")
print(query_2)

# Weather prediction for sunday, given that on thursday and friday the umbrella was not used
query_3 = bp.query(
    variables=["w7"],
    evidence={"u4": False, "u5": False}
)
print("\nP(w7 | u4=False, u5=False):")
print(query_3)

# Most likely sequence for monday and tuesday given that the umbrella was used both days
# Use joint of w1 and w2 for the query

query_4 = bp.query(
    variables=["w1", "w2"],
    evidence={"u1": True, "u2": True}
)

print("\nP(w1, w2 | u1=True, u2=True):")
print(query_4)

# Check if it sums to 1
total_prob = sum(query_4.values)
print("\nTotal probability for P(w1, w2 | u1=True, u2=True):", total_prob.sum())


# MAP (Most Probable Assignment) for w1 and w2 given that the umbrella was used both days:
mpe_query = bp.map_query(
    variables=["w1", "w2"],
    evidence={"u1": True, "u2": True}
)
print("\nMPE for w1 and w2 given u1=True and u2=True:")
print(mpe_query)


P(w6 | u4=True, u5=True):
+------------+-----------+
| w6         |   phi(w6) |
| w6(sunny)  |    0.3023 |
+------------+-----------+
| w6(cloudy) |    0.4081 |
+------------+-----------+
| w6(rainy)  |    0.2896 |
+------------+-----------+

P(w7 | u4=True, u5=True):
+------------+-----------+
| w7         |   phi(w7) |
| w7(sunny)  |    0.3860 |
+------------+-----------+
| w7(cloudy) |    0.3632 |
+------------+-----------+
| w7(rainy)  |    0.2508 |
+------------+-----------+

P(w7 | u4=False, u5=False):
+------------+-----------+
| w7         |   phi(w7) |
| w7(sunny)  |    0.5224 |
+------------+-----------+
| w7(cloudy) |    0.3077 |
+------------+-----------+
| w7(rainy)  |    0.1699 |
+------------+-----------+

P(w1, w2 | u1=True, u2=True):
+------------+------------+--------------+
| w1         | w2         |   phi(w1,w2) |
| w1(sunny)  | w2(sunny)  |       0.0212 |
+------------+------------+--------------+
| w1(sunny)  | w2(cloudy) |       0.0227 |
+------------+---------

## Implementing Belief Propagation

Implement your own Belief Propagation algorithm by filling in the provided code skeleton. You will need to provide the code to:
* Add the evidence factors
* Initialize a factor-to-variable message `f->v` and a variable-to-factor message `v->f` for each edge in the graph
* Update messages of each type
* Perform message passing using the collect_evidence and distribute_evidence algorithms
* Compute the marginal using the messages

See function comments for more details. Each comment line ```#IMPLEMENT``` is to be replaced with as many lines of code as you need. Method definitions should remain unchanged.

In [None]:
from functools import reduce
import operator
from collections import defaultdict
from copy import deepcopy

def prod(iterable):
    """
    Helper function to obtain the product of all the items in the iterable
    given as input
    """

    return reduce(operator.mul, iterable, 1)

class MyBeliefPropagation:
    def __init__(self, factor_graph):
        assert factor_graph.check_model()
        self.original_graph = factor_graph
        self.variables = factor_graph.get_variable_nodes()

        self.state_names = dict()
        for f in self.original_graph.factors:
            self.state_names.update(f.state_names)
        self.bp_done = False

    def get_evidence_factors(self, evidence):
        """
        For each evidence variable v, create a factor with p(v=e)=1. Receives a dict of
        evidences, where keys are variables and values are variable states. Returns a list of
        DiscreteFactor.
        """
        # IMPLEMENT

    def set_evidence(self, evidence):
        """
        Generates a new graph with the evidence factors
        evidence (keys: variables, values: states)
        """
        evidence_factors = self.get_evidence_factors(evidence)
        self.working_graph = self.original_graph.copy()
        for f in evidence_factors:
            # IMPLEMENT
        self.bp_done = False

    def factor_ones(self, v):
        """
        Returns a DiscreteFactor for variable v with all ones.
        """
        # IMPLEMENT

    def initialize_messages(self):
        """
        This function creates, for each edge factor-variable, two messages: m(f->v) and
        m(v->f). It initiliazies each message as a DiscreteFactor with all ones. It stores all
        the messages in a dict of dict. Keys of both dicts are either factors or variables.
        Messages are indexed as messages[to][from]. For example, m(x->y) is in messages[y][x].
        It's done this way because it will be useful to get all messages that go to a variable
        or a factor.
        """
        self.messages = defaultdict(dict)
        for f in self.working_graph.get_factors():
            for v in f.variables:
              self.messages[v][f] = self.factor_ones(v)
              self.messages[f][v] = self.factor_ones(v)


    def factor_to_variable(self, f, v):
        """
        Computes message m from factor to variable. It computes it from all messages from all
        other variables to the factor (i.e. all variables connected the factor except v).
        Returns message m.
        """
        assert v in self.variables and f in self.working_graph.factors
        # IMPLEMENT

    def variable_to_factor(self, v, f):
        """
        Computes message m from variable to factor. It computes it from all messages from all
        other factors to the variable (i.e. all factors connected the variable except f).
        Returns message m.
        """
        assert v in self.variables and f in self.working_graph.factors
        # IMPLEMENT

    def update(self, m_to, m_from):
        """
        Performs an update of a message depending on whether it is variable-to-factor or
        factor-to-variable.
        """
        # IMPLEMENT

    def collect_evidence(self, node, parent=None):
        """
        Passes messages from the leaves to the root of the tree.
        The parent argument is used to avoid an infinite recursion.
        """
        for child in self.working_graph.neighbors(node):
            # IMPLEMENT

    def distribute_evidence(self, node, parent=None):
        """
        Passes messages from the root to the leaves of the tree.
        The parent argument is used to avoid an infinite recursion.
        """
        for child in self.working_graph.neighbors(node):
            # IMPLEMENT

    def run_bp(self, root):
        """
        After initializing the messages, this function performs Belief Propagation
        using collect_evidence and distribute_evidence from the given root node.
        """
        assert root in self.variables, "Variable not in the model"
        # IMPLEMENT
        self.bp_done = True

    def get_marginal(self, variable):
        """
        To be used after run_bp. Returns p(variable | evidence) unnormalized.
        """
        assert self.bp_done, "First run BP!"
        # IMPLEMENT

   def get_marginal_subset(self, variables):
        """
        Returns p(variables | evidence) unnormalized.
        """
        assert self.bp_done, "First run BP!"
        # IMPLEMENT

Check that your implementation produces the same results for the first three queries P(w6|u4=t,u5=t), P(w7|u4=t,u5=t), and P(w6|u4=f,u5=f) as the ones given by the `BeliefPropagation` class. Do you need to run BP each time? Justify your answer.

In [None]:
my_bp = MyBeliefPropagation(G)



Use BP to compute the fourth query P(w1,w2|u1=t,u2=t). Note that we cannot do it with the ```get_marginal``` method as it is now. Define a method ```get_marginal_subset``` that receives a list of variables as input and, if they share a common factor, returns the marginal of the subset. Otherwise, throw an error. Check your result with the one produced by the `BeliefPropagation` class.

In [None]:
# Should be exactly the same as with bp from pgmpy