In [1]:
import pandas as pd
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.inference import VariableElimination
data = pd.read_csv("heart.csv")

display(data)

Unnamed: 0,Age,Gender,ChestPainType,RestingBloodPressure,Cholesterol,FastingBloodSugar,RestingECG,MaxHeartRate,ExerciseInducedAngina,STDepression,ST_Slope,NumMajorVessels,Thalassemia,HeartDisease
0,52,1,0,125,212,0,1,168,0,1.0,2,2,3,0
1,53,1,0,140,203,1,0,155,1,3.1,0,0,3,0
2,70,1,0,145,174,0,1,125,1,2.6,0,0,3,0
3,61,1,0,148,203,0,1,161,0,0.0,2,1,3,0
4,62,0,0,138,294,1,1,106,0,1.9,1,3,2,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1020,59,1,1,140,221,0,1,164,1,0.0,2,0,2,1
1021,60,1,0,125,258,0,0,141,1,2.8,1,1,3,0
1022,47,1,0,110,275,0,0,118,1,1.0,1,1,2,0
1023,50,0,0,110,254,0,0,159,0,0.0,2,0,2,1


In [2]:
# Displays the summary of the data set
display(data.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1025 entries, 0 to 1024
Data columns (total 14 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   Age                    1025 non-null   int64  
 1   Gender                 1025 non-null   int64  
 2   ChestPainType          1025 non-null   int64  
 3   RestingBloodPressure   1025 non-null   int64  
 4   Cholesterol            1025 non-null   int64  
 5   FastingBloodSugar      1025 non-null   int64  
 6   RestingECG             1025 non-null   int64  
 7   MaxHeartRate           1025 non-null   int64  
 8   ExerciseInducedAngina  1025 non-null   int64  
 9   STDepression           1025 non-null   float64
 10  ST_Slope               1025 non-null   int64  
 11  NumMajorVessels        1025 non-null   int64  
 12  Thalassemia            1025 non-null   int64  
 13  HeartDisease           1025 non-null   int64  
dtypes: float64(1), int64(13)
memory usage: 112.2 KB


None

In [3]:

# Initializes the Bayesian network model with nodes and edges
model = BayesianNetwork([
    ('Age', 'HeartDisease'), 
    ('Gender', 'HeartDisease'),
    ('ChestPainType','HeartDisease'), 
    ('ExerciseInducedAngina','HeartDisease'),
    ('HeartDisease','RestingECG'),
    ('HeartDisease','Cholesterol')
])

In [4]:
# Fits the model to calculate Conditional Probability Distribution (CPD) 
# at each node using estimator MaximumLikelihoodEstimator

model.fit(data, estimator=MaximumLikelihoodEstimator)

MaximumLikelihoodEstimator estimator calculates the Conditional Probability Distributions (CPDs) for each node based on the data.

The CPD for the HeartDisease node is estimated.

CPDs represent the probability of each state of a node given the states of its parent nodes.


In [5]:
cpd_HeartDisease = MaximumLikelihoodEstimator(model, data).estimate_cpd('HeartDisease')

In [6]:
print(cpd_HeartDisease)

+-----------------------+-----+--------------------------+
| Age                   | ... | Age(77)                  |
+-----------------------+-----+--------------------------+
| ChestPainType         | ... | ChestPainType(3)         |
+-----------------------+-----+--------------------------+
| ExerciseInducedAngina | ... | ExerciseInducedAngina(1) |
+-----------------------+-----+--------------------------+
| Gender                | ... | Gender(1)                |
+-----------------------+-----+--------------------------+
| HeartDisease(0)       | ... | 0.5                      |
+-----------------------+-----+--------------------------+
| HeartDisease(1)       | ... | 0.5                      |
+-----------------------+-----+--------------------------+


In [7]:
# The Variable Elimination method leverages the properties of factorization and conditional independence to efficiently eliminate variables from the joint probability distribution.
inference = VariableElimination(model)

In [8]:
# Query to infer probability for heart disease given RestingECG
print(inference.query(variables=['HeartDisease'], evidence={'RestingECG': 1}))

+-----------------+---------------------+
| HeartDisease    |   phi(HeartDisease) |
| HeartDisease(0) |              0.4354 |
+-----------------+---------------------+
| HeartDisease(1) |              0.5646 |
+-----------------+---------------------+
