# Intro to Flyte 2 - Scaleable Dynamic Workflows for ML Pipelines and AI Agents

To run this tutorial you'll need to be either an exsisting Union.ai user with Flyte v2, or [sign up for the free Beta access](https://www.union.ai/beta)!

[Flyte 2.0](https://github.com/flyteorg/flyte-sdk) is an orchestration platform for building and managing highly scalable, dynamic, and distributed workflows. 

Examples in this notebook:
- Setup (Do this first before running anyt of the examples)
- Hello Flyte tasks
- ML pipeline - custom environments, caching, reporting, Flyte Files
- Error handling and dynamic infra
- AI Agents with Flyte - coming soon

## Setup

if you're running this notebook locally we suggest creating a virtual environment and installing the packages locally with uv. You can follow the [instructions on the README](README.md).
```|
uv venv .venv --python=python3.11
source .venv/bin/activate
uv pip install 'flyte>=2.0.0b21' --prerelease=allow
```
If you're running this in google colab run the two setup cells below:

In [None]:

!uv pip install 'flyte>=2.0.0b21' --prerelease=allow

##### Flyte Config
if you're running this notebook locally you can remove `--auth-type headless\`

If you have an existing config you can skip this step

In [None]:
!flyte create config \
    --endpoint tryv2.hosted.unionai.cloud \
    --auth-type headless\
    --builder remote \
    --domain development \
    --project flytesnacks

## 👋 Hello Flyte Tasks

TaskEnvironment: A TaskEnvironment object is the abstraction that defines the hardware and software environment in which one or more tasks are executed.


In [None]:
import flyte

# Create a task environment with resource specifications
env = flyte.TaskEnvironment(
    name="hello_flyte_simple",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
)


@env.task # Decorator to define a Flyte task (a unit of work in a environment)
def process_and_validate(user_input: str) -> dict:
    """Extract features and validate in a single task"""
    clean = user_input.strip().lower()
    
    features = {
        "text": clean,
        "length": len(clean),
        "word_count": len(clean.split()),
        "has_numbers": any(c.isdigit() for c in clean),
        "is_valid": len(clean) > 5 and len(clean.split()) >= 2  # validation criteria
    }
    
    return features

# tasks can be called in within other tasks to create workflows
@env.task 
def analyze_batch_workflow(raw_inputs: list[str]) -> dict:
    """Process all inputs and generate summary statistics"""
    if len(raw_inputs) < 3:
        raise ValueError(f"Need at least 3 samples, got: {len(raw_inputs)}")

    # Use Flyte's map function to process all inputs in parallel
    all_features = list(flyte.map(process_and_validate, raw_inputs))
    
    valid_samples = [f for f in all_features if f["is_valid"]]

    summary = {
        "total_samples": len(all_features),
        "valid_samples": len(valid_samples),
        "avg_length": sum(f["length"] for f in all_features) / len(all_features),
        "ready_for_training": len(valid_samples) >= len(all_features) * 0.7
    }

    print("Processed features:", summary)

    return summary

# Sample data to analyze
sample_reviews = [
    "  This product is amazing! I love it.  ",
    "Great quality and fast shipping", 
    " Bad ",  # Will be invalid (too short)
    "The delivery was delayed by 3 days",
    "Excellent customer service team",
    "Perfect for my home office setup"
]

# Initialize Flyte and run the workflow
flyte.init_from_config(".flyte/config.yaml")
execution = flyte.run(analyze_batch_workflow, raw_inputs=sample_reviews)

print(f"Execution: {execution.name}")
print(f"URL: {execution.url}")
print("Click the link above to view execution details in the Flyte UI 👆")
# Click the signin link to run your Flyte workflow!👇

#### Run tasks locally
Running tasks locally is a great way to debug and iterate quickly without needing to run on a remote cluster.

In [None]:
flyte.init() #overwrite flyte init from config
execution = flyte.run(analyze_batch_workflow, raw_inputs=sample_reviews)
print(f"Execution: {execution.name}")
print(f"URL: {execution.url}")

This was a quick example of running a flyte 2.0 task. The next section will show you how to run a full workflow with multiple tasks and dependencies.

Check out the [Flyte 2.0 documentation](https://www.union.ai/docs/v2/flyte/user-guide/) for more examples and details on how to use Flyte.

## Build an ML Pipeline (and see more features)

Let's build a simple ML pipeline using sklearn and the iris dataset. We'll use some more advanced features of Flyte like:

- Build custom environments (containers) -> read more in the [docs]() 
- Reusable environments (containers)
- Caching (to speed up workflows)
- Reporting (to visualize results in the UI)
- Flyte Files (to manage data and model artifacts)

This pipeline can be extended to more complex use cases, but this should give you a good starting point.

In [None]:
# Import libraries and modules
import io
from textwrap import dedent
from typing import List
import joblib
import matplotlib as mpl
mpl.use("Agg") # Use a non-interactive backend for matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import base64
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

import flyte.report
import flyte
from flyte.io import Dir, File

# Custom environment with scikit-learn installed
env = flyte.TaskEnvironment(
    name="scikit_learn_pipeline",
    resources=flyte.Resources(cpu=2, memory="2Gi"),
    reusable=flyte.ReusePolicy(
        replicas=3,
        idle_ttl=120,
        concurrency=6,
        scaledown_ttl=120,
    ),
    image=flyte.Image.from_debian_base().with_pip_packages("scikit-learn", "pandas",
                                                           "unionai-reuse==0.1.6", "joblib==1.3.2",
                                                           "matplotlib==3.8.3", "seaborn==0.13.2",
                                                           "pyarrow", "fastparquet"),
)

# Helper function to convert a matplotlib figure into an HTML string
def _convert_fig_into_html(fig: mpl.figure.Figure) -> str:
    img_buf = io.BytesIO()
    fig.savefig(img_buf, format="png")
    img_base64 = base64.b64encode(img_buf.getvalue()).decode()
    return f'<img src="data:image/png;base64,{img_base64}" alt="Rendered Image" />'

@env.task(cache="auto")
async def download_dataset() -> pd.DataFrame:
    iris = load_iris()
    iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
    iris_df['target'] = iris.target
    return iris_df

@env.task(report=True, cache="auto")
async def process_dataset(data_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:

    # Perform the train-test split
    train_df, test_df = train_test_split(data_df,
                                         test_size=0.2,
                                         random_state=42,
                                         stratify=data_df['target'])

    # Seaborn pairplot full dataset
    pairplot = sns.pairplot(data_df, hue="target")

    html_content = _convert_fig_into_html(pairplot.figure)
    plt.close(pairplot.figure)

    await flyte.report.replace.aio(html_content) # Report to Flyte UI
    await flyte.report.flush.aio() # Ensure report is sent

    return train_df, test_df

@env.task
async def train_model(dataset: pd.DataFrame, n_neighbors: int = 3) -> File:
    X_train, y_train = dataset.drop("target", axis="columns"), dataset["target"]
    model = knn = KNeighborsClassifier(n_neighbors=n_neighbors)
    trained_model = model.fit(X_train, y_train)
    out_path = "trained_model.joblib" # Local path to save the model
    joblib.dump(model, out_path) # Save the trained model locally

    return await File.from_local(out_path) # Return an uploaded Flyte File reference


@env.task(report=True)
async def evaluate_model(model_file: File, dataset: pd.DataFrame) -> set[str]:
    local_path = await model_file.download()
    model: KNeighborsClassifier = joblib.load(local_path)

    X_test, y_test = dataset.drop(columns=["target"]), dataset["target"]
    y_pred = model.predict(X_test)

    # Confusion matrix
    fig, ax = plt.subplots(figsize=(4, 4))
    ConfusionMatrixDisplay.from_predictions(y_test, y_pred, ax=ax)
    html_cm = _convert_fig_into_html(fig)
    plt.close(fig)

    # Text report
    report_txt = classification_report(y_test, y_pred)
    print(dedent(f"""
    Classification Report
    ---------------------
    {report_txt}
    """))
    html_report = f"<pre>{report_txt}</pre>"

    await flyte.report.replace.aio("<h3>Confusion Matrix</h3>" + html_cm + "<h3>Classification Report</h3>" + html_report)
    await flyte.report.flush.aio()

    return {report_txt}


@env.task
async def model_predict(model_file: File, pred_data: List[List[float]]) -> List[int]:
    local_path = await model_file.download()
    model: KNeighborsClassifier = joblib.load(local_path)
    predictions = model.predict(pred_data)
    return predictions.tolist()


@env.task
async def ml_pipeline(n_neighbors: int = 3,
    pred_data: List[List[float]] = [[1.5, 2.3, 1.3, 2.4]]) -> File:
    data_df = await download_dataset()
    train, test = await process_dataset(data_df)
    model = await train_model(dataset=train, n_neighbors=n_neighbors)
    class_report = await evaluate_model(model_file=model, dataset=test)
    (print(f"Classification Report: {class_report}"))
    pred = await model_predict(model_file=model, pred_data=pred_data)
    (print(f"Predictions for {pred_data}: {pred}"))
    return model

# Main workflow
if __name__ == "__main__":
    flyte.init_from_config(".flyte/config.yaml")
    # flyte.init() # uncomment to run locally

    # Run the complete pipeline
    execution = flyte.run(ml_pipeline)

    print(f"Execution: {execution.name}")
    print(f"URL: {execution.url}")


#### Outputs & Flyte Remote

Note: if you're on the `tryv2` demo cluster, you will not be able to download stored artifacts directly. But if you're on a full deployment or have access to your own cluster you can download the artifacts by passing in your storage authorization credentials. 

You can however move artifacts from `tryv2` within a workflow, such as sending to Hugging Face for storage.


In [None]:
# Example coming soon to the notebook

## ⚠️ Error Handling & Dynamic Infrastructure


In [None]:
import asyncio

import flyte
import flyte.errors

env = flyte.TaskEnvironment(
    name="fail",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
)


@env.task
async def oomer(x: int):
    large_list = [0] * 100000000
    print(len(large_list))


@env.task
async def always_succeeds() -> int:
    await asyncio.sleep(1)
    return 42


@env.task
async def failure_recovery() -> int:
    try:
        await oomer(2)
    except flyte.errors.OOMError as e:
        print(f"Failed with oom trying with more resources: {e}, of type {type(e)}, {e.code}")
        try:
            await oomer.override(resources=flyte.Resources(cpu=1, memory="1Gi"))(5)
        except flyte.errors.OOMError as e:
            print(f"Failed with OOM Again giving up: {e}, of type {type(e)}, {e.code}")
            raise e
    finally:
        await always_succeeds()

    return await always_succeeds()


if __name__ == "__main__":
    flyte.init_from_config(".flyte/config.yaml")

    run = flyte.run(failure_recovery)
    print(run.url)
    run.wait(run)

## AI Agents & Agentic Workflows

Flyte 2.0 built in dynamic task and workflows make it easy to build agentic workflows that can call LLMs and other AI models to help make decisions and take actions with out of the box support for most major agent frameworks and LLM providers.

`Example Coming soon!`

## Create and Manage Secrets
`Example Coming soon!`