In [1]:
from collections.abc import Iterator

from dataclasses import dataclass
from dataclasses import field

# from prettyprinter import pprint
from pprint import pprint

import itertools


@dataclass(frozen=True)
class Variable:
    """
    Represents a random variable in a Bayesian network.

    Attributes:
        name (str): The name of the variable.
        r (int): The number of possible values (states) the variable can take.
    """

    name: str
    r: int  # Number of possible values

    def __repr__(self) -> str:
        return f'{self.name}'


@dataclass(frozen=True)
class Assignment:
    """
    Represents an assignment of values to variables.

    Attributes:
        values (dict[str, int]): A mapping from variable names to their assigned values.
    """

    values: dict[str, int]

    def __getitem__(self, var_name: str) -> int:
        return self.values[var_name]

    def __repr__(self) -> str:
        return f'Assignment({self.values})'

    def __hash__(self) -> int:
        # Use frozenset of items to make it hashable
        return hash(frozenset(self.values.items()))

    def __eq__(self, other) -> bool:
        if not isinstance(other, Assignment):
            return NotImplemented
        return self.values == other.values


@dataclass
class Factor:
    """
    Represents a factor in the Bayesian network.

    Attributes:
        variables (list[Variable]): Variables involved in the factor.
        table (dict[Assignment, float]): A mapping from variable assignments to probabilities.
    """

    variables: list[Variable]
    table: dict[Assignment, float] = field(default_factory=dict)

    def __post_init__(self) -> None:
        self.var_names = [var.name for var in self.variables]

    def assignments(self) -> list[Assignment]:
        """
        Generates all possible assignments for the variables in this factor.

        Returns:
            list[Assignment]: A list of all possible assignments.
        """
        ranges = [range(1, var.r + 1) for var in self.variables]
        assignments: list[Assignment] = []
        for values in itertools.product(*ranges):
            assignment = Assignment(
                values=dict(zip(self.var_names, values, strict=False))
            )
            assignments.append(assignment)
        return assignments

    def normalize(self) -> None:
        """
        Normalizes the factor so that the sum of probabilities is 1.
        """
        total = sum(self.table.values())
        if total != 0:
            for assignment in self.table:
                self.table[assignment] /= total

    def marginalize(
        self, var_to_marginalize: Variable
    ) -> 'Factor':  # exclude a factor & compute chances
        """
        Marginalizes out a variable from the factor.

        Args:
            var_to_marginalize (Variable): The variable to be marginalized out.

        Returns:
            Factor: A new factor with the variable marginalized out.
        """
        new_vars = [var for var in self.variables if var != var_to_marginalize]
        new_table = {}
        for assignment in self.assignments():
            # Remove the variable to marginalize from the assignment
            new_assignment_values = {
                k: v
                for k, v in assignment.values.items()
                if k != var_to_marginalize.name
            }
            new_assignment = Assignment(new_assignment_values)
            prob = self.table.get(assignment, 0.0)
            # Sum probabilities for the same assignment after removing the variable
            new_table[new_assignment] = (
                new_table.get(new_assignment, 0.0) + prob
            )
        return Factor(new_vars, new_table)

    def __mul__(self, other: 'Factor') -> 'Factor':
        """
        Multiplies this factor with another factor.

        Args:
            other (Factor): The other factor to multiply with.

        Returns:
            Factor: A new factor resulting from the multiplication.
        """
        # Determine the set of variables in the new factor
        new_vars_dict = {
            var.name: var for var in self.variables + other.variables
        }
        new_vars = list(new_vars_dict.values())
        new_var_names = [var.name for var in new_vars]

        # Prepare the table for the new factor
        new_table = {}

        # Generate all possible assignments for new_vars
        ranges = [range(1, var.r + 1) for var in new_vars]
        for values in itertools.product(*ranges):
            assignment_values = dict(zip(new_var_names, values, strict=False))
            assignment = Assignment(assignment_values)

            # Get the values from self and other for this assignment
            self_assignment_values = {
                k: v
                for k, v in assignment_values.items()
                if k in [var.name for var in self.variables]
            }
            self_assignment = Assignment(self_assignment_values)
            self_value = self.table.get(self_assignment, 0.0)

            other_assignment_values = {
                k: v
                for k, v in assignment_values.items()
                if k in [var.name for var in other.variables]
            }
            other_assignment = Assignment(other_assignment_values)
            other_value = other.table.get(other_assignment, 0.0)

            # Multiply the values
            new_table[assignment] = self_value * other_value

        return Factor(new_vars, new_table)

    def __repr__(self) -> str:
        table_str = '\n'.join([str(item) for item in self.table.items()])
        return f'Factor(variables={self.var_names}, table=\n{table_str})'


# Example usage:

# Define variables
x = Variable('x', 2)
y = Variable('y', 2)
z = Variable('z', 2)


# Define a factor table
ft = {
    Assignment(values={'x': 1, 'y': 1, 'z': 1}): 64,
    Assignment(values={'x': 1, 'y': 1, 'z': 2}): 32,
    Assignment(values={'x': 1, 'y': 2, 'z': 1}): 16,
    Assignment(values={'x': 1, 'y': 2, 'z': 2}): 8,
    Assignment(values={'x': 2, 'y': 1, 'z': 1}): 4,
    Assignment(values={'x': 2, 'y': 1, 'z': 2}): 2,
    Assignment(values={'x': 2, 'y': 2, 'z': 1}): 1,
    Assignment(values={'x': 2, 'y': 2, 'z': 2}): 1,
}

# Create a factor
phi = Factor(variables=[x, y, z], table=ft)

print(phi)


print()


# Normalize the factor
phi.normalize()

pprint(phi, indent=2)

# Marginalize out variable 'z'
phi_marginalized = phi.marginalize(z)

print()
pprint(phi_marginalized, indent=2)
# Multiply two factors (assuming another factor 'psi' is defined)

psi = Factor(
    variables=[x, y, z],
    table={
        Assignment(values={'x': 1, 'y': 1, 'z': 1}): -1,
        Assignment(values={'x': 1, 'y': 1, 'z': 2}): -1,
        Assignment(values={'x': 1, 'y': 2, 'z': 1}): -1,
        Assignment(values={'x': 1, 'y': 2, 'z': 2}): 10,
        Assignment(values={'x': 2, 'y': 1, 'z': 1}): -1,
        Assignment(values={'x': 2, 'y': 1, 'z': 2}): -1,
        Assignment(values={'x': 2, 'y': 2, 'z': 1}): -1,
        Assignment(values={'x': 2, 'y': 2, 'z': 2}): -1,
    },
)


phi_psi = phi * psi
print()
pprint(phi_psi)

Factor(variables=['x', 'y', 'z'], table=
(Assignment({'x': 1, 'y': 1, 'z': 1}), 64)
(Assignment({'x': 1, 'y': 1, 'z': 2}), 32)
(Assignment({'x': 1, 'y': 2, 'z': 1}), 16)
(Assignment({'x': 1, 'y': 2, 'z': 2}), 8)
(Assignment({'x': 2, 'y': 1, 'z': 1}), 4)
(Assignment({'x': 2, 'y': 1, 'z': 2}), 2)
(Assignment({'x': 2, 'y': 2, 'z': 1}), 1)
(Assignment({'x': 2, 'y': 2, 'z': 2}), 1))

Factor(variables=['x', 'y', 'z'], table=
(Assignment({'x': 1, 'y': 1, 'z': 1}), 0.5)
(Assignment({'x': 1, 'y': 1, 'z': 2}), 0.25)
(Assignment({'x': 1, 'y': 2, 'z': 1}), 0.125)
(Assignment({'x': 1, 'y': 2, 'z': 2}), 0.0625)
(Assignment({'x': 2, 'y': 1, 'z': 1}), 0.03125)
(Assignment({'x': 2, 'y': 1, 'z': 2}), 0.015625)
(Assignment({'x': 2, 'y': 2, 'z': 1}), 0.0078125)
(Assignment({'x': 2, 'y': 2, 'z': 2}), 0.0078125))

Factor(variables=['x', 'y'], table=
(Assignment({'x': 1, 'y': 1}), 0.75)
(Assignment({'x': 1, 'y': 2}), 0.1875)
(Assignment({'x': 2, 'y': 1}), 0.046875)
(Assignment({'x': 2, 'y': 2}), 0.015625))



### Define Variables

In [2]:
D = Variable('D', 2)  # Disease: 1 - Yes, 2 - No
T = Variable('T', 2)  # Test Result: 1 - Positive, 2 - Negative
S = Variable('S', 2)  # Sympton: 1 - Present, 2 - Absent

### Define Probability Tables

In [3]:
P_D = {Assignment(values=dict(D=1)): 0.01, Assignment(values=dict(D=2)): 0.99}
P_D

{Assignment({'D': 1}): 0.01, Assignment({'D': 2}): 0.99}

In [4]:
P_T_given_D = {
    Assignment(values=dict(T=1, D=1)): 0.95,  # TP: True Positive rate
    Assignment(values=dict(T=2, D=1)): 0.05,  # FT: False Negative rate
    Assignment(values=dict(T=1, D=2)): 0.1,  # FN: False Positive rate
    Assignment(values=dict(T=2, D=2)): 0.9,  # TN: True Negative rate
}
P_T_given_D

{Assignment({'T': 1, 'D': 1}): 0.95,
 Assignment({'T': 2, 'D': 1}): 0.05,
 Assignment({'T': 1, 'D': 2}): 0.1,
 Assignment({'T': 2, 'D': 2}): 0.9}

In [5]:
P_S_given_D = {
    Assignment(values=dict(S=1, D=1)): 0.8,
    Assignment(values=dict(S=2, D=1)): 0.2,
    Assignment(values=dict(S=1, D=2)): 0.3,
    Assignment(values=dict(S=2, D=2)): 0.7,
}
P_S_given_D

{Assignment({'S': 1, 'D': 1}): 0.8,
 Assignment({'S': 2, 'D': 1}): 0.2,
 Assignment({'S': 1, 'D': 2}): 0.3,
 Assignment({'S': 2, 'D': 2}): 0.7}

### Construct the Bayesian Network

In [6]:
F_D = Factor(variables=[D], table=P_D)
F_D

Factor(variables=['D'], table=
(Assignment({'D': 1}), 0.01)
(Assignment({'D': 2}), 0.99))

In [7]:
F_T_given_D = Factor(variables=[D, T], table=P_T_given_D)
F_T_given_D

Factor(variables=['D', 'T'], table=
(Assignment({'T': 1, 'D': 1}), 0.95)
(Assignment({'T': 2, 'D': 1}), 0.05)
(Assignment({'T': 1, 'D': 2}), 0.1)
(Assignment({'T': 2, 'D': 2}), 0.9))

In [8]:
F_S_given_D = Factor(variables=[D, S], table=P_S_given_D)
F_S_given_D

Factor(variables=['D', 'S'], table=
(Assignment({'S': 1, 'D': 1}), 0.8)
(Assignment({'S': 2, 'D': 1}), 0.2)
(Assignment({'S': 1, 'D': 2}), 0.3)
(Assignment({'S': 2, 'D': 2}), 0.7))

### Perform Inference

In [9]:
# P(D | T)
F_DT_joined = F_T_given_D * F_D

In [10]:
# P(D | T=1)
table_F_D_given_T_positive = {
    assignment: prob
    for assignment, prob in F_DT_joined.table.items()
    if assignment.values['T'] == 1
}
print(table_F_D_given_T_positive)

F_D_given_T_positive = Factor(variables=[D], table=table_F_D_given_T_positive)
print()
print(F_D_given_T_positive)
print()
F_D_given_T_positive.normalize()
print(F_D_given_T_positive)

{Assignment({'D': 1, 'T': 1}): 0.0095, Assignment({'D': 2, 'T': 1}): 0.099}

Factor(variables=['D'], table=
(Assignment({'D': 1, 'T': 1}), 0.0095)
(Assignment({'D': 2, 'T': 1}), 0.099))

Factor(variables=['D'], table=
(Assignment({'D': 1, 'T': 1}), 0.08755760368663594)
(Assignment({'D': 2, 'T': 1}), 0.9124423963133641))


### Analyze the Result

`The difference between the prior probability and the posterior probability.`

P(D) - prior probability of disease
P(D=1) = 0.01, meaning that the patient has a 1% chance of having the disease 
without any other information (prior knowledge). Just % of sick people in population

P(D=1 | T=1) - posterior probability of disease given a positive test result.
After seeing the positive test result (T=1), the probability of the disease (D=1) rises significantly: P(D=1 | T=1) ≈ 8.76%. 
This means that even with a positive test, the patient still only has about 8.76% chance of having the disease.

`How the positive test result affects the belief about the disease.`

A positive test significantly increases the belief that the patient could have the disease,
raising the likelihood from 1% to 8.76%.
The test strongly suggests a possible disease, but it it has high chance of wrong diagnosis. Patient should be informed.

`The implications in decision-making, such as whether to recommend treatment.`

Despite the positive test result, the posterior probability of the disease is still relatively low ~9%.
- If disease is severe and the treatment is both non-harmful and urgent treatment could be recommended
- Else, given that the chance of the disease is still under 10%, and the false positive rate is high 10%,
  it might be better to recommend another test.

Thus, the decision depends on the severity of the disease and the associated riskss with treatment, illness and price of tests.

### Extension (Optional)

In [11]:
# P(D | S)
F_DS_joined = F_S_given_D * F_D

In [12]:
table_F_D_given_S_present = {
    assignment: prob
    for assignment, prob in F_DS_joined.table.items()
    if assignment.values['S'] == 1
}
print(table_F_D_given_S_present)

F_D_given_S_present = Factor(variables=[D], table=table_F_D_given_S_present)
print()
print(F_D_given_S_present)
print()
F_D_given_S_present.normalize()
print(F_D_given_S_present)

{Assignment({'D': 1, 'S': 1}): 0.008, Assignment({'D': 2, 'S': 1}): 0.297}

Factor(variables=['D'], table=
(Assignment({'D': 1, 'S': 1}), 0.008)
(Assignment({'D': 2, 'S': 1}), 0.297))

Factor(variables=['D'], table=
(Assignment({'D': 1, 'S': 1}), 0.02622950819672131)
(Assignment({'D': 2, 'S': 1}), 0.9737704918032787))


In [13]:
# P(D | T=1 | S=1)
P_D_given_T1_S1 = F_D_given_S_present * F_D_given_T_positive
print(P_D_given_T1_S1)
P_D_given_T1_S1.normalize()
print(P_D_given_T1_S1)

Factor(variables=['D'], table=
(Assignment({'D': 1}), 0.0)
(Assignment({'D': 2}), 0.0))
Factor(variables=['D'], table=
(Assignment({'D': 1}), 0.0)
(Assignment({'D': 2}), 0.0))


### Reflect