Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tinker_cookbook/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,27 @@
import tinker

from tinker_cookbook.utils.file_utils import read_jsonl
from tinker_cookbook.utils.trace import get_scope_context, scope

CHECKPOINTS_BASE_NAME = "checkpoints.jsonl"

logger = logging.getLogger(__name__)


@scope
def load_checkpoints_file(log_dir: str) -> list[dict[str, Any]]:
checkpoint_path = os.path.join(log_dir, CHECKPOINTS_BASE_NAME)
if not os.path.exists(checkpoint_path):
logger.info(f"No checkpoints found at {checkpoint_path}")
return []

logger.info(f"Reading checkpoints from {checkpoint_path}")
context = get_scope_context()
context.attributes["checkpoint_path"] = checkpoint_path
return read_jsonl(checkpoint_path)


@scope
def get_last_checkpoint(log_dir: str, required_key: str = "state_path") -> dict[str, Any] | None:
"""
Get the last checkpoint from the checkpoints.jsonl file in the specified log directory.
Expand All @@ -49,6 +54,7 @@ def get_last_checkpoint(log_dir: str, required_key: str = "state_path") -> dict[
return None


@scope
async def save_checkpoint_async(
training_client: tinker.TrainingClient,
name: str,
Expand All @@ -72,6 +78,8 @@ async def save_checkpoint_async(

results = {k: await v.result_async() for k, v in futures.items()}
paths = {k + "_path": v.path for k, v in results.items()}
context = get_scope_context()
context.attributes.update(paths)
logger.info(f"Saved checkpoints: {paths}")
full_dict = {"name": name, **loop_state, **paths}
with open(os.path.join(log_path, "checkpoints.jsonl"), "a") as f:
Expand All @@ -80,6 +88,7 @@ async def save_checkpoint_async(
return paths


@scope
def save_checkpoint(
training_client: tinker.TrainingClient,
name: str,
Expand Down
44 changes: 41 additions & 3 deletions tinker_cookbook/supervised/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import chz
import tinker
from tinker.lib.public_interfaces import APIFuture

from tinker_cookbook import checkpoint_utils
from tinker_cookbook.display import colorize_example
from tinker_cookbook.eval.evaluators import (
Expand All @@ -31,6 +32,7 @@
from tinker_cookbook.utils import ml_log
from tinker_cookbook.utils.lr_scheduling import compute_schedule_lr_multiplier
from tinker_cookbook.utils.misc_utils import timed
from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,6 +74,8 @@ class Config:
wandb_project: str | None = None
wandb_name: str | None = None

enable_trace: bool = False


@dataclass
class SubmittedBatch:
Expand All @@ -87,6 +91,7 @@ class SubmittedBatch:
infrequent_eval_metrics: dict[str, float] | None = None


@scope
async def run_evals(
evaluators: list[Evaluator],
training_client: tinker.TrainingClient,
Expand All @@ -102,29 +107,42 @@ async def run_evals(
checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next
to the same-step training metrics.
"""
context = get_scope_context()
context.attributes["step"] = step

metrics = {}
sampling_client = None

for evaluator in evaluators:
@scope
async def run_evaluator(evaluator: Evaluator) -> dict[str, float]:
context = get_scope_context()
context.attributes["step"] = step
context.attributes["evaluator_name"] = type(evaluator).__name__
if isinstance(evaluator, TrainingClientEvaluator):
eval_metrics = await evaluator(training_client)
context.attributes["evaluator_type"] = "TrainingClientEvaluator"
return await evaluator(training_client)
elif isinstance(evaluator, SamplingClientEvaluator):
context.attributes["evaluator_type"] = "SamplingClientEvaluator"
# Create sampling client lazily, only when needed
nonlocal sampling_client
if sampling_client is None:
# Snapshot the current pre-step weights and create a new sampling client.
sampling_client = await training_client.save_weights_and_get_sampling_client_async(
f"evals_step_{step}"
)
eval_metrics = await evaluator(sampling_client)
return await evaluator(sampling_client)
else:
raise ValueError(f"Unknown evaluator type: {type(evaluator)}")

for evaluator in evaluators:
eval_metrics = await run_evaluator(evaluator)
# Add test/ prefix to all metrics
metrics.update({f"test/{k}": v for k, v in eval_metrics.items()})

return metrics


@scope
async def main(config: Config):
"""Run the standard supervised learning loop used by the supervised recipes.

Expand Down Expand Up @@ -156,6 +174,18 @@ async def main(config: Config):
config=config,
do_configure_logging_module=True,
)
if config.enable_trace:
# Get and rename the current (main) task
current_task = asyncio.current_task()
if current_task is not None:
current_task.set_name("main")
trace_events_path = os.path.join(config.log_path, "trace_events.jsonl")
logger.info(f"Tracing is enabled. Trace events will be saved to {trace_events_path}")
logger.info(
f"Run `python tinker_cookbook/utils/trace.py {trace_events_path} trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/"
)
trace_init(output_file=os.path.join(config.log_path, "trace_events.jsonl"))

service_client = tinker.ServiceClient(base_url=config.base_url)
load_state_path: str | None = (
resume_info["state_path"] if resume_info else config.load_checkpoint_path
Expand Down Expand Up @@ -192,8 +222,12 @@ async def main(config: Config):
f"Training for {n_batches} batches x {config.num_epochs} epochs = {n_batches * config.num_epochs} steps"
)

@scope
async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
step = epoch_idx * n_batches + batch_idx
context = get_scope_context()
context.attributes["step"] = step
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could be nice to have a helper to do
context.set_attribute('step', step)
or even
set_scope_attribute('step', step)
which does the get_scope_context() internally


batch_start_time = time.time()
metrics: dict[str, int | float | str] = {"epoch": epoch_idx}
metrics["progress"] = step / progress_denominator
Expand Down Expand Up @@ -250,7 +284,11 @@ async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:
infrequent_eval_metrics=infrequent_eval_metrics,
)

@scope
async def finish_batch(submitted: SubmittedBatch):
context = get_scope_context()
context.attributes["step"] = submitted.step

metrics = submitted.metrics
metrics["progress"] = min((submitted.step + 1) / progress_denominator, 1.0)

Expand Down
Loading