In [None]:
# Write a program to construct a Bayesian network considering medical data. Use this model to
# demonstrate the diagnosis of heart patients using the standard Heart Disease Data Set (You can use
# Java/Python ML library classes/API.

In [2]:
!pip install pgmpy

import pandas as pd
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import ParameterEstimator, BayesianEstimator
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.inference import VariableElimination


[notice] A new release of pip is available: 24.1.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting pgmpy
  Downloading pgmpy-0.1.26-py3-none-any.whl.metadata (9.1 kB)
Collecting statsmodels (from pgmpy)
  Downloading statsmodels-0.14.4-cp311-cp311-win_amd64.whl.metadata (9.5 kB)
Collecting patsy>=0.5.6 (from statsmodels->pgmpy)
  Downloading patsy-1.0.1-py2.py3-none-any.whl.metadata (3.3 kB)
Downloading pgmpy-0.1.26-py3-none-any.whl (2.0 MB)
   ---------------------------------------- 0.0/2.0 MB ? eta -:--:--
   ---------- ----------------------------- 0.5/2.0 MB 11.3 MB/s eta 0:00:01
   --------------------------------- ------ 1.7/2.0 MB 21.1 MB/s eta 0:00:01
   ------------------------------------ --- 1.8/2.0 MB 16.3 MB/s eta 0:00:01
   ------------------------------------ --- 1.8/2.0 MB 16.3 MB/s eta 0:00:01
   ------------------------------------ --- 1.8/2.0 MB 16.3 MB/s eta 0:00:01
   ------------------------------------ --- 1.8/2.0 MB 16.3 MB/s eta 0:00:01
   ------------------------------------ --- 1.8/2.0 MB 16.3 MB/s eta 0:00:01
   -------------------------------

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
data = pd.read_csv('heart.csv')

In [14]:
data.head()

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,52,1,0,125,212,0,1,168,0,1.0,2,2,3,0
1,53,1,0,140,203,1,0,155,1,3.1,0,0,3,0
2,70,1,0,145,174,0,1,125,1,2.6,0,0,3,0
3,61,1,0,148,203,0,1,161,0,0.0,2,1,3,0
4,62,0,0,138,294,1,1,106,0,1.9,1,3,2,0


In [15]:
data.tail()

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
1020,59,1,1,140,221,0,1,164,1,0.0,2,0,2,1
1021,60,1,0,125,258,0,0,141,1,2.8,1,1,3,0
1022,47,1,0,110,275,0,0,118,1,1.0,1,1,2,0
1023,50,0,0,110,254,0,0,159,0,0.0,2,0,2,1
1024,54,1,0,120,188,0,1,113,0,1.4,1,1,3,0


In [6]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 303 entries, 0 to 302
Data columns (total 14 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   age           303 non-null    int64  
 1   sex           303 non-null    int64  
 2   cp            303 non-null    int64  
 3   trestbps      303 non-null    int64  
 4   chol          303 non-null    int64  
 5   fbs           303 non-null    int64  
 6   restecg       303 non-null    int64  
 7   thalach       303 non-null    int64  
 8   exang         303 non-null    int64  
 9   oldpeak       303 non-null    float64
 10  slope         303 non-null    int64  
 11  ca            303 non-null    object 
 12  thal          303 non-null    object 
 13  heartdisease  303 non-null    int64  
dtypes: float64(1), int64(11), object(2)
memory usage: 33.3+ KB


In [16]:
data.describe()

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
count,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0,1025.0
mean,54.434146,0.69561,0.942439,131.611707,246.0,0.149268,0.529756,149.114146,0.336585,1.071512,1.385366,0.754146,2.323902,0.513171
std,9.07229,0.460373,1.029641,17.516718,51.59251,0.356527,0.527878,23.005724,0.472772,1.175053,0.617755,1.030798,0.62066,0.50007
min,29.0,0.0,0.0,94.0,126.0,0.0,0.0,71.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,48.0,0.0,0.0,120.0,211.0,0.0,0.0,132.0,0.0,0.0,1.0,0.0,2.0,0.0
50%,56.0,1.0,1.0,130.0,240.0,0.0,1.0,152.0,0.0,0.8,1.0,0.0,2.0,1.0
75%,61.0,1.0,2.0,140.0,275.0,0.0,1.0,166.0,1.0,1.8,2.0,1.0,3.0,1.0
max,77.0,1.0,3.0,200.0,564.0,1.0,2.0,202.0,1.0,6.2,2.0,4.0,3.0,1.0


In [17]:
# creating instance of bayesian network
# here the tuples indicate the directed edges / relationships
# the first variable is considered to influence / is the parent of the second variable
# in the same tuple
# eg. age influences condition/target, cholestrol influences condition/target, etc
model = BayesianNetwork([
    ('age', 'target'),
    ('sex', 'target'),
    ('cp', 'target'),
    ('trestbps', 'target'),
    ('chol', 'target')
])

In [24]:
model.fit(data,estimator=MaximumLikelihoodEstimator)



In [25]:
# to perform inference on probabilistic graphical models
# eliminate irrelevant variables by marginalizing over them and focus on imp variables
inference = VariableElimination(model)

In [26]:
# 21 - records for age 50
# 0 - 9
# 1 - 12

# predict the output (child) based on the influencing factor (parent)
query1 = inference.query(variables=['target'], evidence={'age':50})
print(query1)

# the model uses probability
# P(A/B) = P(A ∩ B) / P(B)

+-----------+---------------+
| target    |   phi(target) |
| target(0) |        0.5001 |
+-----------+---------------+
| target(1) |        0.4999 |
+-----------+---------------+


In [27]:
query2 = inference.query(variables=['target'],evidence={'chol' : 220})
print(query2)

+-----------+---------------+
| target    |   phi(target) |
| target(0) |        0.4996 |
+-----------+---------------+
| target(1) |        0.5004 |
+-----------+---------------+


In [28]:
query3 = inference.query(variables=['target'],evidence={'trestbps' : 140})
print(query3)

+-----------+---------------+
| target    |   phi(target) |
| target(0) |        0.5000 |
+-----------+---------------+
| target(1) |        0.5000 |
+-----------+---------------+
