In [4]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Create a SparkSession
spark = SparkSession.builder \
    .appName("Iris Dataset") \
    .getOrCreate()

# Download the iris dataset CSV file and save it locally
iris_url = "https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv"
local_path = "iris.csv"
import urllib.request
urllib.request.urlretrieve(iris_url, local_path)

# Load the iris dataset from the local file
irisData = spark.read.format("csv") \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .load(local_path)

# Convert the "species" column from string to numeric
labelIndexer = StringIndexer(inputCol="species", outputCol="label")
irisData = labelIndexer.fit(irisData).transform(irisData)

# Prepare the data
featureColumns = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
assembler = VectorAssembler(inputCols=featureColumns, outputCol="features")
assembledData = assembler.transform(irisData)

# Split the data into training and test sets
trainingData, testData = assembledData.randomSplit([0.7, 0.3])

# Create a Logistic Regression model
logisticRegression = LogisticRegression(labelCol="label", featuresCol="features")

# Train the model
model = logisticRegression.fit(trainingData)

# Make predictions on the test set
predictions = model.transform(testData)

# Evaluate the model
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

# Show the features and corresponding predictions
predictions.select("features", "prediction").show(truncate=False)

# Display the accuracy of the model
print(f"Accuracy: {accuracy * 100}%")

# Stop the SparkSession
spark.stop()


+-----------------+----------+
|features         |prediction|
+-----------------+----------+
|[4.3,3.0,1.1,0.1]|0.0       |
|[4.4,2.9,1.4,0.2]|0.0       |
|[4.6,3.2,1.4,0.2]|0.0       |
|[4.8,3.4,1.6,0.2]|0.0       |
|[4.8,3.4,1.9,0.2]|0.0       |
|[4.9,2.5,4.5,1.7]|1.0       |
|[4.9,3.0,1.4,0.2]|0.0       |
|[4.9,3.1,1.5,0.1]|0.0       |
|[5.0,3.3,1.4,0.2]|0.0       |
|[5.0,3.4,1.5,0.2]|0.0       |
|[5.0,3.5,1.6,0.6]|0.0       |
|[5.0,3.6,1.4,0.2]|0.0       |
|[5.1,3.4,1.5,0.2]|0.0       |
|[5.1,3.5,1.4,0.2]|0.0       |
|[5.1,3.8,1.9,0.4]|0.0       |
|[5.2,3.4,1.4,0.2]|0.0       |
|[5.4,3.4,1.5,0.4]|0.0       |
|[5.7,2.6,3.5,1.0]|1.0       |
|[5.7,2.8,4.5,1.3]|1.0       |
|[5.7,2.9,4.2,1.3]|1.0       |
+-----------------+----------+
only showing top 20 rows

Accuracy: 87.17948717948718%
