# Train a Model on Union

First, install the necessary packages:

In [None]:
%pip install "flytekit==1.14.0b6" union "pydantic>2" pandas pyarrow scikit-learn joblib

## Create a Dataset

The following code creates a dataset for us to train a model on:

In [1]:
import flytekit as fl
import pandas as pd
import sys


image = fl.ImageSpec(
    name="jupyter-notebook-workshop",
    packages=[
        "pandas",
        "pyarrow",
        "flytekit==1.14.0b6",
        "union==0.1.95",
    ],
    python_version=f"{sys.version_info.major}.{sys.version_info.minor}",
)

task = fl.task(container_image=image)


@task
def get_df() -> pd.DataFrame:
    """Returns a new Dataframe with cols [Name, Age, Grade and PassedTest]"""
    return pd.DataFrame({
        'Name': ['Alice', 'Bob', 'Charlie', 'David', 'Eva'],
        'Age': [23, 25, 22, 24, 23],
        'Grade': ['A', 'B', 'A', 'C', 'B'],
        'PassedTest': [True, False, True, False, True]
    })

Next we create a remote client that can execute tasks on Union Serverless:

In [None]:
from union.remote import UnionRemote

serverless = UnionRemote()

Create the dataframe:

In [None]:
exe = serverless.execute(get_df, inputs={})
exe

Wait for the execution to complete, then load the dataframe into memory:

In [None]:
exe.wait(poll_interval=1)
dataframe = exe.outputs['o0']
dataframe

You can now play around with the dataframe directly in the jupyter runtime:

In [None]:
def local_function(dataframe: pd.DataFrame):
    return dataframe.groupby("Grade")["Age"].sum()

local_function(dataframe)

## Train a Model

Next we define a task that trains a model:

In [7]:
from flytekit.types.file import FlyteFile


training_task = fl.task(container_image=image.with_packages(["scikit-learn", "joblib"]))

@training_task
def train_model(dataframe: pd.DataFrame) -> FlyteFile:
    import joblib
    from sklearn.linear_model import LogisticRegression

    model = LogisticRegression()
    model.fit(dataframe[["Age"]], dataframe["PassedTest"])

    with open("model.pkl", "wb") as f:
        joblib.dump(model, f)

    return FlyteFile(path="model.pkl")


@fl.workflow
def train_wf(dataframe: pd.DataFrame) -> FlyteFile:
    return train_model(dataframe)

Execute the training run:

In [None]:
model_exe = serverless.execute(train_wf, inputs={"dataframe": dataframe})
model_exe

Now let's loads the model into the jupyter runtime:

In [None]:
import joblib

model_exe.wait(poll_interval=1)
model_file = model_exe.outputs['o0']

with open(model_file, "rb") as f:
    model = joblib.load(f)

model

Finally, we generate some predictions with the model:

In [None]:
prediction_data = pd.DataFrame({
    "Age": [23, 25, 22, 24, 23]
})

model.predict(prediction_data)