# Experiment tracking with MLFlow

Load a FiftyOne dataset:

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset("cifar10")

In [None]:
from itertools import chain

import eta.core.utils as etau

from flash.core.classification import FiftyOneLabelsOutput
from flash.image import ImageClassificationData, ImageClassifier
from flash import Trainer

import fiftyone as fo
import fiftyone.utils.splits as fous
import fiftyone.zoo as foz

import fiftyone.utils.mlflow as foum

from pytorch_lightning.loggers import MLFlowLogger


def train_flash_model_with_mlflow(dataset, mlf_logger, pred_field):
    dataset.untag_samples("test")

    # Create splits from the dataset
    splits = {"train": 0.7, "test": 0.1, "val": 0.1, "pred": 0.1}
    fous.random_split(dataset, splits)

    train_dataset = dataset.match_tags("train")
    test_dataset = dataset.match_tags("test")
    val_dataset = dataset.match_tags("val")
    predict_dataset = dataset.match_tags("pred")

    # Create the Datamodule
    datamodule = ImageClassificationData.from_fiftyone(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        val_dataset=val_dataset,
        predict_dataset=predict_dataset,
        label_field="ground_truth",
        batch_size=4,
        num_workers=4,
    )

    # Build the model
    model = ImageClassifier(
        backbone="resnet18",
        labels=datamodule.labels,
    )
    trainer = Trainer(
        max_epochs=1, limit_train_batches=5, limit_val_batches=5,
        logger=mlf_logger,
    )
    trainer.finetune(model, datamodule=datamodule)

    predictions = trainer.predict(
        model,
        datamodule=datamodule,
        output=FiftyOneLabelsOutput(labels=datamodule.labels),
    )
    predictions = list(chain.from_iterable(predictions))  # flatten batches

    # Map filepaths to predictions
    predictions = {p["filepath"]: p["predictions"] for p in predictions}

    # Add predictions to FiftyOne dataset
    predict_dataset.set_values(
        pred_field, predictions, key_field="filepath",
    )

In [None]:
# Initialize MLFlow (for flash in this case)
tracking_uri = "file:/tmp/mlruns"
mlf_logger = MLFlowLogger(experiment_name="fiftyone_test", tracking_uri=tracking_uri)

In [None]:
pred_field = "flash_predictions"
train_flash_model_with_mlflow(dataset, mlf_logger, pred_field)

In [None]:
# Connect FiftyOne to MLFlow
fields = {
    "ground_truth": "ground_truth",
    "predictions": pred_field,
}
mlflow_key = "mlflow_run_1"

foum.connect_flash_mlflogger(dataset, mlflow_key, mlf_logger, fields, tracking_uri)
# OR
#foum.connect_to_mlflow(dataset, mlflow_key, experiment_id, run_id, fields, tracking_uri)

In [None]:
foum.launch_mlflow(dataset, mlflow_key)