diff --git a/.gitignore b/.gitignore index c3355f9..9f446e6 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ lightning_logs/ uv.lock run_hpo.sh *.pose +*.safetensors .cache/ diff --git a/pyproject.toml b/pyproject.toml index 544cc1a..e78bfe0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.build_meta" + [project] name = "sign-language-segmentation" description = "Sign language pose segmentation model on both the sentence and sign level" @@ -16,6 +20,7 @@ dependencies = [ "pose-anonymization", "scikit-learn", "pytorch-lightning", + "safetensors", ] [project.optional-dependencies] @@ -61,7 +66,7 @@ where = ["."] include = ["sign_language_segmentation*"] [tool.setuptools.package-data] -sign_language_segmentation = ["**/*.json", "**/*.ckpt", "**/*.yaml"] +sign_language_segmentation = ["**/*.json", "**/*.safetensors", "**/*.yaml"] [tool.pytest.ini_options] addopts = "-v" @@ -69,4 +74,3 @@ testpaths = ["sign_language_segmentation"] [project.scripts] pose_to_segments = "sign_language_segmentation.bin:main" -slim_checkpoint = "sign_language_segmentation.slim_checkpoint:main" diff --git a/sign_language_segmentation/bin.py b/sign_language_segmentation/bin.py index 0662d6c..b8189a6 100644 --- a/sign_language_segmentation/bin.py +++ b/sign_language_segmentation/bin.py @@ -5,6 +5,7 @@ and writes an ELAN (.eaf) annotation file with SIGN and SENTENCE tiers. """ import argparse +import json import os from functools import lru_cache from pathlib import Path @@ -13,19 +14,80 @@ import pympi import torch from pose_format import Pose +from safetensors.torch import load_file as load_safetensors from sign_language_segmentation.utils.pose import preprocess_pose, compute_velocity from sign_language_segmentation.metrics import likeliest_probs_to_segments, filter_segments from sign_language_segmentation.model.model import PoseTaggingModel +_BAKED_IN_DIR = Path(__file__).resolve().parent / "dist" / "2026" -def _default_model_path() -> str: - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "dist", "2026", "best.ckpt") + +def resolve_model_path() -> str: + """Resolve model directory path. + + Priority: MODEL_PATH env > HF_MODEL_REPO env > baked-in package default. + """ + # 1. explicit local path + explicit = os.environ.get("MODEL_PATH") + if explicit: + return explicit + + # 2. huggingface hub download + hf_repo = os.environ.get("HF_MODEL_REPO") + if hf_repo: + return _download_from_hf(hf_repo) + + # 3. baked-in default + return str(_BAKED_IN_DIR) + + +def _download_from_hf(repo_id: str) -> str: + """Download model from HuggingFace Hub. Returns local cache directory.""" + try: + from huggingface_hub import snapshot_download + except ImportError: + raise ImportError( + "huggingface_hub is required for HF_MODEL_REPO. " + "Install with: pip install sign-language-segmentation[hf]" + ) + revision = os.environ.get("HF_MODEL_REVISION") + if not revision: + raise ValueError("HF_MODEL_REVISION must be set when using HF_MODEL_REPO") + return snapshot_download( + repo_id=repo_id, + revision=revision, + allow_patterns=["model.safetensors", "config.json"], + ) + + +def _load_from_safetensors(model_dir: str, device: str) -> PoseTaggingModel: + """Load model from safetensors + config.json directory.""" + model_dir_path = Path(model_dir) + with open(model_dir_path / "config.json") as f: + config = json.load(f) + # config.json stores tuples as lists — convert pose_dims back + if "pose_dims" in config: + config["pose_dims"] = tuple(config["pose_dims"]) + model = PoseTaggingModel(**config) + state_dict = load_safetensors(filename=str(model_dir_path / "model.safetensors"), device=device) + model.load_state_dict(state_dict) + model = model.to(device) + model.eval() + return model @lru_cache(maxsize=1) -def load_model(model_path: str, device: str = "cpu") -> PoseTaggingModel: - model = PoseTaggingModel.load_from_checkpoint(model_path, map_location=device) +def load_model(model_dir: str, device: str = "cpu", revision: str = "") -> PoseTaggingModel: + # revision is part of the cache key only — callers pass HF_MODEL_REVISION so a mid-process + # env change invalidates the cache entry instead of silently returning a stale model. + model_dir_path = Path(model_dir) + # prefer safetensors if available, fall back to .ckpt + if (model_dir_path / "model.safetensors").exists(): + return _load_from_safetensors(model_dir=str(model_dir), device=device) + # backward compat: load .ckpt directly (model_dir might be a file path) + ckpt_path = model_dir_path if model_dir_path.suffix == ".ckpt" else model_dir_path / "best.ckpt" + model = PoseTaggingModel.load_from_checkpoint(checkpoint_path=str(ckpt_path), map_location=device) model = model.to(device) model.eval() return model @@ -51,7 +113,7 @@ def run_inference(model: PoseTaggingModel, pose: Pose, device: str) -> dict: return model(pose_tensor, timestamps=timestamps) -def segment_pose(pose: Pose, model_path: str = None, device: str = "cpu", +def segment_pose(pose: Pose, model_dir: str = None, device: str = "cpu", min_frames: int = 3, merge_gap: int = 0): """Segment a pose into signs and sentences. @@ -59,10 +121,11 @@ def segment_pose(pose: Pose, model_path: str = None, device: str = "cpu", eaf: pympi.Elan.Eaf with SIGN and SENTENCE tiers tiers: dict mapping tier name to list of {start, end} segment dicts """ - model_path = model_path or _default_model_path() - model = load_model(model_path, device) + model_dir = model_dir or resolve_model_path() + revision = os.environ.get("HF_MODEL_REVISION", "") + model = load_model(model_dir=model_dir, device=device, revision=revision) - log_probs = run_inference(model, pose, device) + log_probs = run_inference(model=model, pose=pose, device=device) fps = pose.body.fps seg_fn = likeliest_probs_to_segments @@ -103,7 +166,7 @@ def get_args(): parser.add_argument("--pose", required=True, type=Path, help="input .pose file") parser.add_argument("--elan", required=True, type=str, help="output .eaf file path") parser.add_argument("--model", default=None, type=str, - help="path to .ckpt checkpoint (default: dist/2026/best.ckpt)") + help="path to model directory (safetensors) or .ckpt file") parser.add_argument("--video", default=None, type=str, help="video file to link in ELAN") parser.add_argument("--subtitles", default=None, type=str, help="path to .srt subtitle file") parser.add_argument("--no-pose-link", action="store_true", help="do not link pose file in ELAN") @@ -120,21 +183,21 @@ def get_args(): def main(): args = get_args() - model_path = args.model or _default_model_path() - if not os.path.exists(model_path): + model_dir = args.model or resolve_model_path() + if not os.path.exists(model_dir): raise FileNotFoundError( - f"Model not found: {model_path}\n" - "Download a checkpoint and place it at dist/2026/best.ckpt, " - "or pass --model ." + f"Model not found: {model_dir}\n" + "Set HF_MODEL_REPO env var, pass --model , " + "or place model files at dist/2026/." ) print(f"Loading pose: {args.pose}") with open(args.pose, "rb") as f: pose = Pose.read(f) - print(f"Loading model: {model_path}") + print(f"Loading model: {model_dir}") print("Running inference...") - eaf, tiers = segment_pose(pose, model_path=model_path, device=args.device, + eaf, tiers = segment_pose(pose, model_dir=model_dir, device=args.device, min_frames=args.min_frames, merge_gap=args.merge_gap) sign_count = len(tiers["SIGN"]) diff --git a/sign_language_segmentation/dist/2026/config.json b/sign_language_segmentation/dist/2026/config.json new file mode 100644 index 0000000..d5b8e6d --- /dev/null +++ b/sign_language_segmentation/dist/2026/config.json @@ -0,0 +1,20 @@ +{ + "pose_dims": [ + 50, + 6 + ], + "hidden_dim": 384, + "encoder_depth": 4, + "num_classes": 4, + "learning_rate": 0.0005, + "steps_per_epoch": 69, + "max_epochs": 400, + "dice_loss_weight": 1.5, + "attn_nhead": 8, + "attn_ff_mult": 2, + "attn_dropout": 0.1, + "optimizer": "adamw-onecycle", + "fps_aug": true, + "frame_dropout": 0.15, + "num_frames": 1024 +} \ No newline at end of file diff --git a/sign_language_segmentation/dist/2026/best.ckpt b/sign_language_segmentation/dist/2026/model.safetensors similarity index 99% rename from sign_language_segmentation/dist/2026/best.ckpt rename to sign_language_segmentation/dist/2026/model.safetensors index 7d80417..d9e8feb 100644 Binary files a/sign_language_segmentation/dist/2026/best.ckpt and b/sign_language_segmentation/dist/2026/model.safetensors differ