# Load and Split Data Flow

In [11]:
from pathlib import Path

from metaflow import FlowSpec, Parameter, step, NBRunner
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split


class DataLoaderFlow(FlowSpec):
    test_size = Parameter(
        name="test_size",
        help="Ratio of test set",
        default=0.2
    )
    random_state = Parameter(
        name="random_state",
        help="Random seed",
        default=42
    )
    dataset_dir = Parameter(
        name="dataset_dir",
        help="Path to save the splits",
        default="./data"
    )

    @step
    def start(self):
        self.dataset_path = Path(self.dataset_dir).resolve()
        print("Running data loading and splitting flow...")
        self.next(self.load)

    @step
    def load(self):
        self.dataset = load_iris()
        self.next(self.split)

    @step
    def split(self):
        data = self.dataset["data"]
        target = self.dataset["target"]
        train_data, test_data, train_target, test_target = train_test_split(
            data, 
            target, 
            test_size=self.test_size, 
            random_state=self.random_state, 
            stratify=target
        )
        self.train = pd.DataFrame(train_data, columns=self.dataset["feature_names"])
        self.train["target"] = train_target
        self.test = pd.DataFrame(test_data, columns=self.dataset["feature_names"])
        self.test["target"] = test_target
        self.next(self.persist)

    @step
    def persist(self):
        self.dataset_path.mkdir(parents=True, exist_ok=True)
        self.train.to_csv(self.dataset_path / "train.csv", index=False)
        self.test.to_csv(self.dataset_path / "test.csv", index=False)
        self.next(self.end)

    @step
    def end(self):
        print(f"Split dataset saved to {self.dataset_dir}")
        print(f"Shape of saved train set {self.train.drop(columns='target').shape}")
        print(f"Shape of saved test set {self.test.drop(columns='target').shape}")


run = NBRunner(DataLoaderFlow, base_dir="./artifacts").nbrun()

Metaflow 2.13.9 executing DataLoaderFlow for user:jeera
Validating your flow...
    The graph looks good!
Running pylint...
    Pylint is happy!
2025-02-23 01:48:11.249 Workflow starting (run-id 1740250091248933):
2025-02-23 01:48:11.261 [1740250091248933/start/1 (pid 87457)] Task is starting.
2025-02-23 01:48:11.894 [1740250091248933/start/1 (pid 87457)] Running data loading and splitting flow...
2025-02-23 01:48:11.977 [1740250091248933/start/1 (pid 87457)] Task finished successfully.
2025-02-23 01:48:11.983 [1740250091248933/load/2 (pid 87460)] Task is starting.
2025-02-23 01:48:12.675 [1740250091248933/load/2 (pid 87460)] Task finished successfully.
2025-02-23 01:48:12.681 [1740250091248933/split/3 (pid 87463)] Task is starting.
2025-02-23 01:48:13.388 [1740250091248933/split/3 (pid 87463)] Task finished successfully.
2025-02-23 01:48:13.394 [1740250091248933/persist/4 (pid 87466)] Task is starting.
2025-02-23 01:48:14.071 [1740250091248933/persist/4 (pid 87466)] Task finished succ