In [None]:
import json
import heapq
from collections import defaultdict
from itertools import product
from functools import reduce
from typing import List, Dict, Tuple, Any, Union
import operator

In [None]:
# Read in the data from the json file 
# path = 'Sample_Testcase.json'
path = 'Sample_Testcase.json'
with open(path, 'r') as file:
    data = json.load(file)

In [223]:
class Inference:
    def __init__(self, data : dict):
        """
        Initialize the Inference class with the input data.
        
        Parameters:
        -----------
        data : dict
            The input data containing the graphical model details, such as variables, cliques, potentials, and k value.
        
        What to do here:
        ----------------
        - Parse the input data and store necessary attributes (e.g., variables, cliques, potentials, k value).
        - Initialize any data structures required for triangulation, junction tree creation, and message passing.
        
        Refer to the sample test case for the structure of the input data.
        """
        self.testcaseNumber = data.get("TestCaseNumber", 0)
        self.variableCount = data.get("VariablesCount", 0)
        self.potentialCount = data.get("Potentials_count", 0)
        # List of cliques (each clique is a dictionary)
        self.cliquesAndPotentials = data.get("Cliques and Potentials", []) 
        self.k = data.get("k value (in top k)", 0)

        # Build the graph from the cliques. Nodes are labelled from 0 to variableCount - 1
        self.undirectedGraph = {i: set() for i in range(self.variableCount)}
        for clique in self.cliquesAndPotentials:
            nodes = clique.get("cliques", [])
            # For each clique, add edges between each pair of nodes
            for i in range(len(nodes)):
                for j in range(i + 1, len(nodes)):
                    self.undirectedGraph[nodes[i]].add(nodes[j])
                    self.undirectedGraph[nodes[j]].add(nodes[i])
        
        # Define the variables needed to construct the triangulated graph
        self.triangulatedGraph = {}
        self.maximalCliques = []

        # Define the variables needed to construct the junction tree
        self.junctionTree = []

        # Make a dictionary to store the potentials for each clique
        self.cliquePotentials = {}


    def triangulate_and_get_cliques(self):
        """
        Triangulate the undirected graph and extract the maximal cliques.
        
        What to do here:
        ----------------
        - Implement the triangulation algorithm to make the graph chordal.
        - Extract the maximal cliques from the triangulated graph.
        - Store the cliques for later use in junction tree creation.

        Refer to the problem statement for details on triangulation and clique extraction.
        """
        temporaryGraph = {n : set(self.undirectedGraph[n]) for n in self.undirectedGraph}
        chordalGraph = {n : set(self.undirectedGraph[n]) for n in self.undirectedGraph}

        # Initialize the degree info into a min heap
        degreeInfo = {node : len(temporaryGraph[node]) for node in temporaryGraph}
        # Initialize the heap
        degreeHeap = [(degreeInfo[node], node) for node in degreeInfo]
        heapq.heapify(degreeHeap)
        # Maintain a set for the nodes that have been eliminated
        eliminatedNodes = set()
        # Store the cliques from each elimination step
        candidateCliques = []

        # Eliminate the nodes one by one to triangulate the graph
        while degreeHeap:
            currDegree, currNode = heapq.heappop(degreeHeap)
            if currNode in eliminatedNodes or degreeInfo[currNode] != currDegree:
                continue

            # Mark the node as eliminated
            eliminatedNodes.add(currNode)
            # Get the neighbours of the node that are not eliminated
            neighbours = temporaryGraph[currNode] - eliminatedNodes
            # Record the clique as the current node and the neighbours
            currClique = set(neighbours)
            currClique.add(currNode)
            candidateCliques.append(currClique)

            # Add fill edges among the neighours in the chordal graph
            for i in neighbours:
                for j in neighbours:
                    # Add edge in the chordal graph if it is not already present
                    if i == j:
                        continue
                    if j not in chordalGraph[i]:
                        chordalGraph[i].add(j)
                        chordalGraph[j].add(i)
                    # Add edge in the temporary graph if it is not already present
                    if j not in temporaryGraph[i]:
                        temporaryGraph[i].add(j)
                        temporaryGraph[j].add(i)
                        degreeInfo[i] += 1
                        degreeInfo[j] += 1
            
            # Remove the current node from the neighbours
            for neighbour in neighbours:
                temporaryGraph[neighbour].discard(currNode)
            temporaryGraph[currNode].clear()
            # Update the degree info
            degreeInfo[currNode] = 0

            # Update the heap with the new degrees
            for neighbour in neighbours:
                heapq.heappush(degreeHeap, (degreeInfo[neighbour], neighbour))
        
        # After exiting the while loop the graph is now triangulated
        self.triangulatedGraph = chordalGraph
        # Extract the maximal cliques from the candidate cliques by removing the duplicates and subsets
        candidateCliques = list(set(frozenset(clique) for clique in candidateCliques))
        finalCliques = []
        for clique in candidateCliques:
            if not any(clique < others for others in candidateCliques if clique != others):
                finalCliques.append(clique)
        self.maximalCliques = [set(clique) for clique in finalCliques]

    
    def get_junction_tree(self):
        """
        Construct the junction tree from the maximal cliques.
        
        What to do here:
        ----------------
        - Create a junction tree using the maximal cliques obtained from the triangulated graph.
        - For each pair of cliques, compute the common variables.
          Then define the directed separator sets:
              S_ij = clique_i - (clique_i ∩ clique_j)
              S_ji = clique_j - (clique_i ∩ clique_j)
        - Use the size of the common set as the weight to construct a maximum spanning tree.
        - Store the junction tree as a list of tuples:
              (clique_i, clique_j, S_ij, S_ji)
          where S_ij is the separator when a message is passed from clique i to clique j,
          and S_ji is for the reverse direction.
        """
        # Build the weighted graph among the cliques
        cliques = self.maximalCliques
        cliqueCount = len(cliques)

        if cliqueCount <= 1:
            self.junctionTree = []
            return
        
        # Collect the edges of the graph as (weight, i, j) where i and j are the indices of the cliques
        weightedEdges = []
        for i in range(cliqueCount):
            for j in range(i + 1, cliqueCount):
                intersection = cliques[i].intersection(cliques[j])
                weight = len(intersection)
                # Add the edge only if the intersection is non-empty
                if weight > 0:
                    # Store the negative weight to use the min heap as a max heap
                    weightedEdges.append((-weight, frozenset(cliques[i]), frozenset(cliques[j])))
        
        # If there are no edges, then the junction tree is empty
        if not weightedEdges and cliqueCount > 0:
            # Implies that original graph has no intersecting cliques
            self.junctionTree = []
            return
        print("Weighted Edges : ", weightedEdges)
        
        # Initialize the heap with the edges
        heapq.heapify(weightedEdges)

        # Initialize the union-find set to keep track of the connected components
        parent = {tuple(clique): tuple(clique) for clique in cliques}
        rank = {tuple(clique): 0 for clique in cliques}

        # Define the find and union functions for the union-find set
        def find(x):
            if parent[tuple(x)] != tuple(x):
                parent[tuple(x)] = find(parent[tuple(x)])
            return parent[tuple(x)]

        def union(x, y):
            rootX, rootY = find(x), find(y)
            if rootX == rootY:
                return False
            # Union by Rank
            if rank[rootX] > rank[rootY]:
                parent[rootY] = rootX
            elif rank[rootX] < rank[rootY]:
                parent[rootX] = rootY
            else:
                parent[rootY] = rootX
                rank[rootX] += 1
            return True
        
        # Initialize the junction tree as a list of tuples
        spanningTreeEdges = []
        components = cliqueCount
        while weightedEdges and components > 1:
            weight, i, j = heapq.heappop(weightedEdges)
            if union(i, j):
                components -= 1
                spanningTreeEdges.append((i, j, -weight))
        print("Spanning Tree Edges : ", spanningTreeEdges)
        
        # Define the directed separator sets for each edge in the junction tree
        junctionTree = []
        for i, j, weight in spanningTreeEdges:
            intersection = i.intersection(j)
            separatorIJ = i - intersection
            separatorJI = j - intersection
            junctionTree.append((i, j, separatorIJ, separatorJI))
            # junctionTree.append((j, i, separatorJI, separatorIJ))
        
        self.junctionTree = junctionTree

    
    def assign_potentials_to_cliques(self):
        """
        Assign potentials to the cliques in the junction tree.
        
        - Map the given potentials (from the input data) to the corresponding cliques in the junction tree.
        - If multiple potential tables are associated with the same clique, multiply the values element-wise.
        """
        frozenMaximal = [frozenset(clique) for clique in self.maximalCliques]
        
        for frozenClique in frozenMaximal:
            cliqueVars = sorted(list(frozenClique))
            assignments = list(product([0, 1], repeat=len(cliqueVars)))
            potentialTable = defaultdict(lambda: 1)
            
            for potentialDict in self.cliquesAndPotentials:
                potentialVars = potentialDict.get("cliques", [])
                potentialValues = potentialDict.get("potentials", [])
                
                if frozenset(potentialVars).issubset(frozenClique):
                    potVarPositions = [cliqueVars.index(var) for var in potentialVars]
                    potAssignments = list(product([0, 1], repeat=len(potVarPositions)))
                    
                    for assignment in assignments:
                        potAssignment = tuple(assignment[pos] for pos in potVarPositions)
                        potentialTable[assignment] *= potentialValues[potAssignments.index(potAssignment)]
            
            self.cliquePotentials[frozenClique] = potentialTable
    

    def get_z_value(self):
        """
        Compute the partition function (Z value) of the graphical model.
        
        What to do here:
        ----------------
        - Implement the message passing algorithm to compute the partition function (Z value).
        - The Z value is the normalization constant for the probability distribution.
        
        Refer to the problem statement for details on computing the partition function.
        """
        messages = {}
        for clique1, clique2, separator12, separator21 in self.junctionTree:
            messages[(clique1, clique2)] = None
            messages[(clique2, clique1)] = None
        
        def computeMessage(source, destination, separator):
            print(f"[COMPUTE MESSAGE] Called with {source}, {destination}")
            sourceVars = sorted(list(source))
            destVars = sorted(list(destination))
            separatorVars = sorted(list(separator))

            sourceToSeparator = [sourceVars.index(var) for var in separatorVars]
            message = defaultdict(float)
            sepAssignments = list(product([0, 1], repeat=len(separatorVars)))

            sourcePotential = self.cliquePotentials[source]

            for sepAssignment in sepAssignments:
                sumValue = 0

                for sourceAssignment, potential in sourcePotential.items():
                    if all(sourceAssignment[pos] == sepVal for pos, sepVal in zip(sourceToSeparator, sepAssignment)):
                        totalValue = potential

                        for edge in self.junctionTree:
                            otherSource, otherDest, _, _ = edge
                            if otherSource == source and otherDest != destination:
                                if messages[(otherSource, otherDest)] is not None:
                                    sepVars = sorted(list(set(source) & set(otherDest)))
                                    sepIndices = [sourceVars.index(var) for var in sepVars]
                                    messageAssignment = tuple(sourceAssignment[index] for index in sepIndices)
                                    totalValue *= messages[(otherDest, source)][messageAssignment]
                        sumValue += totalValue
                message[sepAssignment] = sumValue
            print(f"[COMPUTE MESSAGE] Returning message : {dict(message)}")
            return dict(message)
        
        # Define the root clique
        rootClique = self.junctionTree[0][0]

        def collectMessages(current, parent = None):
            print("Called computeMessage with current : ", current, " and parent : ", parent)
            for edge in self.junctionTree:
                source, destination, _, _ = edge
                if source == current and destination != parent:
                    print("Reached source == current and destination != parent")
                    collectMessages(destination, current)
                    separator = set(source) & set(destination)
                    messages[(destination, source)] = computeMessage(destination, source, separator)
                    print("Message : ", messages[(destination, source)])
                elif destination == current and source != parent:
                    print("Reached destination == current and source != parent")
                    collectMessages(source, current)
                    separator = set(source) & set(destination)
                    messages[(source, destination)] = computeMessage(source, destination, separator)
                    print("Message : ", messages[(source, destination)])
        
        collectMessages(rootClique)

        # Compute the partition function
        rootPotential = self.cliquePotentials[rootClique]
        zValue = 0

        for assignment, potential in rootPotential.items():
            totalValue = potential
            for edge in self.junctionTree:
                source, destination, _, _ = edge
                if destination == rootClique:
                    if messages[(source, destination)] is not None:
                        print("[GET Z VALUE]Reached destination == rootClique and message is not None")
                        separator = set(source) & set(destination)
                        sepVars = sorted(list(separator))
                        targetVars = sorted(list(destination))
                        sepIndices = [targetVars.index(var) for var in sepVars]
                        messageAssignment = tuple(assignment[index] for index in sepIndices)
                        totalValue *= messages[(source, destination)][messageAssignment]
                        print("[GET Z VALUE]Total value so far : ", totalValue)
                elif source == rootClique:
                    if messages[(destination, source)] is not None:
                        print("[GET Z VALUE]Reached source == rootClique and message is not None")
                        separator = set(source) & set(destination)
                        sepVars = sorted(list(separator))
                        sourceVars = sorted(list(source))
                        sepIndices = [sourceVars.index(var) for var in sepVars]
                        messageAssignment = tuple(assignment[index] for index in sepIndices)
                        totalValue *= messages[(destination, source)][messageAssignment]
            zValue += totalValue
        return zValue

      
            


            
            
            
            





            
            






In [224]:
# Helper Functions to test the implementation
def printUndirectedGraph(graph):
    print("Undirected Graph:")
    for node, neighbors in graph.items():
        print(f"{node}: {sorted(neighbors)}")

def printTriangulatedGraph(graph):
    print("Triangulated Graph:")
    for node, neighbors in graph.items():
        print(f"{node}: {sorted(neighbors)}")

In [225]:
sampleInference = Inference(data[0]['Input'])
print(sampleInference.testcaseNumber)
print(sampleInference.variableCount)
print(sampleInference.potentialCount)
print(sampleInference.cliquesAndPotentials)

1
5
15
[{'clique_size': 1, 'cliques': [0], 'potentials': [11, 4]}, {'clique_size': 1, 'cliques': [1], 'potentials': [8, 12]}, {'clique_size': 1, 'cliques': [2], 'potentials': [19, 3]}, {'clique_size': 1, 'cliques': [3], 'potentials': [11, 6]}, {'clique_size': 1, 'cliques': [4], 'potentials': [12, 8]}, {'clique_size': 2, 'cliques': [2, 4], 'potentials': [20, 18, 3, 6]}, {'clique_size': 2, 'cliques': [0, 2], 'potentials': [19, 6, 4, 1]}, {'clique_size': 2, 'cliques': [2, 4], 'potentials': [15, 18, 12, 14]}, {'clique_size': 1, 'cliques': [4], 'potentials': [5, 9]}, {'clique_size': 2, 'cliques': [1, 4], 'potentials': [14, 5, 20, 20]}, {'clique_size': 2, 'cliques': [0, 3], 'potentials': [8, 8, 13, 12]}, {'clique_size': 2, 'cliques': [0, 1], 'potentials': [5, 3, 8, 11]}, {'clique_size': 2, 'cliques': [2, 3], 'potentials': [17, 15, 18, 2]}, {'clique_size': 2, 'cliques': [0, 3], 'potentials': [9, 7, 17, 9]}, {'clique_size': 2, 'cliques': [0, 2], 'potentials': [9, 19, 20, 9]}]


In [226]:
sampleInference.triangulate_and_get_cliques()
sampleInference.get_junction_tree()
print("Junction Tree")
for clique1, clique2, separator12, separator21 in sampleInference.junctionTree:
    print(f"Clique 1: {sorted(clique1)}")
    print(f"Clique 2: {sorted(clique2)}")
    print(f"Separator 1->2: {sorted(separator12)}")
    print(f"Separator 2->1: {sorted(separator21)}")
    print()

print("Maximal Cliques")
for clique in sampleInference.maximalCliques:
    print(clique)

sampleInference.assign_potentials_to_cliques()
print("Clique Potentials")
for clique, potential in sampleInference.cliquePotentials.items():
    print(f"Clique: {sorted(list(clique))}")
    print("Potential:")
    for assignment, value in potential.items():
        print(f"{assignment}: {value}")
    print()
print(sampleInference.get_z_value())
# sampleInference.get_z_value()

Weighted Edges :  [(-1, frozenset({0, 1, 4}), frozenset({0, 2, 3})), (-2, frozenset({0, 1, 4}), frozenset({0, 2, 4})), (-2, frozenset({0, 2, 3}), frozenset({0, 2, 4}))]
Spanning Tree Edges :  [(frozenset({0, 2, 3}), frozenset({0, 2, 4}), 2), (frozenset({0, 1, 4}), frozenset({0, 2, 4}), 2)]
Junction Tree
Clique 1: [0, 2, 3]
Clique 2: [0, 2, 4]
Separator 1->2: [3]
Separator 2->1: [4]

Clique 1: [0, 1, 4]
Clique 2: [0, 2, 4]
Separator 1->2: [1]
Separator 2->1: [2]

Maximal Cliques
{0, 1, 4}
{0, 2, 3}
{0, 2, 4}
Clique Potentials
Clique: [0, 1, 4]
Potential:
(0, 0, 0): 369600
(0, 0, 1): 158400
(0, 1, 0): 475200
(0, 1, 1): 570240
(1, 0, 0): 215040
(1, 0, 1): 92160
(1, 1, 0): 633600
(1, 1, 1): 760320

Clique: [0, 2, 3]
Potential:
(0, 0, 0): 481189896
(0, 0, 1): 180124560
(0, 1, 0): 53631072
(0, 1, 1): 2528064
(1, 0, 0): 251268160
(1, 0, 1): 59097600
(1, 1, 0): 4725864
(1, 1, 1): 139968

Clique: [0, 2, 4]
Potential:
(0, 0, 0): 643302000
(0, 0, 1): 833719392
(0, 1, 0): 8125920
(0, 1, 1): 227525