In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import Tokenizer, StopWordsRemover, CountVectorizer, IDF, StringIndexer
from pyspark.ml.classification import NaiveBayes
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import col

# Initialize Spark session
spark = SparkSession.builder.appName("Medical_NaiveBayes").getOrCreate()

# --------------------------
#  Load dataset
# --------------------------
# Example CSV: columns = ["symptoms", "disease_type"]
data = spark.read.csv("medical_symptoms.csv", header=True, inferSchema=True)

print("=== Sample Data ===")
data.show(5)

# --------------------------
#  Preprocessing
# --------------------------
# Convert text to lowercase, tokenize, and remove stopwords
tokenizer = Tokenizer(inputCol="symptoms", outputCol="words")
remover = StopWordsRemover(inputCol="words", outputCol="filtered")

# Convert text to numerical feature vectors using TF-IDF
vectorizer = CountVectorizer(inputCol="filtered", outputCol="raw_features")
idf = IDF(inputCol="raw_features", outputCol="features")

# Encode target variable: contagious (1) or non-contagious (0)
indexer = StringIndexer(inputCol="disease_type", outputCol="label")

# --------------------------
#  Split data
# --------------------------
train, test = data.randomSplit([0.8, 0.2], seed=42)

# --------------------------
#  Naive Bayes model
# --------------------------
nb = NaiveBayes(modelType="multinomial", smoothing=1.0)

# Build the pipeline
pipeline = Pipeline(stages=[tokenizer, remover, vectorizer, idf, indexer, nb])

# Train model
model = pipeline.fit(train)

# --------------------------
#  Make predictions
# --------------------------
predictions = model.transform(test)
predictions.select("symptoms", "label", "prediction").show(10, truncate=False)

# --------------------------
#  Evaluate performance
# --------------------------
evaluator_precision = MulticlassClassificationEvaluator(metricName="weightedPrecision")
evaluator_recall = MulticlassClassificationEvaluator(metricName="weightedRecall")
evaluator_f1 = MulticlassClassificationEvaluator(metricName="f1")

precision = evaluator_precision.evaluate(predictions)
recall = evaluator_recall.evaluate(predictions)
f1 = evaluator_f1.evaluate(predictions)

print(f"Precision: {precision:.3f}")
print(f"Recall:    {recall:.3f}")
print(f"F1-score:  {f1:.3f}")

spark.stop()