In [11]:
from typing import List
from regex import Regex
import wandb
import datetime as dt
import re

def get_all_runs() -> List[wandb.run]:
    api = wandb.Api(timeout=100)
    # By default, runs are sorted in descending order by creation time.
    return api.runs(
        f"opentensor-dev/pretraining-subnet",
        # The regex matching is quite poor, so let's just match on anything any filter ourselves.
        per_page=1000,
    )
    
regex = r'validator-([0-9]{1,3})-2024-.*'
filtered_runs = [r for r in get_all_runs()[:5000] if re.match(regex, r.name)]

print(len(filtered_runs))

804


1705106401.4615083


In [38]:
from collections import defaultdict
import dataclasses
from re import S
import sys
from rich.table import Table
from rich.console import Console
import json


def parse_to_nanos(duration: str) -> int:
    # duration is formatted like: 1.23 ms or 2.123 min
    units = {
        "ns": 1,
        "μs": 1000,
        "ms": 1000_000,
        "s": 1000_000_000,
        "min": 60 * 1000_000_000,
    }
    tokens = duration.split(" ")
    assert len(tokens) == 2, f"Unexpected duration format: {duration}"
    value = float(tokens[0])
    unit = tokens[1]
    return int(value * units[unit])

@dataclasses.dataclass(frozen=True)
class stats:
    samples: int
    min: int
    median: int
    max: int
    p90: int

def parse_summary_str(summary: str) -> ():
    unit = r'(?:ns|μs|ms|s|min)'
    matcher = rf'.*N=([0-9]+) \| Min=([0-9\.]+ {unit}) \| Max=([0-9\.]+ {unit}) \| Median=([0-9\.]+ {unit}) \| P90=([0-9\.]+ {unit})'
    groups = re.match(matcher, summary).groups()
    return stats(
        samples=int(groups[0]),
        min=parse_to_nanos(groups[1]) / 1000_000_000,
        max=parse_to_nanos(groups[2]) / 1000_000_000,
        median=parse_to_nanos(groups[3])/ 1000_000_000,
        p90=parse_to_nanos(groups[4])/ 1000_000_000,
    )

regex = r'validator-([0-9]{1,3})-2024-.*'

# Map of uid to timestamp.
most_recent_runs = defaultdict(lambda: sys.maxsize)
stats_by_uid = {}

for run in filtered_runs:
    uid = int(re.match(regex, run.name).group(1))
    timestamp = json.loads(filtered_runs[0].summary['original_format_json'])["timestamp"]
    if timestamp < most_recent_runs[uid]:
        if run.summary.get("load_model_perf_log"):
            load_stats = parse_summary_str(run.summary.get("load_model_perf_log"))
            eval_stats = parse_summary_str(run.summary.get("compute_model_perf_log"))
            stats_by_uid[uid] = (load_stats, eval_stats)
            most_recent_runs[uid] = timestamp

table = Table(title="Perf stats")
table.add_column("uid", justify="right", style="cyan", no_wrap=True)
table.add_column("samples", style="magenta")
table.add_column("load_model_avg", style="magenta")
table.add_column("load_model_max", style="magenta")
table.add_column("eval_model_avg", style="magenta")
table.add_column("eval_model_max", style="magenta")
uids = sorted([int(uid) for uid in stats_by_uid.keys()])
for uid in uids:
    s = stats_by_uid[uid]
    table.add_row(
        str(uid),
        str(s[0].samples),
        str(s[0].median),
        str(s[0].max),
        str(s[1].median),
        str(s[1].max),
    )

console = Console()
console.print(table)

Eval: Load model performance: N=30 | Min=5.11 s | Max=6.33 s | Median=5.52 s | P90=6.06 s
Eval: Compute loss performance: N=30 | Min=2.28 s | Max=2.81 s | Median=2.50 s | P90=2.68 s
Eval: Load model performance: N=30 | Min=1.40 s | Max=3.27 s | Median=1.56 s | P90=2.59 s
Eval: Compute loss performance: N=30 | Min=3.75 s | Max=5.90 s | Median=3.82 s | P90=4.97 s
Eval: Load model performance: N=30 | Min=902.12 ms | Max=1.08 s | Median=1.05 s | P90=1.08 s
Eval: Compute loss performance: N=30 | Min=1.60 s | Max=1.67 s | Median=1.65 s | P90=1.66 s
Eval: Load model performance: N=31 | Min=1.20 s | Max=4.82 s | Median=2.16 s | P90=4.52 s
Eval: Compute loss performance: N=31 | Min=1.77 s | Max=1.99 s | Median=1.85 s | P90=1.94 s
Eval: Load model performance: N=30 | Min=1.11 s | Max=2.23 s | Median=1.19 s | P90=1.33 s
Eval: Compute loss performance: N=30 | Min=2.24 s | Max=2.65 s | Median=2.26 s | P90=2.59 s
Eval: Load model performance: N=31 | Min=1.65 s | Max=2.95 s | Median=2.75 s | P90=2.85

In [31]:
unit = r'(?:ns|μs|ms|s|min)'
matcher = rf'.*N=([0-9]+) \| Min=([0-9\.]+ {unit}) \| Max=([0-9\.]+ {unit}) \| Median=([0-9\.]+ {unit}) \| P90=([0-9\.]+ {unit})'
summary ="Eval: Load model performance: N=30 | Min=5.11 s | Max=6.33 s | Median=5.52 s | P90=6.06 s"
groups = re.match(matcher, summary).groups()
print(groups)


('30', '5.11 s', '6.33 s', '5.52 s', '6.06 s')
