In [15]:
import time
import weave
from weave.ops_domain import RunSegment
from weave import storage, publish, type_of, use_frontend_devmode
from weave.weave_types import List
import typing
import time
import sys
import numpy as np
from weave.ops import to_arrow

import logging

use_frontend_devmode()
logger = logging.getLogger("run_segment")
handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
logger.addHandler(handler)

# set to logging.INFO for more verbose profiling
logger.setLevel(logging.ERROR)

# serializer = publish   # uses w&b artifacts intead of local artifacts
serializer = storage.save

N_NUMERIC_METRICS = 99  # number of numerical columns in the metrics table


def random_metrics(n: int = 10, starting_step: int = 0, delta_step: int = 1):
    """Create an array of metrics of length n starting from step starting_index."""
    if n <= 0:
        raise ValueError("n must be at least 1")
    if starting_step < 0:
        raise ValueError("starting index must be at least 0")
    if delta_step < 1:
        raise ValueError("delta_step must be an integer greater than or equal to 1.")
    t_start = time.time()
    raw = [
        {
            "step": starting_step + i * delta_step,
            "string_col": np.random.choice(list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")),
            **{f"metric{j}": np.random.random() for j in range(N_NUMERIC_METRICS)},
        }
        for i in range(n)
    ]
    t_stop = time.time()
    logger.info(f'Created raw python metrics table of length {n}, '
                f'width {N_NUMERIC_METRICS + 2} in {t_stop - t_start:.2f} sec.')


    t_start = time.time()
    wb_type = List(type_of(raw[0]))
    arrow_form = to_arrow(raw, wb_type=wb_type)
    t_stop = time.time()
    logger.info(f'Created arrow version of metrics table of length {n}, '
                f'width {N_NUMERIC_METRICS + 2} in {t_stop - t_start:.2f} sec.')
    return arrow_form


def create_branch(
    name: str,
    previous_segment: typing.Optional[RunSegment] = None,
    length=10,
    previous_segment_branch_frac=0.8,
) -> RunSegment:
    """Create a new segment and optionally attach it to a previous segment.

    Parameters
    ----------
    name: str
       The name of the segment.
    previous_segment: Optional[RunSegment], default None.
       The parent run segment. If this is a root run segment, use None.
    length: int, default = 10
       The number of history rows to generate for the segment.
    previous_segment_branch_frac: float satisfying 0 < branch_frac <= 1.
       Parameter describing where in the previous segment to set the branch point.
       A previous_segment_branch_frac of 0 sets the branch point at the previous
       segment's root, whereas a previous_segment_branch_frac of 1 sets the branch
       point at the end of the previous segment. A previous_segment_branch_frac of
       0.5 would include half of the previous segment's metric rows.

    Returns
    -------
    segment: RunSegment
        The new segment.
    """
    if not (0 < previous_segment_branch_frac <= 1):
        raise ValueError("branch_frac must satisfy 0 < branch_frac <= 1")

    if length <= 0:
        raise ValueError("Length must be greater than 0.")

    if previous_segment:
        previous_metrics = previous_segment.metrics
        n_previous_metrics = len(previous_metrics)
        if n_previous_metrics > 0:
            previous_segment_branch_index = (
                int(previous_segment_branch_frac * n_previous_metrics) - 1
            )

            # this run segment has a different root than the previous one
            if previous_segment_branch_index < 0:
                raise ValueError(
                    f"Invalid branch point on RunSegment: previous_segment_branch_index "
                    f"{previous_segment_branch_index} must be between 0 and {len(previous_metrics) - 1}"
                )

            previous_segment_branch_step = (
                previous_metrics._index(0)["step"] + previous_segment_branch_index
            )

            ref = storage.save(previous_segment)
            new_metrics = random_metrics(
                n=length, starting_step=previous_segment_branch_step + 1
            )

            return RunSegment(name, ref.uri, previous_segment_branch_index, new_metrics)
    return RunSegment(name, None, 0, random_metrics(length, 0))


def create_experiment(
    num_steps: int, num_runs: int, branch_frac: float = 0.8
) -> typing.Optional[RunSegment]:
    num_steps_per_run = num_steps // num_runs
    segment = None
    for i in range(num_runs):
        segment = create_branch(
            f"branch {i}",
            segment,
            length=num_steps_per_run,
            previous_segment_branch_frac=branch_frac,
        )
    return segment

In [16]:
last_segment = create_experiment(500000, 100, 0.8)

In [17]:
weave.show(last_segment)

In [10]:
weave.show(last_segment.experiment())