In [None]:
import os
import io
import pandas as pd
import numpy as np
import mlflow
import dvc.api
from pyspark.sql import SparkSession
from pyspark.sql import functions as f

spark = SparkSession.builder.appName("linreg").getOrCreate()

path = 'dataset/housing.csv'
repo = '../.git'
version = 'v1' #Git tag

data_url = dvc.api.get_url(
    path = path,
    repo = repo,
    rev = version
)

backend_uri = os.environ['MLFLOW_TRACKING_URI']
artifact_uri = os.environ['MLFLOW_ARTIFACT_STORE']
mlflow.set_tracking_uri(backend_uri)

In [None]:
def fetch_data_from_s3(path, repo, version):
    data = dvc.api.read(
            path = path,
            repo = repo,
            rev = version
        )
    df = pd.read_csv(io.StringIO(data), sep=',')
    df.to_csv("tmp.csv")
    df = spark.read.csv("tmp.csv", sep=',', header=True, inferSchema=True)
    return df


def fetch_data_from_fs(url):
    return pd.read_csv(url, sep=',')


def fetch_data(url):
    storage_type = url.split(":")[0]
    if storage_type.upper() == "S3":
        data = fetch_data_from_s3(path, repo, version)
    else:
        data = fetch_data_from_fs(url)
        
    return data

In [None]:
housing = fetch_data(data_url)
housing = housing.select("housing_median_age")
housing.show(5)

In [None]:
logged_model = 'models:/linreg/Staging'
model = mlflow.pyfunc.spark_udf(spark, logged_model)

df = housing.withColumn('predictions', model(f.col("housing_median_age")))
df.show()