Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ lightning_logs/
uv.lock
run_hpo.sh
*.pose
*.safetensors
.cache/
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[project]
name = "sign-language-segmentation"
description = "Sign language pose segmentation model on both the sentence and sign level"
Expand All @@ -16,6 +20,7 @@ dependencies = [
"pose-anonymization",
"scikit-learn",
"pytorch-lightning",
"safetensors",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -61,12 +66,11 @@ where = ["."]
include = ["sign_language_segmentation*"]

[tool.setuptools.package-data]
sign_language_segmentation = ["**/*.json", "**/*.ckpt", "**/*.yaml"]
sign_language_segmentation = ["**/*.json", "**/*.safetensors", "**/*.yaml"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert to safetensors, otherwise, ckpt file will not be installed on pip install


[tool.pytest.ini_options]
addopts = "-v"
testpaths = ["sign_language_segmentation"]

[project.scripts]
pose_to_segments = "sign_language_segmentation.bin:main"
slim_checkpoint = "sign_language_segmentation.slim_checkpoint:main"
95 changes: 79 additions & 16 deletions sign_language_segmentation/bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
and writes an ELAN (.eaf) annotation file with SIGN and SENTENCE tiers.
"""
import argparse
import json
import os
from functools import lru_cache
from pathlib import Path
Expand All @@ -13,19 +14,80 @@
import pympi
import torch
from pose_format import Pose
from safetensors.torch import load_file as load_safetensors

from sign_language_segmentation.utils.pose import preprocess_pose, compute_velocity
from sign_language_segmentation.metrics import likeliest_probs_to_segments, filter_segments
from sign_language_segmentation.model.model import PoseTaggingModel

_BAKED_IN_DIR = Path(__file__).resolve().parent / "dist" / "2026"

def _default_model_path() -> str:
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "dist", "2026", "best.ckpt")

def resolve_model_path() -> str:
"""Resolve model directory path.

Priority: MODEL_PATH env > HF_MODEL_REPO env > baked-in package default.
"""
# 1. explicit local path
explicit = os.environ.get("MODEL_PATH")
if explicit:
return explicit

# 2. huggingface hub download
hf_repo = os.environ.get("HF_MODEL_REPO")
if hf_repo:
return _download_from_hf(hf_repo)

# 3. baked-in default
return str(_BAKED_IN_DIR)


def _download_from_hf(repo_id: str) -> str:
"""Download model from HuggingFace Hub. Returns local cache directory."""
try:
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(
"huggingface_hub is required for HF_MODEL_REPO. "
"Install with: pip install sign-language-segmentation[hf]"
)
revision = os.environ.get("HF_MODEL_REVISION")
if not revision:
raise ValueError("HF_MODEL_REVISION must be set when using HF_MODEL_REPO")
return snapshot_download(
repo_id=repo_id,
revision=revision,
allow_patterns=["model.safetensors", "config.json"],
)


def _load_from_safetensors(model_dir: str, device: str) -> PoseTaggingModel:
"""Load model from safetensors + config.json directory."""
model_dir_path = Path(model_dir)
with open(model_dir_path / "config.json") as f:
config = json.load(f)
# config.json stores tuples as lists — convert pose_dims back
if "pose_dims" in config:
config["pose_dims"] = tuple(config["pose_dims"])
model = PoseTaggingModel(**config)
state_dict = load_safetensors(filename=str(model_dir_path / "model.safetensors"), device=device)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
return model


@lru_cache(maxsize=1)
def load_model(model_path: str, device: str = "cpu") -> PoseTaggingModel:
model = PoseTaggingModel.load_from_checkpoint(model_path, map_location=device)
def load_model(model_dir: str, device: str = "cpu", revision: str = "") -> PoseTaggingModel:
# revision is part of the cache key only — callers pass HF_MODEL_REVISION so a mid-process
# env change invalidates the cache entry instead of silently returning a stale model.
model_dir_path = Path(model_dir)
# prefer safetensors if available, fall back to .ckpt
if (model_dir_path / "model.safetensors").exists():
return _load_from_safetensors(model_dir=str(model_dir), device=device)
# backward compat: load .ckpt directly (model_dir might be a file path)
ckpt_path = model_dir_path if model_dir_path.suffix == ".ckpt" else model_dir_path / "best.ckpt"
model = PoseTaggingModel.load_from_checkpoint(checkpoint_path=str(ckpt_path), map_location=device)
model = model.to(device)
model.eval()
return model
Expand All @@ -51,18 +113,19 @@ def run_inference(model: PoseTaggingModel, pose: Pose, device: str) -> dict:
return model(pose_tensor, timestamps=timestamps)


def segment_pose(pose: Pose, model_path: str = None, device: str = "cpu",
def segment_pose(pose: Pose, model_dir: str = None, device: str = "cpu",
min_frames: int = 3, merge_gap: int = 0):
"""Segment a pose into signs and sentences.

Returns:
eaf: pympi.Elan.Eaf with SIGN and SENTENCE tiers
tiers: dict mapping tier name to list of {start, end} segment dicts
"""
model_path = model_path or _default_model_path()
model = load_model(model_path, device)
model_dir = model_dir or resolve_model_path()
revision = os.environ.get("HF_MODEL_REVISION", "")
model = load_model(model_dir=model_dir, device=device, revision=revision)

log_probs = run_inference(model, pose, device)
log_probs = run_inference(model=model, pose=pose, device=device)

fps = pose.body.fps
seg_fn = likeliest_probs_to_segments
Expand Down Expand Up @@ -103,7 +166,7 @@ def get_args():
parser.add_argument("--pose", required=True, type=Path, help="input .pose file")
parser.add_argument("--elan", required=True, type=str, help="output .eaf file path")
parser.add_argument("--model", default=None, type=str,
help="path to .ckpt checkpoint (default: dist/2026/best.ckpt)")
help="path to model directory (safetensors) or .ckpt file")
parser.add_argument("--video", default=None, type=str, help="video file to link in ELAN")
parser.add_argument("--subtitles", default=None, type=str, help="path to .srt subtitle file")
parser.add_argument("--no-pose-link", action="store_true", help="do not link pose file in ELAN")
Expand All @@ -120,21 +183,21 @@ def get_args():
def main():
args = get_args()

model_path = args.model or _default_model_path()
if not os.path.exists(model_path):
model_dir = args.model or resolve_model_path()
if not os.path.exists(model_dir):
raise FileNotFoundError(
f"Model not found: {model_path}\n"
"Download a checkpoint and place it at dist/2026/best.ckpt, "
"or pass --model <path>."
f"Model not found: {model_dir}\n"
"Set HF_MODEL_REPO env var, pass --model <path>, "
"or place model files at dist/2026/."
)

print(f"Loading pose: {args.pose}")
with open(args.pose, "rb") as f:
pose = Pose.read(f)

print(f"Loading model: {model_path}")
print(f"Loading model: {model_dir}")
print("Running inference...")
eaf, tiers = segment_pose(pose, model_path=model_path, device=args.device,
eaf, tiers = segment_pose(pose, model_dir=model_dir, device=args.device,
min_frames=args.min_frames, merge_gap=args.merge_gap)

sign_count = len(tiers["SIGN"])
Expand Down
20 changes: 20 additions & 0 deletions sign_language_segmentation/dist/2026/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"pose_dims": [
50,
6
],
"hidden_dim": 384,
"encoder_depth": 4,
"num_classes": 4,
"learning_rate": 0.0005,
"steps_per_epoch": 69,
"max_epochs": 400,
"dice_loss_weight": 1.5,
"attn_nhead": 8,
"attn_ff_mult": 2,
"attn_dropout": 0.1,
"optimizer": "adamw-onecycle",
"fps_aug": true,
"frame_dropout": 0.15,
"num_frames": 1024
}
Binary file not shown.