In [None]:
!pip install -q pgmpy

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.9 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━[0m [32m1.2/1.9 MB[0m [31m35.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.9/1.9 MB[0m [31m41.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m
[?25h

# Belief Propagation

In this lab session you will implement the belief propagation (BP) algorithm on a factor graph using the [pgmpy library](http://www.pgmpy.org). By the end of this session you will be able to
- **Create a factor graph** respresenting a **hidden Markov model**.
- **Answer inference queries** using the built in methods of **pgmpy**.
- **Experiment with** the fundamental concepts behind **probabilistic inference and message passing**.

## The Model

<img src="https://raw.githubusercontent.com/guillermoim/resources/main/makov_weather.png" width=750 height=200>

We will consider our previous Markov chain model of wheather change over time. For simplicy, we will assume that the **weather is stationary** over the seven days of a week, **and that it can take three different values:** `sunny`, `cloudy` and `rainy`. The probability of rain or clouds after a rainy day is $0.25$ and $0.5$, respectively. A sunny day is also followed by a sunny day or a cloudy day with probabilities $0.7$ and $0.25$, respectively. Finally, if a day is cloudy, the next day will also be cloudy with probability $0.35$ or it will rain with probability $0.4$.

<style type="text/css">@media screen and (max-width: 767px) {.tg {width: auto !important;}.tg col {width: auto !important;}.tg-wrap {overflow-x: auto;-webkit-overflow-scrolling: touch;}}</style><div class="tg-wrap"><table style="border-collapse:collapse;border-spacing:0" class="tg"><thead><tr><th style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal" colspan="2" rowspan="2"></th><th style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal" colspan="3">Today's weather</th></tr><tr><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">sunny</th><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">cloudy</th><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">rainy</th></tr></thead><tbody><tr><td style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:middle;word-break:normal" rowspan="3">Yesterday's weather</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">sunny</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.7</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.25</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.05</td></tr><tr><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">cloudy</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.25</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.35</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.4</td></tr><tr><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">rainy</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.25</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.5</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.25</td></tr></tbody></table></div>

Let's assume that by some unfortunate reason we are not able to see the weather directly, but we can observe if someone makes use of an umbrella. This will be our observed random variable, which directly depends on the actual weather. The probability of carrying an umbrella is $0.95$ for a rainy day, $0.6$ for a cloudy day, and $0.2$ for a sunny day.  

<style type="text/css">@media screen and (max-width: 767px) {.tg {width: auto !important;}.tg col {width: auto !important;}.tg-wrap {overflow-x: auto;-webkit-overflow-scrolling: touch;}}</style><div class="tg-wrap"><table style="border-collapse:collapse;border-spacing:0" class="tg"><thead><tr><th style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal" colspan="2" rowspan="2"></th><th style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal" colspan="2">Umbrella</th></tr><tr><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">true</th><th style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">false</th></tr></thead><tbody><tr><td style="border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-weight:normal;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:middle;word-break:normal" rowspan="3">Weather</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">sunny</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.2</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.8</td></tr><tr><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">cloudy</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.6</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.4</td></tr><tr><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;overflow:hidden;padding:10px 5px;text-align:left;vertical-align:top;word-break:normal">rainy</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.95</td><td style="border-color:inherit;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;font-style:italic;overflow:hidden;padding:10px 5px;text-align:center;vertical-align:top;word-break:normal">.05</td></tr></tbody></table></div>


Finally, we know that the first day of the week is sunny, cloudy, or rainy with probabilities $0.3$, $0.4$ and $0.3$, respectively. 

## Creating the Factor graph

The model is composed of $14$ random variables, $7$ corresponding to the hidden weather state, and $7$ corresponding to the binary observed variable. Let's create the factors and variables enconding the transition and observational model.

In [None]:
from pgmpy.factors.discrete import DiscreteFactor

weather = ["w1", "w2", "w3", "w4", "w5", "w6", "w7"]
umbrella = ["u1", "u2", "u3", "u4", "u5", "u6", "u7"]
variables = weather + umbrella
state_names = dict([(k, ["sunny", "cloudy", "rainy"]) for k in weather] +
                   [(k, [True, False]) for k in umbrella])

def day_transition(d1, d2): # P(d2|d1)
     return DiscreteFactor(variables=[d1, d2],
                        cardinality=[len(state_names.get(d1)), len(state_names.get(d2))],
                        values=[0.7, 0.25, 0.05, 0.25, 0.35, 0.4, 0.25, 0.5, 0.25],
                        state_names=state_names)
    
def take_umbrella_transition(d, u): # P(u|d)
    return DiscreteFactor(variables=[d, u],
                        cardinality=[len(state_names.get(d)), len(state_names.get(u))],
                        values=[0.2, 0.8, 0.6, 0.4, 0.95, 0.05],
                        state_names=state_names)

p = dict()
p["w1"] = DiscreteFactor(variables=["w1"],
                        cardinality=[len(state_names.get("w1"))],
                        values=[0.3, 0.4, 0.3],
                        state_names=state_names)

p["w2|w1"] = day_transition("w1", "w2")
### ADD REMAINING FACTORS
p["w3|w2"] = day_transition("w2", "w3")
p["w4|w3"] = day_transition("w3", "w4")
p["w5|w4"] = day_transition("w4", "w5")
p["w6|w5"] = day_transition("w5", "w6")
p["w7|w6"] = day_transition("w6", "w7")

p["u1|w1"] = take_umbrella_transition("w1", "u1")
### ADD REMAINING FACTORS
p["u2|w2"] = take_umbrella_transition("w2", "u2")
p["u3|w3"] = take_umbrella_transition("w3", "u3")
p["u4|w4"] = take_umbrella_transition("w4", "u4")
p["u5|w5"] = take_umbrella_transition("w5", "u5")
p["u6|w6"] = take_umbrella_transition("w6", "u6")
p["u7|w7"] = take_umbrella_transition("w7", "u7")

Now, let's build the model using the ```FactorGraph``` class. To do so, we first need to add all variable nodes with ```add_nodes_from```. Then, we add each factor ```f``` to the graph with```add_factors(f)``` and, for each of its variables ```v``` (available in ```f.variables```), we create an edge between ```f``` and ```v```  with ```add_edge(f,v)```. Finally, we verify that the factor graph is correct with the ```check_model``` method.

In [None]:
from pgmpy.models import FactorGraph
G = FactorGraph()

assert set(variables) == set([v for f in p.values() for v in f.variables])


### ADD NODES AND EDGES TO THE FACTOR GRAPH
G.add_nodes_from(variables)
G.add_factors(p["w2|w1"], p["w3|w2"], p["w4|w3"], p["w5|w4"], p["w6|w5"], p["w7|w6"], p["u1|w1"], p["u2|w2"], p["u3|w3"], p["u4|w4"], p["u5|w5"], p["u6|w6"], p["u7|w7"])
G.add_edge("w1", p["w2|w1"])
G.add_edge("w2", p["w2|w1"])
G.add_edge("w2", p["w3|w2"])
G.add_edge("w3", p["w3|w2"])
G.add_edge("w3", p["w4|w3"])
G.add_edge("w4", p["w4|w3"])
G.add_edge("w4", p["w5|w4"])
G.add_edge("w5", p["w5|w4"])
G.add_edge("w5", p["w6|w5"])
G.add_edge("w6", p["w6|w5"])
G.add_edge("w6", p["w7|w6"])
G.add_edge("w7", p["w7|w6"])
G.add_edge("w1", p["u1|w1"])
G.add_edge("u1", p["u1|w1"])
G.add_edge("w2", p["u2|w2"])
G.add_edge("u2", p["u2|w2"])
G.add_edge("w3", p["u3|w3"])
G.add_edge("u3", p["u3|w3"])
G.add_edge("w4", p["u4|w4"])
G.add_edge("u4", p["u4|w4"])
G.add_edge("w5", p["u5|w5"])
G.add_edge("u5", p["u5|w5"])
G.add_edge("w6", p["u6|w6"])
G.add_edge("u6", p["u6|w6"])
G.add_edge("w7", p["u7|w7"])
G.add_edge("u7", p["u7|w7"])

print("Model is ok: ", G.check_model())                                                                                                          

Model is ok:  True


## Running Inference Queries

We will use now some methods implemented in `pgmpy` that perform exact inference.
Consider the following queries:

1. We observe that on **Thursday and Friday the umbrella is being used**. What is the **weather prediction for Saturday**?
2. What is the **weather prediction for Sunday**, with the **same observations**?
3. What is the **weather prediction for Sunday** if, instead, we observe that the **umbrella is *NOT* being used on Thursday and Friday**? 
4. What is the **most likely weather sequence for Monday and Tuesday**, if we observed that the **umbrella has been used both days**?


Answer the three previous queries using the `query` method of the `BeliefPropagation` class:
```python
    def query(self, variables, evidence=None, joint=True, show_progress=True):
        """
        Query method using belief propagation.

        Parameters
        ----------
        variables: list
            list of variables for which you want to compute the probability

        evidence: dict
            a dict key, value pair as {var: state_of_var_observed}
            None if no evidence

        joint: boolean
            If True, returns a Joint Distribution over `variables`.
            If False, returns a dict of distributions over each of the `variables`
```
You will need to normalize the output.

In [None]:
from pgmpy.inference import BeliefPropagation
bp = BeliefPropagation(G)

# MAKE QUERIES
p["w6|u4&u5=t"] = bp.query(["w6"], {"u4":True, "u5": True}, True)
print(p["w6|u4&u5=t"])
p["w7|u4&u5=t"] = bp.query(["w7"], {"u4":True, "u5": True}, True)
print(p["w7|u4&u5=t"])
p["w7|u4&u5=f"] = bp.query(["w7"], {"u4":False, "u5": False}, True)
print(p["w7|u4&u5=f"])
p["w1,w2|u1&u2=t"] = bp.query(["w1", "w2"], {"u1":True, "u2": True}, True)
print(p["w1,w2|u1&u2=t"])

+------------+-----------+
| w6         |   phi(w6) |
| w6(sunny)  |    0.3026 |
+------------+-----------+
| w6(cloudy) |    0.4080 |
+------------+-----------+
| w6(rainy)  |    0.2894 |
+------------+-----------+
+------------+-----------+
| w7         |   phi(w7) |
| w7(sunny)  |    0.3862 |
+------------+-----------+
| w7(cloudy) |    0.3631 |
+------------+-----------+
| w7(rainy)  |    0.2507 |
+------------+-----------+
+------------+-----------+
| w7         |   phi(w7) |
| w7(sunny)  |    0.5226 |
+------------+-----------+
| w7(cloudy) |    0.3076 |
+------------+-----------+
| w7(rainy)  |    0.1698 |
+------------+-----------+
+------------+------------+--------------+
| w1         | w2         |   phi(w1,w2) |
| w1(sunny)  | w2(sunny)  |       0.0277 |
+------------+------------+--------------+
| w1(sunny)  | w2(cloudy) |       0.0297 |
+------------+------------+--------------+
| w1(sunny)  | w2(rainy)  |       0.0094 |
+------------+------------+--------------+
| w1(clo

The most likely weather would be that on Monday is rainy and on Tuesday is cloudy, because it has the higher probability.

## Implementing Belief Propagation

Implement your own Belief Propagation algorithm by filling in the provided code skeleton. You will need to provide the code to:
* Add the evidence factors
* Initialize a factor-to-variable message `f->v` and a variable-to-factor message `v->f` for each edge in the graph
* Update messages of each type
* Perform message passing using the collect_evidence and distribute_evidence algorithms
* Compute the marginal using the messages

See function comments for more details. Each comment line ```#IMPLEMENT``` is to be replaced with as many lines of code as you need. Method definitions should remain unchanged.

In [None]:
from traitlets.traitlets import ValidateHandler
from functools import reduce
import operator
from collections import defaultdict
from copy import deepcopy

def prod(iterable):
    
    """
      Returns the product of all the items in the iterable given in as input.
    """

    return reduce(operator.mul, iterable, 1)

class MyBeliefPropagation:
    def __init__(self, factor_graph):
        assert factor_graph.check_model()
        self.original_graph = factor_graph
        self.variables = factor_graph.get_variable_nodes()
        
        self.state_names = dict()
        for f in self.original_graph.factors:
            self.state_names.update(f.state_names)
        self.bp_done = False

    def factor_ones(self, v):
        """
        Returns a DiscreteFactor for variable v with all ones.
        """
        # IMPLEMENT
        return DiscreteFactor(variables=[v],
                        cardinality=[len(self.state_names.get(v))],
                        values=[1]*len(self.state_names.get(v)),
                        state_names=self.state_names)
        
    def initialize_messages(self):
        """
        This function creates, for each edge factor-variable, two messages: m(f->v) and 
        m(v->f). It initiliazies each message as a DiscreteFactor with all ones. It stores all
        the messages in a dict of dict. Keys of both dicts are either factors or variables.
        Messages are indexed as messages[to][from]. For example, m(x->y) is in messages[y][x].
        It's done this way because it will be useful to get all messages that go to a variable
        or a factor.
        """
        self.messages = defaultdict(dict)
        for f in self.working_graph.get_factors():
            # IMPLEMENT
            for v in f.variables:
              self.messages[f][v] = self.factor_ones(v)
              self.messages[v][f] = self.factor_ones(v)
                
    def factor_to_variable(self, f, v):
        """
        Computes message m from factor to variable. It computes it from all messages from all
        other variables to the factor (i.e. all variables connected the factor except v).
        Returns message m.
        """
        assert v in self.variables and f in self.working_graph.factors
        # IMPLEMENT
        M = list(self.messages[f].values())
        M.remove(self.messages[f][v])
        res = f * prod(M)
        vs = []
        for aux in res.variables:
          if aux!= v:
            vs.append(aux)
        return res.marginalize(vs, inplace=False)
 
    def variable_to_factor(self, v, f):
        """
        Computes message m from variable to factor. It computes it from all messages from all
        other factors to the variable (i.e. all factors connected the variable except f).
        Returns message m.
        """        
        assert v in self.variables and f in self.working_graph.factors
        # IMPLEMENT
        M = list(self.messages[v].values())
        M.remove(self.messages[v][f])
        return prod(M)
    
    def get_evidence_factors(self, evidence):
        """
        For each evidence variable v, create a factor with p(v=e)=1. Recieves a dict of
        evidence, where keys are variables and values are variable states. Returns a list of
        DiscreteFactor.
        """
        # IMPLEMENT
        df = []
        for var, state in evidence.items():
          vals = []
          for st in self.state_names.get(var):
            if (st == state):
              vals.append(1)
            else:
              vals.append(0)
          factor = DiscreteFactor(variables=[var],
                        cardinality=[len(self.state_names.get(var))],
                        values = vals,
                        state_names=self.state_names)
          df.append(factor)
        return df
        
    def update(self, m_to, m_from):
        """
        Performs an update of a message depending on whether it is variable-to-factor or
        factor-to-variable.
        """
        # IMPLEMENT
        #print("to:", m_to, "from:", m_from, "\nmessage:", self.messages[m_to][m_from])
        #print("-----------------------------------------------------")
        if (m_to in self.variables):
          self.messages[m_to][m_from] = self.factor_to_variable(m_from, m_to)
          #print("\nupdate:", self.messages[m_to][m_from])
          #print("-----------------------------------------------------")
          #print("-----------------------------------------------------")
        elif (m_to in self.working_graph.get_factors()):
          self.messages[m_to][m_from] = self.variable_to_factor(m_from, m_to)
          #print("\nupdate:", self.messages[m_to][m_from])
          #print("-----------------------------------------------------")
          #print("-----------------------------------------------------")
    
    def collect_evidence(self, node, parent=None):
        """
        Passes messages from the leaves to the root of the tree.
        The parent argument is used to avoid an infinite recursion.
        """
        for child in self.messages[node]:
          if child != parent:
            self.update(node, self.collect_evidence(child, parent=node))
        return node
    
    def distribute_evidence(self, node, parent=None):
        """
        Passes messages from the root to the leaves of the tree.
        The parent argument is used to avoid an infinite recursion.
        """
        for child in self.working_graph.neighbors(node):
          if child != parent:
            self.update(child, node)
            self.distribute_evidence(child, parent=node)
    
    def set_evidence(self, evidence):
        """
        Generates a new graph with the evidence factors
        evidence (keys: variables, values: states)
        """
        evidence_factors = self.get_evidence_factors(evidence)
        self.working_graph = self.original_graph.copy()
        for f in evidence_factors:
            # IMPLEMENT
            self.working_graph.add_factors(f)
            for variable in f.variables:
              self.working_graph.add_edge(variable, f)
        self.bp_done = False
    
    def run_bp(self, root):
        """
        After initializing the messages, this function performs Belief Propagation
        using collect_evidence and distribute_evidence from the given root node.
        """
        assert root in self.variables, "Variable not in the model"
        self.initialize_messages()
        self.distribute_evidence(root)
        self.collect_evidence(root)

        self.bp_done = True

    def get_marginal(self, variable):
        """
        To be used after run_bp. Returns p(variable | evidence) unnormalized.
        """
        assert self.bp_done, "First run BP!"
        marginal = 1
        for factors in self.working_graph.neighbors(variable):
          marginal *= self.messages[variable][factors]
        return marginal 

Check that your implementation produces the same results for the first three queries P(w7|u4=t,u5=t), P(w7|u4=t,u5=t), and P(w6|u4=f,u5=f) as the ones given by the `BeliefPropagation` class. Do you need to run BP each time? Justify your answer.

We only need to run BP when evidences change. Otherwise, we already have the information propagated through all nodes.

In [None]:
from pgmpy.models import FactorGraph
p2 = dict()
variables2 = ["w1, w2"]
p2["w1"] = DiscreteFactor(variables=["w1"],
                        cardinality=[len(state_names.get("w1"))],
                        values=[0.3, 0.4, 0.3],
                        state_names=state_names)

p2["w2|w1"] = day_transition("w1", "w2")
G2 = FactorGraph()
print(p2.values())
#assert set(variables2) == set([v for f in p2.values() for v in f.variables])
G2.add_nodes_from(["w1", "w2"])
G2.add_factors(p2["w2|w1"])
G2.add_edge("w1", p2["w2|w1"])
G2.add_edge("w2", p2["w2|w1"])
print("Model is ok: ", G.check_model()) 

dict_values([<DiscreteFactor representing phi(w1:3) at 0x7f6a803914c0>, <DiscreteFactor representing phi(w1:3, w2:3) at 0x7f6a803914f0>])
Model is ok:  True


In [None]:
p["w2|w1=sunny"] = bp.query(["w2"], {"w1":"rainy"}, True)
print(p["w2|w1=sunny"])

+------------+-----------+
| w2         |   phi(w2) |
| w2(sunny)  |    0.2500 |
+------------+-----------+
| w2(cloudy) |    0.5000 |
+------------+-----------+
| w2(rainy)  |    0.2500 |
+------------+-----------+


In [None]:
my_bp = MyBeliefPropagation(G2)
my_bp.set_evidence({"w1":"sunny"})
my_bp.run_bp("w2")
p2["w2|w1=sunny"] = my_bp.get_marginal("w2").normalize(inplace=False)
print(p2["w2|w1=sunny"])
my_bp.set_evidence({"w1":"cloudy"})
my_bp.run_bp("w2")
p2["w2|w1=cloudy"] = my_bp.get_marginal("w2").normalize(inplace=False)
print(p2["w2|w1=cloudy"])
my_bp.set_evidence({"w1":"rainy"})
my_bp.run_bp("w2")
p2["w2|w1=rainy"] = my_bp.get_marginal("w2").normalize(inplace=False)
print(p2["w2|w1=rainy"])

+------------+-----------+
| w2         |   phi(w2) |
| w2(sunny)  |    0.7000 |
+------------+-----------+
| w2(cloudy) |    0.2500 |
+------------+-----------+
| w2(rainy)  |    0.0500 |
+------------+-----------+
+------------+-----------+
| w2         |   phi(w2) |
| w2(sunny)  |    0.2500 |
+------------+-----------+
| w2(cloudy) |    0.3500 |
+------------+-----------+
| w2(rainy)  |    0.4000 |
+------------+-----------+
+------------+-----------+
| w2         |   phi(w2) |
| w2(sunny)  |    0.2500 |
+------------+-----------+
| w2(cloudy) |    0.5000 |
+------------+-----------+
| w2(rainy)  |    0.2500 |
+------------+-----------+


In [None]:
my_bp = MyBeliefPropagation(G)

my_bp.set_evidence({"u4":True, "u5": True})
my_bp.run_bp("w6")
p["w6|u4=t,u5=t"] = my_bp.get_marginal("w6").normalize(inplace=False)
print(p["w6|u4=t,u5=t"])

p["w7|u4=t,u5=t"] = my_bp.get_marginal("w7").normalize(inplace=False)
my_bp.run_bp("w7")
print(p["w7|u4=t,u5=t"])

my_bp.set_evidence({"u4":False, "u5": False})
my_bp.run_bp("w7")
p["w7|u4=f,u5=f"] = my_bp.get_marginal("w7").normalize(inplace=False)
print(p["w7|u4=f,u5=f"])

+------------+-----------+
| w6         |   phi(w6) |
| w6(sunny)  |    0.3026 |
+------------+-----------+
| w6(cloudy) |    0.4080 |
+------------+-----------+
| w6(rainy)  |    0.2894 |
+------------+-----------+
+------------+-----------+
| w7         |   phi(w7) |
| w7(sunny)  |    0.4000 |
+------------+-----------+
| w7(cloudy) |    0.3667 |
+------------+-----------+
| w7(rainy)  |    0.2333 |
+------------+-----------+
+------------+-----------+
| w7         |   phi(w7) |
| w7(sunny)  |    0.5226 |
+------------+-----------+
| w7(cloudy) |    0.3076 |
+------------+-----------+
| w7(rainy)  |    0.1698 |
+------------+-----------+


Use BP to compute the fourth query P(w1,w2|u1=t,u2=t). Note that we cannot do it with the ```get_marginal``` method as it is now. Define a method ```get_marginal_subset``` that receives a list of variables as input and, if they are connected, returns the marginal of the subset. Otherwise, throw an error. Check your result with the one produced by the `BeliefPropagation` class.

In [None]:
#p["w1,w2|u1=t,u2=t"] = my_bp.query(["w1", "w2"], {"u1":True, "u2": True}, True)
#print(p["w1,w2|u1=t,u2=t"])