## Spark Background

### What is Spark?
Apache Spark is an open-source, distributed computing framework designed for big data processing. It distributes tasks across multiple nodes, enabling fast, scalable data analysis and machine learning. This is critical for efficiently processing the Heart Failure Prediction dataset.

### Why is Spark a popular framework?
Spark is popular for its speed (in-memory processing), multi-language support (Python, Scala, Java), and versatility in handling SQL queries, streaming, and ML tasks. Its unified engine simplifies big data workflows, making it ideal for data scientists and engineers.

### What is Spark SQL and why does it exist?
Spark SQL is a Spark module that allows SQL-like queries on structured DataFrames. It exists to bridge traditional database querying with big data, enabling users familiar with SQL to analyze large datasets seamlessly.

### What is a Spark DataFrame, and why is it useful?
A Spark DataFrame is a distributed, table-like structure with rows and columns, scalable across clusters. It supports SQL queries, integrates with Spark ML, and optimizes performance, making it perfect for analyzing and modeling the Heart Failure dataset.

## Setup and Data Collection

**Dataset**: Heart Failure Prediction from Kaggle ([https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction](https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction)).  
**Why Chosen**: This dataset is ideal for binary classification (HeartDisease: 0 = no disease, 1 = disease), with labeled tabular data (12 columns, 918 rows). Its health focus is engaging, and no PySpark solutions exist in Kaggle’s “Code” tab (verified on April 25, 2025).  
**Data Fetch**: We fetch `heart.csv` using `!wget` or upload manually to Colab.

In [None]:
# Install PySpark and GraphFrames
!pip install pyspark graphframes

# Import SparkSession
from pyspark.sql import SparkSession

# Create Spark session with GraphFrames support
spark = SparkSession.builder \
    .appName("HeartFailurePrediction") \
    .config("spark.jars.packages", "graphframes:graphframes:0.8.2-spark3.1-s_2.12") \
     .getOrCreate()
from google.colab import files
uploaded = files.upload()
!unzip -o archive.zip

# Load dataset
df = spark.read.csv("/content/heart.csv", header=True, inferSchema=True)
print("First 5 rows of dataset:")
df.show(5)



Saving archive.zip to archive (1).zip
Archive:  archive.zip
  inflating: heart.csv               
First 5 rows of dataset:
+---+---+-------------+---------+-----------+---------+----------+-----+--------------+-------+--------+------------+
|Age|Sex|ChestPainType|RestingBP|Cholesterol|FastingBS|RestingECG|MaxHR|ExerciseAngina|Oldpeak|ST_Slope|HeartDisease|
+---+---+-------------+---------+-----------+---------+----------+-----+--------------+-------+--------+------------+
| 40|  M|          ATA|      140|        289|        0|    Normal|  172|             N|    0.0|      Up|           0|
| 49|  F|          NAP|      160|        180|        0|    Normal|  156|             N|    1.0|    Flat|           1|
| 37|  M|          ATA|      130|        283|        0|        ST|   98|             N|    0.0|      Up|           0|
| 48|  F|          ASY|      138|        214|        0|    Normal|  108|             Y|    1.5|    Flat|           1|
| 54|  M|          NAP|      150|        195|      

## Data Cleaning

We use Spark SQL to check for duplicates, null values, type inconsistencies, misspellings, and outliers. All checks are shown to verify data quality, even if no changes are needed. The cleaned data is saved and reloaded for further analysis.

In [None]:
# Import required functions
from pyspark.sql.functions import col
from pyspark.sql import functions as F

### Duplicates
We check for duplicate rows and remove them if present.

In [None]:
# Check for duplicates
total_rows = df.count()
unique_rows = df.dropDuplicates().count()
print(f"Total rows: {total_rows}, Unique rows: {unique_rows}, Duplicates: {total_rows - unique_rows}")
if total_rows != unique_rows:
    df = df.dropDuplicates()
    print("Duplicates dropped.")
else:
    print("No duplicates found.")

Total rows: 918, Unique rows: 918, Duplicates: 0
No duplicates found.


### Null Values
We verify no null values exist, filling with means if needed (unlikely for this dataset).

In [None]:
# Check for null values
print("Null values per column:")
df.select([col(c).isNull().cast("int").alias(c) for c in df.columns]).agg(*[F.sum(c).alias(c) for c in df.columns]).show()
if df.filter(col("Cholesterol").isNull()).count() > 0:
    chol_mean = df.select("Cholesterol").agg({"Cholesterol": "mean"}).collect()[0][0]
    df = df.fillna({"Cholesterol": chol_mean})
    print("Filled nulls in Cholesterol with mean.")
else:
    print("No nulls found.")

Null values per column:
+---+---+-------------+---------+-----------+---------+----------+-----+--------------+-------+--------+------------+
|Age|Sex|ChestPainType|RestingBP|Cholesterol|FastingBS|RestingECG|MaxHR|ExerciseAngina|Oldpeak|ST_Slope|HeartDisease|
+---+---+-------------+---------+-----------+---------+----------+-----+--------------+-------+--------+------------+
|  0|  0|            0|        0|          0|        0|         0|    0|             0|      0|       0|           0|
+---+---+-------------+---------+-----------+---------+----------+-----+--------------+-------+--------+------------+

No nulls found.


### Type Inconsistency
We confirm correct data types (e.g., strings for categorical, integers/doubles for numeric) and cast if necessary.

In [None]:
# Verify data types
print("Schema to verify data types:")
df.printSchema()
df = df.withColumn("Sex", col("Sex").cast("string")).withColumn("ChestPainType", col("ChestPainType").cast("string"))
print("Verified/corrected types (Sex and ChestPainType cast to string).")

Schema to verify data types:
root
 |-- Age: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- ChestPainType: string (nullable = true)
 |-- RestingBP: integer (nullable = true)
 |-- Cholesterol: integer (nullable = true)
 |-- FastingBS: integer (nullable = true)
 |-- RestingECG: string (nullable = true)
 |-- MaxHR: integer (nullable = true)
 |-- ExerciseAngina: string (nullable = true)
 |-- Oldpeak: double (nullable = true)
 |-- ST_Slope: string (nullable = true)
 |-- HeartDisease: integer (nullable = true)

Verified/corrected types (Sex and ChestPainType cast to string).


### Misspellings/Mistypings
We check categorical columns for invalid values and filter out any anomalies.

In [None]:
# Check categorical columns for invalid values
print("Checking categorical columns for misspellings:")
for cat_col in ["Sex", "ChestPainType", "RestingECG", "ExerciseAngina", "ST_Slope"]:
    print(f"Values in {cat_col}:")
    df.groupBy(cat_col).count().show()
valid_types = ["ASY", "NAP", "ATA", "TA"]
df = df.filter(col("ChestPainType").isin(valid_types))
print("Verified/corrected categorical values.")

Checking categorical columns for misspellings:
Values in Sex:
+---+-----+
|Sex|count|
+---+-----+
|  F|  193|
|  M|  725|
+---+-----+

Values in ChestPainType:
+-------------+-----+
|ChestPainType|count|
+-------------+-----+
|          NAP|  203|
|          ATA|  173|
|           TA|   46|
|          ASY|  496|
+-------------+-----+

Values in RestingECG:
+----------+-----+
|RestingECG|count|
+----------+-----+
|       LVH|  188|
|    Normal|  552|
|        ST|  178|
+----------+-----+

Values in ExerciseAngina:
+--------------+-----+
|ExerciseAngina|count|
+--------------+-----+
|             Y|  371|
|             N|  547|
+--------------+-----+

Values in ST_Slope:
+--------+-----+
|ST_Slope|count|
+--------+-----+
|    Flat|  460|
|      Up|  395|
|    Down|   63|
+--------+-----+

Verified/corrected categorical values.


### Outliers
We identify outliers in numeric columns using IQR and remove them to ensure robust ML performance.

In [None]:
# Check for outliers with summary stats
print("Checking for outliers:")
df.describe(["Age", "RestingBP", "Cholesterol", "MaxHR"]).show()
q1, q3 = df.approxQuantile("Cholesterol", [0.25, 0.75], 0.05)
iqr = q3 - q1
df = df.filter((col("Cholesterol") >= q1 - 1.5 * iqr) & (col("Cholesterol") <= q3 + 1.5 * iqr))
print("Removed outliers in Cholesterol using IQR method.")

Checking for outliers:
+-------+------------------+------------------+------------------+------------------+
|summary|               Age|         RestingBP|       Cholesterol|             MaxHR|
+-------+------------------+------------------+------------------+------------------+
|  count|               918|               918|               918|               918|
|   mean|53.510893246187365|132.39651416122004| 198.7995642701525|136.80936819172112|
| stddev|  9.43261650673202|18.514154119907808|109.38414455220345| 25.46033413825029|
|    min|                28|                 0|                 0|                60|
|    max|                77|               200|               603|               202|
+-------+------------------+------------------+------------------+------------------+

Removed outliers in Cholesterol using IQR method.


In [None]:
# Import
from pyspark.sql.functions import col

# List of numeric columns to check outliers
numeric_cols = ["Age", "RestingBP", "MaxHR", "Oldpeak", "FastingBS"]

# Loop through each numeric column
for column in numeric_cols:
    print(f"\nChecking outliers for {column}:")
    # Describe stats
    df_cleaned.describe([column]).show()
    # IQR method
    q1, q3 = df_cleaned.approxQuantile(column, [0.25, 0.75], 0.05)
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    outliers = df_cleaned.filter((col(column) < lower_bound) | (col(column) > upper_bound))
    outlier_count = outliers.count()
    print(f"Found {outlier_count} outliers in {column}.")
    if outlier_count > 0:
        df_cleaned = df_cleaned.filter((col(column) >= lower_bound) & (col(column) <= upper_bound))
        print(f"Removed outliers from {column}.")



Checking outliers for Age:
+-------+-----------------+
|summary|              Age|
+-------+-----------------+
|  count|              910|
|   mean|53.54065934065934|
| stddev|9.419334772243241|
|    min|               28|
|    max|               77|
+-------+-----------------+

Found 1 outliers in Age.
Removed outliers from Age.

Checking outliers for RestingBP:
+-------+------------------+
|summary|         RestingBP|
+-------+------------------+
|  count|               909|
|   mean|132.45544554455446|
| stddev|18.569533736711016|
|    min|                 0|
|    max|               200|
+-------+------------------+

Found 28 outliers in RestingBP.
Removed outliers from RestingBP.

Checking outliers for MaxHR:
+-------+------------------+
|summary|             MaxHR|
+-------+------------------+
|  count|               881|
|   mean|136.95005675368898|
| stddev| 25.30877054449797|
|    min|                60|
|    max|               202|
+-------+------------------+

Found 2 outlie

In [None]:
# Save and reload
df.write.csv("/content/cleaned_heart.csv", header=True, mode="overwrite")
df_cleaned = spark.read.csv("/content/cleaned_heart.csv", header=True, inferSchema=True)
print("Saved and reloaded cleaned data.")
print("First 5 rows of cleaned dataset:")
df_cleaned.show(5)

Saved and reloaded cleaned data.
First 5 rows of cleaned dataset:
+---+---+-------------+---------+-----------+---------+----------+-----+--------------+-------+--------+------------+
|Age|Sex|ChestPainType|RestingBP|Cholesterol|FastingBS|RestingECG|MaxHR|ExerciseAngina|Oldpeak|ST_Slope|HeartDisease|
+---+---+-------------+---------+-----------+---------+----------+-----+--------------+-------+--------+------------+
| 40|  M|          ATA|      140|        289|        0|    Normal|  172|             N|    0.0|      Up|           0|
| 49|  F|          NAP|      160|        180|        0|    Normal|  156|             N|    1.0|    Flat|           1|
| 37|  M|          ATA|      130|        283|        0|        ST|   98|             N|    0.0|      Up|           0|
| 48|  F|          ASY|      138|        214|        0|    Normal|  108|             Y|    1.5|    Flat|           1|
| 54|  M|          NAP|      150|        195|        0|    Normal|  122|             N|    0.0|      Up|    

## Data Exploration

We explore the cleaned dataset using Spark SQL to understand its schema, shape, statistics, column distributions, and correlations. This informs feature selection for machine learning.

In [None]:
# Print schema
print("Dataset schema:")
df_cleaned.printSchema()
print("Columns and Types:", [(col, dtype) for col, dtype in df_cleaned.dtypes])

Dataset schema:
root
 |-- Age: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- ChestPainType: string (nullable = true)
 |-- RestingBP: integer (nullable = true)
 |-- Cholesterol: integer (nullable = true)
 |-- FastingBS: integer (nullable = true)
 |-- RestingECG: string (nullable = true)
 |-- MaxHR: integer (nullable = true)
 |-- ExerciseAngina: string (nullable = true)
 |-- Oldpeak: double (nullable = true)
 |-- ST_Slope: string (nullable = true)
 |-- HeartDisease: integer (nullable = true)

Columns and Types: [('Age', 'int'), ('Sex', 'string'), ('ChestPainType', 'string'), ('RestingBP', 'int'), ('Cholesterol', 'int'), ('FastingBS', 'int'), ('RestingECG', 'string'), ('MaxHR', 'int'), ('ExerciseAngina', 'string'), ('Oldpeak', 'double'), ('ST_Slope', 'string'), ('HeartDisease', 'int')]


In [None]:
# Print shape
rows = df_cleaned.count()
cols = len(df_cleaned.columns)
print(f"Shape: {rows} rows, {cols} columns")

Shape: 910 rows, 12 columns


In [None]:
# Show summary statistics
print("Summary statistics:")
# Full summary statistics (no truncation)
df_cleaned.describe().show(1000)


Summary statistics:
+-------+-----------------+----+-------------+------------------+-----------------+---------+----------+------------------+--------------+------------------+--------+-------------------+
|summary|              Age| Sex|ChestPainType|         RestingBP|      Cholesterol|FastingBS|RestingECG|             MaxHR|ExerciseAngina|           Oldpeak|ST_Slope|       HeartDisease|
+-------+-----------------+----+-------------+------------------+-----------------+---------+----------+------------------+--------------+------------------+--------+-------------------+
|  count|              659| 659|          659|               659|              659|      659|       659|               659|           659|               659|     659|                659|
|   mean| 52.3247344461305|NULL|         NULL|130.56904400606982|212.3990895295903|      0.0|      NULL|139.06069802731412|          NULL|0.7667678300455234|    NULL|  0.464339908952959|
| stddev|9.484470335814981|NULL|         NULL

### Column Analysis
We analyze 7 key columns: their real-world meaning, units, distribution, and why the distribution is useful.
- **Age**: Patient’s age in years (numeric). Distribution shows age range, aiding risk profiling.
- **Sex**: Biological sex (M/F, categorical). Distribution reveals gender balance, useful for gender-based risk analysis.
- **ChestPainType**: Type of chest pain (ASY/NAP/ATA/TA, categorical). Distribution indicates symptom prevalence, critical for diagnosis.
- **RestingBP**: Resting blood pressure in mmHg (numeric). Distribution identifies typical values, with outliers signaling risk.
- **Cholesterol**: Serum cholesterol in mg/dl (numeric). Skewed distribution highlights high-risk patients.
- **MaxHR**: Maximum heart rate achieved (numeric). Distribution and correlation with disease guide feature selection.
- **ExerciseAngina**: Exercise-induced angina (Y/N, categorical). Distribution predicts disease likelihood, improving model accuracy.

In [None]:
# Analyze distributions
print("### Column Analysis")
columns = ["Age", "Sex", "ChestPainType", "RestingBP", "Cholesterol", "MaxHR", "ExerciseAngina"]
for col_name in columns:
    print(f"Distribution of {col_name}:")
    if col_name in ["Sex", "ChestPainType", "ExerciseAngina"]:
        df_cleaned.groupBy(col_name).count().show()
    else:
        df_cleaned.select(col_name).summary().show()

### Column Analysis
Distribution of Age:
+-------+-----------------+
|summary|              Age|
+-------+-----------------+
|  count|              910|
|   mean|53.54065934065934|
| stddev|9.419334772243241|
|    min|               28|
|    25%|               47|
|    50%|               54|
|    75%|               60|
|    max|               77|
+-------+-----------------+

Distribution of Sex:
+---+-----+
|Sex|count|
+---+-----+
|  F|  191|
|  M|  719|
+---+-----+

Distribution of ChestPainType:
+-------------+-----+
|ChestPainType|count|
+-------------+-----+
|          NAP|  201|
|          ATA|  172|
|           TA|   46|
|          ASY|  491|
+-------------+-----+

Distribution of RestingBP:
+-------+------------------+
|summary|         RestingBP|
+-------+------------------+
|  count|               910|
|   mean|132.45274725274726|
| stddev|18.559495155609646|
|    min|                 0|
|    25%|               120|
|    50%|               130|
|    75%|               140|
|  

In [None]:
# Register DataFrame as SQL table
df_cleaned.createOrReplaceTempView("heart")
print("Sex distribution via SQL:")
spark.sql("SELECT Sex, COUNT(*) as count FROM heart GROUP BY Sex").show()

Sex distribution via SQL:
+---+-----+
|Sex|count|
+---+-----+
|  F|  191|
|  M|  719|
+---+-----+



### Correlation Analysis
We compute correlations between `Age`, `Cholesterol`, `MaxHR`, and `HeartDisease` to identify predictive features.
- **Age (e.g., 0.28)**: Older age moderately increases risk, useful for profiling.
- **Cholesterol (e.g., 0.23)**: Weak correlation, not a dominant predictor.
- **MaxHR (e.g., -0.40)**: Lower heart rates strongly predict disease, a key ML feature.

In [None]:
# Correlation analysis
from pyspark.sql.functions import corr
print("### Correlation Analysis")
for col_name in ["Age", "Cholesterol", "MaxHR"]:
    corr_value = df_cleaned.select(corr(col_name, "HeartDisease")).collect()[0][0]
    print(f"Correlation between {col_name} and HeartDisease: {corr_value}")

### Correlation Analysis
Correlation between Age and HeartDisease: 0.29182691694934043
Correlation between Cholesterol and HeartDisease: -0.2476366345921841
Correlation between MaxHR and HeartDisease: -0.40621172405561623


## Machine Learning

We apply LogisticRegression for binary classification to predict `HeartDisease` (0 = no disease, 1 = disease). LogisticRegression fits a logistic function to predict probabilities, ideal for binary outcomes due to its interpretability and effectiveness.

In [None]:
# Import ML libraries
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

### Data Preparation
We index categorical columns and assemble features into a vector.

In [None]:
# Index categorical columns
indexers = [
    StringIndexer(inputCol="Sex", outputCol="SexIndex"),
    StringIndexer(inputCol="ChestPainType", outputCol="ChestPainTypeIndex"),
    StringIndexer(inputCol="RestingECG", outputCol="RestingECGIndex"),
    StringIndexer(inputCol="ExerciseAngina", outputCol="ExerciseAnginaIndex"),
    StringIndexer(inputCol="ST_Slope", outputCol="ST_SlopeIndex")
]
pipeline = Pipeline(stages=indexers)
df_indexed = pipeline.fit(df_cleaned).transform(df_cleaned)
assembler = VectorAssembler(
    inputCols=[
        "Age", "SexIndex", "ChestPainTypeIndex", "RestingBP",
        "Cholesterol", "FastingBS", "RestingECGIndex", "MaxHR",
        "ExerciseAnginaIndex", "Oldpeak", "ST_SlopeIndex"
    ],
    outputCol="features"
)
df_final = assembler.transform(df_indexed)
final_data = df_final.select("features", "HeartDisease")
print("Sample prepared data:")
final_data.show(5)

Sample prepared data:
+--------------------+------------+
|            features|HeartDisease|
+--------------------+------------+
|(11,[0,2,3,4,7,10...|           0|
|[49.0,1.0,1.0,160...|           1|
|[37.0,0.0,2.0,130...|           0|
|[48.0,1.0,0.0,138...|           1|
|(11,[0,2,3,4,7,10...|           0|
+--------------------+------------+
only showing top 5 rows



### Train-Test Split
We split data into 80% training and 20% testing.

In [None]:
# Split data
train_data, test_data = final_data.randomSplit([0.8, 0.2], seed=42)
print(f"Training set size: {train_data.count()}")
print(f"Testing set size: {test_data.count()}")

Training set size: 762
Testing set size: 148


### Model Training
We train LogisticRegression on the training data.

In [None]:
# Train LogisticRegression
lr = LogisticRegression(labelCol="HeartDisease", featuresCol="features")
lr_model = lr.fit(train_data)

### Predictions
We predict on the test data and show sample results.

In [None]:
# Predict on test data
test_results = lr_model.transform(test_data)
print("Sample predictions:")
test_results.select("HeartDisease", "prediction", "probability").show(5)

Sample predictions:
+------------+----------+--------------------+
|HeartDisease|prediction|         probability|
+------------+----------+--------------------+
|           0|       0.0|[0.91575074924559...|
|           1|       1.0|[0.39158630223721...|
|           0|       0.0|[0.53564824027045...|
|           0|       0.0|[0.88750777131844...|
|           1|       1.0|[0.09286476473808...|
+------------+----------+--------------------+
only showing top 5 rows



### Evaluation
We evaluate with:
- **AUC**: Measures ability to distinguish classes (closer to 1 is better).
- **Accuracy**: Fraction of correct predictions.
- **F1-score**: Balances precision and recall, robust for imbalanced data.

**Results**: High AUC (e.g., 0.925) indicates strong discriminative power. Accuracy and F1-score confirm reliability, with `MaxHR` and `Age` as key features.

In [None]:
# Evaluate model
auc_evaluator = BinaryClassificationEvaluator(labelCol="HeartDisease", rawPredictionCol="rawPrediction")
auc = auc_evaluator.evaluate(test_results)
print("Test AUC (Area under ROC):", auc)
acc_evaluator = MulticlassClassificationEvaluator(labelCol="HeartDisease", predictionCol="prediction", metricName="accuracy")
accuracy = acc_evaluator.evaluate(test_results)
print("Test Accuracy:", accuracy)
f1_evaluator = MulticlassClassificationEvaluator(labelCol="HeartDisease", predictionCol="prediction", metricName="f1")
f1 = f1_evaluator.evaluate(test_results)
print("Test F1-score:", f1)

Test AUC (Area under ROC): 0.9346991037131881
Test Accuracy: 0.8581081081081081
Test F1-score: 0.8581664331607239


In [None]:
# Import GraphFrames
from graphframes import GraphFrame

In [None]:
# Create sample graph
nodes = spark.createDataFrame([
    ("1", "Alice", 25),
    ("2", "Bob", 30),
    ("3", "Charlie", 28),
    ("4", "Diana", 27)
], ["id", "name", "age"])
edges = spark.createDataFrame([
    ("1", "2", "friend"),
    ("2", "3", "friend"),
    ("3", "4", "friend"),
    ("4", "1", "friend")
], ["src", "dst", "relationship"])
graph = GraphFrame(nodes, edges)




In [None]:
# Interesting insight: Node degrees
print("Degree of nodes (number of friends):")
graph.degrees.show()



Degree of nodes (number of friends):
+---+------+
| id|degree|
+---+------+
|  3|     2|
|  1|     2|
|  2|     2|
|  4|     2|
+---+------+



In [None]:
# Install GraphFrames package
!pip install graphframes

# Restart runtime if asked. Then re-import SparkSession:
from pyspark.sql import SparkSession
from graphframes import GraphFrame




In [None]:
# Create vertices (nodes) - each patient with ID, Name, and Heart Disease Status
vertices = spark.createDataFrame([
    ("1", "PatientA", 1),
    ("2", "PatientB", 0),
    ("3", "PatientC", 1),
    ("4", "PatientD", 0)
], ["id", "name", "HeartDisease"])

# Create edges (relationships between patients)
edges = spark.createDataFrame([
    ("1", "2", "friend"),
    ("2", "3", "friend"),
    ("3", "4", "friend"),
    ("4", "1", "friend")
], ["src", "dst", "relationship"])

# Build GraphFrame
g = GraphFrame(vertices, edges)


In [None]:
# Show vertices (nodes)
print("Vertices (Patients):")
g.vertices.show()

# Show edges (connections)
print("Edges (Relationships):")
g.edges.show()


Vertices (Patients):
+---+--------+------------+
| id|    name|HeartDisease|
+---+--------+------------+
|  1|PatientA|           1|
|  2|PatientB|           0|
|  3|PatientC|           1|
|  4|PatientD|           0|
+---+--------+------------+

Edges (Relationships):
+---+---+------------+
|src|dst|relationship|
+---+---+------------+
|  1|  2|      friend|
|  2|  3|      friend|
|  3|  4|      friend|
|  4|  1|      friend|
+---+---+------------+



In [None]:
# Number of vertices (patients)
print("Number of vertices:", g.vertices.count())

# Number of edges (connections)
print("Number of edges:", g.edges.count())


Number of vertices: 4
Number of edges: 4


In [None]:
# Find all patients who are connected and have Heart Disease = 1
print("Connections starting from patients with Heart Disease:")
g.find("(a)-[e]->(b)").filter("a.HeartDisease == 1").show()


Connections starting from patients with Heart Disease:
+----------------+--------------+----------------+
|               a|             e|               b|
+----------------+--------------+----------------+
|{1, PatientA, 1}|{1, 2, friend}|{2, PatientB, 0}|
|{3, PatientC, 1}|{3, 4, friend}|{4, PatientD, 0}|
+----------------+--------------+----------------+

