In [None]:
#r "nuget:Microsoft.Spark"
#r "nuget:Microsoft.ML"
#r "nuget:Microsoft.ML.FastTree"

In [None]:
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.FastTree;
using Microsoft.Spark;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using static Microsoft.Spark.Sql.Functions;

// run spark : spark-submit --class org.apache.spark.deploy.dotnet.DotnetRunner --master local D:\3bStudio\Sandbox\spark-program\FirstSparkProgram\bin\Debug\net6.0\microsoft-spark-3-0_2.12-2.1.0.jar debug
// see Spark portal at : http://localhost:4040
var spark = SparkSession
    .Builder()
    .AppName("spark-use-ml-model")
    .GetOrCreate();

spark.SparkContext.SetLogLevel("WARN");

In [None]:
public class HeartProfile
{
    public float Age;  
    public float Cholesterol;
    public float RestingBP;
    public float FastingBS;
    public bool HeartDisease;
}

public class PredictionSummary
{
    [ColumnName("PredictedLabel")]
    public bool Prediction { get; set; }

    public float Probability { get; set; }

    public float Score { get; set; }
}

In [None]:
var inputSchema = new StructType(new[]
    {
        new StructField("age", new FloatType()),
        new StructField("cholesterol", new FloatType()),
        new StructField("restingBP", new FloatType()),
        new StructField("fastingBS", new FloatType())
    });

DataFrame df = spark.Read().Schema(inputSchema).Json("D:/3bStudio/Sandbox/3bs-spark-training/resources/ml-input.json");

In [None]:
var schema = df.Schema();
Console.WriteLine(schema.SimpleString);

IEnumerable<Row> rows = df.Collect();
foreach (Row row in rows)
{
    Console.WriteLine(row);
}

In [None]:
static PredictionSummary PredictFunc(float age, float cholesterol, float restingBP, float fastingBS)
{
    MLContext mlContext = new MLContext();
    ITransformer model = mlContext.Model.Load(@"D:\3bStudio\Sandbox\3bs-spark-training\resources\HeartClassification.zip", out DataViewSchema schema);
    var predEngine = mlContext.Model.CreatePredictionEngine<HeartProfile, PredictionSummary>(model);

    return predEngine.Predict(new HeartProfile
        {
            Age = age,
            Cholesterol = cholesterol,
            RestingBP = restingBP,
            FastingBS = fastingBS
        });
}

In [None]:
var predictionResults = new List<PredictionSummary>();
foreach (Row row in rows)
{
    object[] rowValues = row.Values;
    var predictResult = PredictFunc(Convert.ToSingle(rowValues[0]),
    Convert.ToSingle(rowValues[1]),
    Convert.ToSingle(rowValues[2]),
    Convert.ToSingle(rowValues[3]));

    predictionResults.Add(predictResult);

    Console.WriteLine($"predict : {predictResult.Prediction} (prob= {predictResult.Probability}, score={predictResult.Score})");
}

var resultDf = spark.CreateDataFrame(predictionResults.Select(x => 
    new GenericRow(new object[]{x.Prediction})),
    new StructType(
        new List<StructField>()
        {
            new StructField("Prediction", new BooleanType())
        })
    );

In [None]:
resultDf.Show();

In [None]:
spark.Udf()
    .Register<string, bool>("MLudf", predict);

// Use Spark SQL to call ML.NET UDF
df.CreateOrReplaceTempView("Heart-data");
DataFrame sqlDf = spark.Sql("SELECT ReviewText, MLudf(ReviewText) FROM Heart-data");
sqlDf.Show();

// Print out first 20 rows of data
// Prevent data getting cut off by setting truncate = 0
sqlDf.Show(20, 0, false);

spark.Stop();