Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ run_hpo.sh
*.pose
*.safetensors
.cache/
dist/
wandb/
*.log
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ server = [
hf = [
"huggingface_hub>=0.20.0",
]
publish = [
"sign-language-segmentation[hf]",
"sign-language-segmentation[train]",
]

[tool.ruff]
line-length = 120
Expand Down
1 change: 1 addition & 0 deletions sign_language_segmentation/publish/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from sign_language_segmentation.publish.publish import main, publish
192 changes: 192 additions & 0 deletions sign_language_segmentation/publish/publish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""Publish a model checkpoint to HuggingFace Hub.

Converts a PyTorch Lightning .ckpt to safetensors, optionally evaluates it,
runs regression checks against the current production model, generates a
model card, and pushes to the 'weekly' branch on HuggingFace Hub.
A date-based tag (vYYYY.MM.DD) is only created on promotion
(regression pass or explicit --promote).

Usage:
uv run --extra publish python -m sign_language_segmentation.publish.publish \
--checkpoint path/to/best.ckpt --repo org/model-name
uv run --extra publish python -m sign_language_segmentation.publish.publish \
--repo org/model-name --promote
"""

import argparse
import json
import tempfile
from pathlib import Path

from sign_language_segmentation.publish.utils import (
convert_to_safetensors,
find_split_manifest,
run_evaluation,
check_regression,
generate_model_card,
promote,
get_next_version,
)


def publish(
checkpoint: str,
repo_id: str,
tag: str,
datasets: str,
corpus: str,
poses: str,
device: str,
skip_eval: bool,
metrics_json: str | None,
regression_threshold: float,
no_promote: bool,
) -> None:
"""Main publish workflow."""
from huggingface_hub import HfApi

api = HfApi()

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)

# 1. convert to safetensors
print(f"Converting {checkpoint} to safetensors...")
config = convert_to_safetensors(checkpoint_path=checkpoint, output_dir=tmp_path)
print(f" model.safetensors + config.json written to {tmp_path}")

# 2. load split manifest (needed for eval quality_percentile)
manifest = find_split_manifest(checkpoint_path=checkpoint)

# 3. evaluation
eval_results = None
if not skip_eval:
if metrics_json:
print(f"Loading pre-computed metrics from {metrics_json}")
with open(metrics_json) as f:
eval_results = json.load(f)
else:
print(f"Evaluating on {datasets} dev+test sets...")
eval_results = run_evaluation(
checkpoint_path=checkpoint,
datasets=datasets,
corpus=corpus,
poses=poses,
device=device,
split_manifest=manifest,
)

# save eval results
if eval_results:
with open(tmp_path / "eval_results.json", "w") as f:
json.dump(eval_results, f, indent=2)

# 4. regression check
regression_status = "skipped"
if eval_results and not skip_eval:
regression_status, _ = check_regression(
new_metrics=eval_results,
repo_id=repo_id,
threshold=regression_threshold,
)
# 5. save split manifest
if manifest:
with open(tmp_path / "split_manifest.json", "w") as f:
json.dump(manifest, f, indent=2)

# 6. generate model card
model_card = generate_model_card(
config=config,
eval_results=eval_results,
regression_status=regression_status,
tag=tag,
repo_id=repo_id,
split_manifest=manifest,
)
with open(tmp_path / "README.md", "w") as f:
f.write(model_card)

# 7. push to weekly branch
print(f"Pushing to {repo_id} branch 'weekly'...")
api.create_repo(repo_id=repo_id, exist_ok=True, repo_type="model")
api.create_branch(repo_id=repo_id, branch="weekly", exist_ok=True)
api.upload_folder(
folder_path=str(tmp_path),
repo_id=repo_id,
revision="weekly",
commit_message=f"publish {tag} (regression: {regression_status})",
)

# 8. promote if regression passed — creates a date tag on the weekly branch
if regression_status == "fail":
print("NOT promoting — regression check failed")
elif no_promote:
print("Skipping promotion (--no-promote)")
else:
promote(repo_id=repo_id, tag=tag, revision="weekly")

print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Publish a model checkpoint to HuggingFace Hub",
)
parser.add_argument("--checkpoint", type=str, help="path to .ckpt checkpoint to publish")
parser.add_argument("--repo", type=str, required=True, help="HuggingFace repo ID (e.g. org/model-name)")
parser.add_argument(
"--tag", type=str, default=None, help="version tag (default: vYYYY.MM.DD based on today's date)"
)

# evaluation
parser.add_argument("--datasets", type=str, default="dgs", help="comma-separated dataset names for evaluation")
parser.add_argument("--corpus", type=str, default="/mnt/nas/GCS/sign-external-datasets/dgs-corpus")
parser.add_argument("--poses", type=str, default="/mnt/nas/GCS/sign-mediapipe-holistic-poses")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--skip-eval", action="store_true", help="skip evaluation and regression check")
parser.add_argument(
"--metrics-json", type=str, help="path to pre-computed metrics JSON (alternative to running eval)"
)

# regression / promotion
parser.add_argument(
"--regression-threshold",
type=float,
default=0.005,
help="IoU drop tolerance for regression check (default: 0.005)",
)
parser.add_argument("--no-promote", action="store_true", help="push without tagging or promoting")
parser.add_argument("--promote", action="store_true", help="tag the current weekly branch (no upload)")

args = parser.parse_args()

# resolve version tag
if args.tag is None:
args.tag = get_next_version(repo_id=args.repo)
print(f"Version: {args.tag}")

# standalone promote mode
if args.promote:
promote(repo_id=args.repo, tag=args.tag, revision="weekly")
return

if not args.checkpoint:
parser.error("--checkpoint is required (unless using --promote)")

publish(
checkpoint=args.checkpoint,
repo_id=args.repo,
tag=args.tag,
datasets=args.datasets,
corpus=args.corpus,
poses=args.poses,
device=args.device,
skip_eval=args.skip_eval,
metrics_json=args.metrics_json,
regression_threshold=args.regression_threshold,
no_promote=args.no_promote,
)


if __name__ == "__main__":
main()
155 changes: 155 additions & 0 deletions sign_language_segmentation/tests/test_publish_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Tests for publish() CLI orchestration with all HF + eval boundaries mocked."""

from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import pytest


class TestPublishIntegration:
"""End-to-end publish() orchestration with every HF + eval boundary mocked."""

@pytest.fixture
def ckpt_fixture(self, tmp_path):
import torch
ckpt_path = tmp_path / "fake.ckpt"
fake_ckpt = {
"state_dict": {"layer.weight": torch.randn(2, 2)},
"hyper_parameters": {
"hidden_dim": 128,
"encoder_depth": 4,
"attn_nhead": 8,
"attn_ff_mult": 4,
"attn_dropout": 0.1,
"num_frames": 1024,
"pose_dims": (75, 3),
"num_classes": 3,
"learning_rate": 1e-4,
"optimizer": "adam",
"dice_loss_weight": 0.5,
"fps_aug": True,
"frame_dropout": 0.1,
},
}
torch.save(fake_ckpt, ckpt_path)
return str(ckpt_path)

def _eval_results(self):
return {
"ds_a": {
"dev": {"sign_IoU": 0.80, "sentence_IoU": 0.70},
"test": {"sign_IoU": 0.81, "sentence_IoU": 0.71},
},
"combined": {
"dev": {"sign_IoU": 0.82, "sentence_IoU": 0.72},
"test": {"sign_IoU": 0.83, "sentence_IoU": 0.73},
},
}

def _mock_api(self):
mock_api = MagicMock()
# no prior version tags — regression check short-circuits to no_baseline
mock_api.list_repo_refs.return_value = SimpleNamespace(
tags=[SimpleNamespace(name="weekly", target_commit="commit123")],
branches=[],
)
return mock_api

def test_skip_eval_and_no_promote(self, ckpt_fixture):
from sign_language_segmentation.publish.publish import publish

mock_api = self._mock_api()
with patch("huggingface_hub.HfApi", return_value=mock_api):
publish(
checkpoint=ckpt_fixture,
repo_id="fake/repo",
tag="v2026.4.20",
datasets="ds_a",
corpus="",
poses="",
device="cpu",
skip_eval=True,
metrics_json=None,
regression_threshold=0.005,
no_promote=True,
)
mock_api.create_repo.assert_called_once()
mock_api.create_branch.assert_called_once()
mock_api.upload_folder.assert_called_once()
# no_promote=True — no tag should be created
mock_api.create_tag.assert_not_called()

def test_skip_eval_with_promote_creates_tag(self, ckpt_fixture):
from sign_language_segmentation.publish.publish import publish

mock_api = self._mock_api()
with patch("huggingface_hub.HfApi", return_value=mock_api):
publish(
checkpoint=ckpt_fixture,
repo_id="fake/repo",
tag="v2026.4.20",
datasets="ds_a",
corpus="",
poses="",
device="cpu",
skip_eval=True,
metrics_json=None,
regression_threshold=0.005,
no_promote=False,
)
mock_api.create_tag.assert_called_once()
args, kwargs = mock_api.create_tag.call_args
assert kwargs["tag"] == "v2026.4.20"

def test_with_eval_passes_regression_and_promotes(self, ckpt_fixture, tmp_path):
from sign_language_segmentation.publish.publish import publish

mock_api = self._mock_api()
# stub run_evaluation and check_regression at the publish.py binding site,
# since publish.py imports them at module scope
with patch("sign_language_segmentation.publish.publish.run_evaluation",
return_value=self._eval_results()), \
patch("sign_language_segmentation.publish.publish.check_regression",
return_value=("pass", None)), \
patch("huggingface_hub.HfApi", return_value=mock_api):
publish(
checkpoint=ckpt_fixture,
repo_id="fake/repo",
tag="v2026.4.20",
datasets="ds_a",
corpus="",
poses="",
device="cpu",
skip_eval=False,
metrics_json=None,
regression_threshold=0.005,
no_promote=False,
)
mock_api.upload_folder.assert_called_once()
mock_api.create_tag.assert_called_once()

def test_regression_fail_does_not_promote(self, ckpt_fixture):
from sign_language_segmentation.publish.publish import publish

mock_api = self._mock_api()
with patch("sign_language_segmentation.publish.publish.run_evaluation",
return_value=self._eval_results()), \
patch("sign_language_segmentation.publish.publish.check_regression",
return_value=("fail", None)), \
patch("huggingface_hub.HfApi", return_value=mock_api):
publish(
checkpoint=ckpt_fixture,
repo_id="fake/repo",
tag="v2026.4.20",
datasets="ds_a",
corpus="",
poses="",
device="cpu",
skip_eval=False,
metrics_json=None,
regression_threshold=0.005,
no_promote=False,
)
mock_api.upload_folder.assert_called_once()
# regression failed — no promotion
mock_api.create_tag.assert_not_called()
1 change: 1 addition & 0 deletions sign_language_segmentation/tests/test_publish_hf_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,4 @@ def test_raises_when_unresolved(self):
with pytest.raises(ValueError, match="Could not resolve revision"):
promote(repo_id="fake/repo", tag="v2026.4.20", revision="nonexistent")
mock_api.create_tag.assert_not_called()