# Wine Quality Classification - Spark R Jupyter Notebook

## SDSC Summer Institute

## Set up environment

In [None]:
# Start Spark session

library(SparkR, lib.loc = c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib")))
sparkR.session(master="local[*]", 
               sparkConfig=list(spark.driver.memory="2g"),
                               (spark.app.name="SparkR Wine Quality Classification"))

In [None]:
# Print software versions

R.Version()$version.string
Sys.getenv("SPARK_HOME")
sparkR.version()

## Read in data

In [None]:
# Read data into a Spark dataframe
# Data adapted from: https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv

sdf <- read.df(<<FILL-IN>>, "csv", header="true", inferSchema="true")

In [None]:
# Cache dataframe

<<FILL-IN>>

In [None]:
# Examine schema

<<FILL-IN>>

## Prepare data

In [None]:
# Split into train & test sets

seed <- 12345
train_sdf <- sample(sdf, withReplacement=FALSE, fraction=0.7, seed=seed)
test_sdf <- except (sdf, train_sdf)
dim(train_sdf)                # Get dimensions of train dataset
<<FILL-IN>>(test_sdf)         # Get dimensions of test dataset

## Train random forest model

In [None]:
model <- spark.randomForest(train_sdf, quality ~ ., type=<<FILL-IN>>, numTrees=30, seed=seed)
head(summary(model))

## Evaluate model

In [None]:
# Apply model to test data

predictions_sdf <- predict(model, <<FILL-IN>>)
class(predictions_sdf)

In [None]:
# Convert results from Spark DataFrame to R data.frame

predictions_df <- as.data.frame(predictions_sdf)
class(predictions_df)
head(predictions_df)

In [None]:
# Calculate accuracy

accuracy <- mean(predictions_df$quality == predictions_df$prediction)
sprintf ("Accuracy on Test Data:  %f", <<FILL-IN>>)

In [None]:
# Confusion matrix

table(predictions_df$quality, predictions_df$prediction)

## Save model

In [None]:
# Save model (NOTE:  Existing model will be overwritten)

write.ml(model, "<<FILL-IN>>", overwrite=TRUE)  # Fill in with name to save model to

## Stop cluster

In [None]:
# Stop Spark cluster

sparkR.stop()