# Load and Split Data Flow

In [4]:
from pathlib import Path

from metaflow import FlowSpec, Parameter, step, NBRunner
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split


class DataLoaderFlow(FlowSpec):
    test_size = Parameter(
        name="test_size",
        help="Ratio of test set",
        default=0.2
    )
    random_state = Parameter(
        name="random_state",
        help="Random seed",
        default=42
    )
    dataset_dir = Parameter(
        name="dataset_dir",
        help="Path to save the splits",
        default="./data"
    )

    @step
    def start(self):
        self.dataset_path = Path(self.dataset_dir).resolve()
        print("Running data loading and splitting flow...")
        self.next(self.load)

    @step
    def load(self):
        self.dataset = load_iris()
        self.next(self.split)

    @step
    def split(self):
        X = self.dataset["data"]
        y = self.dataset["target"]
        X_train, X_test, y_train, y_test = train_test_split(
            X, 
            y, 
            test_size=self.test_size, 
            random_state=self.random_state, 
            stratify=y
        )
        self.train = pd.DataFrame(X_train, columns=self.dataset["feature_names"])
        self.train["target"] = y_train
        self.test = pd.DataFrame(X_test, columns=self.dataset["feature_names"])
        self.test["target"] = y_test
        self.next(self.persist)

    @step
    def persist(self):
        self.dataset_path.mkdir(parents=True, exist_ok=True)
        self.train.to_csv(self.dataset_path / "train.csv", index=False)
        self.test.to_csv(self.dataset_path / "test.csv", index=False)
        self.next(self.end)

    @step
    def end(self):
        print(f"Split dataset saved to {self.dataset_dir}")
        print(f"Shape of saved train set {self.train.drop(columns='target').shape}")
        print(f"Shape of saved test set {self.test.drop(columns='target').shape}")


run = NBRunner(DataLoaderFlow, base_dir="./artifacts").nbrun()

Metaflow 2.13.9 executing DataLoaderFlow for user:jeera
Validating your flow...
    The graph looks good!
Running pylint...
    Pylint is happy!
2025-02-23 00:05:03.630 Workflow starting (run-id 1740243903630317):
2025-02-23 00:05:03.641 [1740243903630317/start/1 (pid 86826)] Task is starting.
2025-02-23 00:05:04.290 [1740243903630317/start/1 (pid 86826)] Running data loading and splitting flow...
2025-02-23 00:05:04.380 [1740243903630317/start/1 (pid 86826)] Task finished successfully.
2025-02-23 00:05:04.386 [1740243903630317/load/2 (pid 86829)] Task is starting.
2025-02-23 00:05:05.069 [1740243903630317/load/2 (pid 86829)] Task finished successfully.
2025-02-23 00:05:05.075 [1740243903630317/split/3 (pid 86832)] Task is starting.
2025-02-23 00:05:05.763 [1740243903630317/split/3 (pid 86832)] Task finished successfully.
2025-02-23 00:05:05.769 [1740243903630317/persist/4 (pid 86835)] Task is starting.
2025-02-23 00:05:06.461 [1740243903630317/persist/4 (pid 86835)] Task finished succ