-
Notifications
You must be signed in to change notification settings - Fork 6
feat: HuggingFace publish pipeline with safetensors model format #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d60e759
b0c5756
b57644f
b834d87
744daab
bb67dd4
41df983
b3997ba
94007ee
c7198c6
c3727ce
aa3e609
2d5604d
b376fe8
edeafd7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,16 @@ | ||
| # authenticating with annotation platform db | ||
| CONVEX_URL= | ||
| CONVEX_AUTH_TOKEN= | ||
| # for Weights & Biases (not mandatory if logging in with wandb cli) | ||
| WAND_API_KEY= | ||
| # signtube db credentials | ||
| DB_NAME= | ||
| DB_HOST= | ||
| DB_USER= | ||
| DB_PASS= | ||
| # for interacting with HF api | ||
| HF_TOKEN= | ||
| HF_MODEL_REPO= | ||
| HF_MODEL_REVISION= | ||
| # needed for sparks (HF cache files are owned by root and error is thrown) | ||
| XDG_CACHE_HOME= | ||
|
ziv-lazarov-nagish marked this conversation as resolved.
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,3 +11,7 @@ uv.lock | |
| run_hpo.sh | ||
| *.pose | ||
| .cache/ | ||
| dist/ | ||
| wandb/ | ||
| *.log | ||
| *.safetensors | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| and writes an ELAN (.eaf) annotation file with SIGN and SENTENCE tiers. | ||
| """ | ||
| import argparse | ||
| import json | ||
| import os | ||
| from functools import lru_cache | ||
| from pathlib import Path | ||
|
|
@@ -13,19 +14,78 @@ | |
| import pympi | ||
| import torch | ||
| from pose_format import Pose | ||
| from safetensors.torch import load_file as load_safetensors | ||
|
|
||
| from sign_language_segmentation.utils.pose import preprocess_pose, compute_velocity | ||
| from sign_language_segmentation.metrics import likeliest_probs_to_segments, filter_segments | ||
| from sign_language_segmentation.model.model import PoseTaggingModel | ||
|
|
||
| _BAKED_IN_DIR = Path(__file__).resolve().parent / "dist" / "2026" | ||
|
|
||
| def _default_model_path() -> str: | ||
| return os.path.join(os.path.dirname(os.path.abspath(__file__)), "dist", "2026", "best.ckpt") | ||
|
|
||
| def resolve_model_path() -> str: | ||
| """Resolve model directory path. | ||
|
|
||
| Priority: MODEL_PATH env > HF_MODEL_REPO env > baked-in package default. | ||
| """ | ||
| # 1. explicit local path | ||
| explicit = os.environ.get("MODEL_PATH") | ||
| if explicit: | ||
| return explicit | ||
|
|
||
| # 2. huggingface hub download | ||
| hf_repo = os.environ.get("HF_MODEL_REPO") | ||
| if hf_repo: | ||
| return _download_from_hf(hf_repo) | ||
|
|
||
| # 3. baked-in default | ||
| return str(_BAKED_IN_DIR) | ||
|
|
||
|
|
||
| def _download_from_hf(repo_id: str) -> str: | ||
| """Download model from HuggingFace Hub. Returns local cache directory.""" | ||
| try: | ||
| from huggingface_hub import snapshot_download | ||
| except ImportError: | ||
| raise ImportError( | ||
| "huggingface_hub is required for HF_MODEL_REPO. " | ||
| "Install with: pip install sign-language-segmentation[hf]" | ||
| ) | ||
| revision = os.environ.get("HF_MODEL_REVISION") | ||
| if not revision: | ||
| raise ValueError("HF_MODEL_REVISION must be set when using HF_MODEL_REPO") | ||
| return snapshot_download( | ||
| repo_id=repo_id, | ||
| revision=revision, | ||
| allow_patterns=["model.safetensors", "config.json"], | ||
| ) | ||
|
|
||
|
|
||
| def _load_from_safetensors(model_dir: str, device: str) -> PoseTaggingModel: | ||
| """Load model from safetensors + config.json directory.""" | ||
| model_dir_path = Path(model_dir) | ||
| with open(model_dir_path / "config.json") as f: | ||
| config = json.load(f) | ||
| # config.json stores tuples as lists — convert pose_dims back | ||
| if "pose_dims" in config: | ||
| config["pose_dims"] = tuple(config["pose_dims"]) | ||
| model = PoseTaggingModel(**config) | ||
| state_dict = load_safetensors(filename=str(model_dir_path / "model.safetensors"), device=device) | ||
| model.load_state_dict(state_dict) | ||
| model = model.to(device) | ||
| model.eval() | ||
| return model | ||
|
|
||
|
|
||
| @lru_cache(maxsize=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
(created by Claude) |
||
| def load_model(model_path: str, device: str = "cpu") -> PoseTaggingModel: | ||
| model = PoseTaggingModel.load_from_checkpoint(model_path, map_location=device) | ||
| def load_model(model_dir: str, device: str = "cpu") -> PoseTaggingModel: | ||
| model_dir_path = Path(model_dir) | ||
| # prefer safetensors if available, fall back to .ckpt | ||
| if (model_dir_path / "model.safetensors").exists(): | ||
| return _load_from_safetensors(model_dir=str(model_dir), device=device) | ||
| # backward compat: load .ckpt directly (model_dir might be a file path) | ||
| ckpt_path = model_dir_path if model_dir_path.suffix == ".ckpt" else model_dir_path / "best.ckpt" | ||
| model = PoseTaggingModel.load_from_checkpoint(checkpoint_path=str(ckpt_path), map_location=device) | ||
| model = model.to(device) | ||
| model.eval() | ||
| return model | ||
|
|
@@ -51,18 +111,18 @@ def run_inference(model: PoseTaggingModel, pose: Pose, device: str) -> dict: | |
| return model(pose_tensor, timestamps=timestamps) | ||
|
|
||
|
|
||
| def segment_pose(pose: Pose, model_path: str = None, device: str = "cpu", | ||
| def segment_pose(pose: Pose, model_dir: str = None, device: str = "cpu", | ||
| min_frames: int = 3, merge_gap: int = 0): | ||
| """Segment a pose into signs and sentences. | ||
|
|
||
| Returns: | ||
| eaf: pympi.Elan.Eaf with SIGN and SENTENCE tiers | ||
| tiers: dict mapping tier name to list of {start, end} segment dicts | ||
| """ | ||
| model_path = model_path or _default_model_path() | ||
| model = load_model(model_path, device) | ||
| model_dir = model_dir or resolve_model_path() | ||
| model = load_model(model_dir=model_dir, device=device) | ||
|
|
||
| log_probs = run_inference(model, pose, device) | ||
| log_probs = run_inference(model=model, pose=pose, device=device) | ||
|
|
||
| fps = pose.body.fps | ||
| seg_fn = likeliest_probs_to_segments | ||
|
|
@@ -103,7 +163,7 @@ def get_args(): | |
| parser.add_argument("--pose", required=True, type=Path, help="input .pose file") | ||
| parser.add_argument("--elan", required=True, type=str, help="output .eaf file path") | ||
| parser.add_argument("--model", default=None, type=str, | ||
| help="path to .ckpt checkpoint (default: dist/2026/best.ckpt)") | ||
| help="path to model directory (safetensors) or .ckpt file") | ||
| parser.add_argument("--video", default=None, type=str, help="video file to link in ELAN") | ||
| parser.add_argument("--subtitles", default=None, type=str, help="path to .srt subtitle file") | ||
| parser.add_argument("--no-pose-link", action="store_true", help="do not link pose file in ELAN") | ||
|
|
@@ -120,21 +180,21 @@ def get_args(): | |
| def main(): | ||
| args = get_args() | ||
|
|
||
| model_path = args.model or _default_model_path() | ||
| if not os.path.exists(model_path): | ||
| model_dir = args.model or resolve_model_path() | ||
| if not os.path.exists(model_dir): | ||
| raise FileNotFoundError( | ||
| f"Model not found: {model_path}\n" | ||
| "Download a checkpoint and place it at dist/2026/best.ckpt, " | ||
| "or pass --model <path>." | ||
| f"Model not found: {model_dir}\n" | ||
| "Set HF_MODEL_REPO env var, pass --model <path>, " | ||
| "or place model files at dist/2026/." | ||
| ) | ||
|
|
||
| print(f"Loading pose: {args.pose}") | ||
| with open(args.pose, "rb") as f: | ||
| pose = Pose.read(f) | ||
|
|
||
| print(f"Loading model: {model_path}") | ||
| print(f"Loading model: {model_dir}") | ||
| print("Running inference...") | ||
| eaf, tiers = segment_pose(pose, model_path=model_path, device=args.device, | ||
| eaf, tiers = segment_pose(pose, model_dir=model_dir, device=args.device, | ||
| min_frames=args.min_frames, merge_gap=args.merge_gap) | ||
|
|
||
| sign_count = len(tiers["SIGN"]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,9 +65,7 @@ def __init__( | |
| cache_dirty = False | ||
|
|
||
| videos_dir = self.corpus_dir / "videos" | ||
| doc_ids = sorted( | ||
| d.name for d in videos_dir.iterdir() if d.is_dir() | ||
| ) | ||
| doc_ids = sorted(d.name for d in videos_dir.iterdir() if d.is_dir()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file has whitespace-only ruff reflows (this line, and several more below) that are unrelated to the publish pipeline. Consider splitting them into a separate formatting PR against (created by Claude) |
||
|
|
||
| for doc_id in doc_ids: | ||
| if doc_id in EXCLUDED_IDS or is_joke(self.corpus_dir, doc_id): | ||
|
|
@@ -117,23 +115,26 @@ def __init__( | |
| continue | ||
|
|
||
| person_sentences = [ | ||
| s for s in sentences | ||
| if s["participant"].lower() == person and len(s["glosses"]) > 0 | ||
| s for s in sentences if s["participant"].lower() == person and len(s["glosses"]) > 0 | ||
| ] | ||
| if not person_sentences: | ||
| continue | ||
|
|
||
| all_glosses = [g for s in person_sentences for g in s["glosses"]] | ||
| sentence_spans = [{"start": s["start"], "end": s["end"]} for s in person_sentences] | ||
|
|
||
| self._track_and_filter(cache_key, doc_split, { | ||
| "id": cache_key, | ||
| "pose_path": str(pose_path), | ||
| "fps": fps, | ||
| "total_frames": total_frames, | ||
| "glosses": all_glosses, | ||
| "sentences": sentence_spans, | ||
| }) | ||
| self._track_and_filter( | ||
| cache_key, | ||
| doc_split, | ||
| { | ||
| "id": cache_key, | ||
| "pose_path": str(pose_path), | ||
| "fps": fps, | ||
| "total_frames": total_frames, | ||
| "glosses": all_glosses, | ||
| "sentences": sentence_spans, | ||
| }, | ||
| ) | ||
|
|
||
| if cache_dirty: | ||
| self._save_cache(cache_file, cache) | ||
|
|
@@ -150,9 +151,7 @@ def get_split_manifest(self) -> dict: | |
| return { | ||
| "dataset": self.dataset_name, | ||
| "splits_path": self.splits_path, | ||
| "splits": { | ||
| s.value: sorted(ids) for s, ids in self._all_split_ids.items() | ||
| }, | ||
| "splits": {s.value: sorted(ids) for s, ids in self._all_split_ids.items()}, | ||
| } | ||
|
|
||
| def _load_cache(self, path: Path) -> dict: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from sign_language_segmentation.publish.publish import main, publish |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| --- | ||
| language: dgs | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
(created by Claude) |
||
| tags: | ||
| - sign-language | ||
| - segmentation | ||
| - pose | ||
| - pytorch | ||
| - pytorch-lightning | ||
| library_name: pytorch | ||
| pipeline_tag: other | ||
| {{model_index}} | ||
| --- | ||
|
|
||
| # Sign Language Segmentation | ||
|
|
||
| CNN-medium-attn model with RoPE for sign language segmentation. | ||
| Jointly trained on sign (gloss) and phrase (sentence) BIO tagging. | ||
|
|
||
| **Published:** {{published_at}} | ||
| **Tag:** `{{tag}}` | ||
| **Regression status:** {{regression_status}} | ||
|
|
||
| ## Architecture | ||
|
|
||
| {{architecture_rows}} | ||
|
|
||
| {{eval_section}} | ||
|
|
||
| ## Training Config | ||
|
|
||
| {{training_rows}} | ||
|
|
||
| {{dataset_section}} | ||
|
|
||
| ## Usage | ||
|
|
||
| ```bash | ||
| pip install sign-language-segmentation | ||
| pose_to_segments --pose input.pose --elan output.eaf | ||
| ``` | ||
Uh oh!
There was an error while loading. Please reload this page.