Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# authenticating with annotation platform db
CONVEX_URL=
CONVEX_AUTH_TOKEN=
# for Weights & Biases (not mandatory if logging in with wandb cli)
WAND_API_KEY=
# signtube db credentials
DB_NAME=
Comment thread
ziv-lazarov-nagish marked this conversation as resolved.
DB_HOST=
DB_USER=
DB_PASS=
# for interacting with HF api
HF_TOKEN=
HF_MODEL_REPO=
HF_MODEL_REVISION=
# needed for sparks (HF cache files are owned by root and error is thrown)
XDG_CACHE_HOME=
Comment thread
ziv-lazarov-nagish marked this conversation as resolved.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ uv.lock
run_hpo.sh
*.pose
.cache/
dist/
wandb/
*.log
*.safetensors
16 changes: 14 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[project]
name = "sign-language-segmentation"
description = "Sign language pose segmentation model on both the sentence and sign level"
Expand All @@ -16,6 +20,7 @@ dependencies = [
"pose-anonymization",
"scikit-learn",
"pytorch-lightning",
"safetensors",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -46,6 +51,14 @@ server = [
"Flask-Compress",
"Flask-Caching",
]
hf = [
"huggingface_hub>=0.20.0",
]
publish = [
"sign-language-segmentation[hf]",
"sign-language-segmentation[train]",
]


[tool.ruff]
line-length = 120
Expand All @@ -61,12 +74,11 @@ where = ["."]
include = ["sign_language_segmentation*"]

[tool.setuptools.package-data]
sign_language_segmentation = ["**/*.json", "**/*.ckpt", "**/*.yaml"]
sign_language_segmentation = ["**/*.json", "**/*.safetensors", "**/*.yaml", "**/*.md"]

[tool.pytest.ini_options]
addopts = "-v"
testpaths = ["sign_language_segmentation"]

[project.scripts]
pose_to_segments = "sign_language_segmentation.bin:main"
slim_checkpoint = "sign_language_segmentation.slim_checkpoint:main"
92 changes: 76 additions & 16 deletions sign_language_segmentation/bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
and writes an ELAN (.eaf) annotation file with SIGN and SENTENCE tiers.
"""
import argparse
import json
import os
from functools import lru_cache
from pathlib import Path
Expand All @@ -13,19 +14,78 @@
import pympi
import torch
from pose_format import Pose
from safetensors.torch import load_file as load_safetensors

from sign_language_segmentation.utils.pose import preprocess_pose, compute_velocity
from sign_language_segmentation.metrics import likeliest_probs_to_segments, filter_segments
from sign_language_segmentation.model.model import PoseTaggingModel

_BAKED_IN_DIR = Path(__file__).resolve().parent / "dist" / "2026"

def _default_model_path() -> str:
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "dist", "2026", "best.ckpt")

def resolve_model_path() -> str:
"""Resolve model directory path.

Priority: MODEL_PATH env > HF_MODEL_REPO env > baked-in package default.
"""
# 1. explicit local path
explicit = os.environ.get("MODEL_PATH")
if explicit:
return explicit

# 2. huggingface hub download
hf_repo = os.environ.get("HF_MODEL_REPO")
if hf_repo:
return _download_from_hf(hf_repo)

# 3. baked-in default
return str(_BAKED_IN_DIR)


def _download_from_hf(repo_id: str) -> str:
"""Download model from HuggingFace Hub. Returns local cache directory."""
try:
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(
"huggingface_hub is required for HF_MODEL_REPO. "
"Install with: pip install sign-language-segmentation[hf]"
)
revision = os.environ.get("HF_MODEL_REVISION")
if not revision:
raise ValueError("HF_MODEL_REVISION must be set when using HF_MODEL_REPO")
return snapshot_download(
repo_id=repo_id,
revision=revision,
allow_patterns=["model.safetensors", "config.json"],
)


def _load_from_safetensors(model_dir: str, device: str) -> PoseTaggingModel:
"""Load model from safetensors + config.json directory."""
model_dir_path = Path(model_dir)
with open(model_dir_path / "config.json") as f:
config = json.load(f)
# config.json stores tuples as lists — convert pose_dims back
if "pose_dims" in config:
config["pose_dims"] = tuple(config["pose_dims"])
model = PoseTaggingModel(**config)
state_dict = load_safetensors(filename=str(model_dir_path / "model.safetensors"), device=device)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
return model


@lru_cache(maxsize=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lru_cache keys by (model_dir, device), but when model_dir comes from resolve_model_path() via HF_MODEL_REPO, the actual content depends on HF_MODEL_REVISION too. If anything in the process ever changes HF_MODEL_REVISION between calls (long-running server, test fixtures), this cache returns a stale model silently. Either include the revision in the cache key (e.g. cache on resolve_model_path()s output, which bakes the revision into the snapshot dir), or document that the cache is process-lifetime and the env must not change.

(created by Claude)

def load_model(model_path: str, device: str = "cpu") -> PoseTaggingModel:
model = PoseTaggingModel.load_from_checkpoint(model_path, map_location=device)
def load_model(model_dir: str, device: str = "cpu") -> PoseTaggingModel:
model_dir_path = Path(model_dir)
# prefer safetensors if available, fall back to .ckpt
if (model_dir_path / "model.safetensors").exists():
return _load_from_safetensors(model_dir=str(model_dir), device=device)
# backward compat: load .ckpt directly (model_dir might be a file path)
ckpt_path = model_dir_path if model_dir_path.suffix == ".ckpt" else model_dir_path / "best.ckpt"
model = PoseTaggingModel.load_from_checkpoint(checkpoint_path=str(ckpt_path), map_location=device)
model = model.to(device)
model.eval()
return model
Expand All @@ -51,18 +111,18 @@ def run_inference(model: PoseTaggingModel, pose: Pose, device: str) -> dict:
return model(pose_tensor, timestamps=timestamps)


def segment_pose(pose: Pose, model_path: str = None, device: str = "cpu",
def segment_pose(pose: Pose, model_dir: str = None, device: str = "cpu",
min_frames: int = 3, merge_gap: int = 0):
"""Segment a pose into signs and sentences.

Returns:
eaf: pympi.Elan.Eaf with SIGN and SENTENCE tiers
tiers: dict mapping tier name to list of {start, end} segment dicts
"""
model_path = model_path or _default_model_path()
model = load_model(model_path, device)
model_dir = model_dir or resolve_model_path()
model = load_model(model_dir=model_dir, device=device)

log_probs = run_inference(model, pose, device)
log_probs = run_inference(model=model, pose=pose, device=device)

fps = pose.body.fps
seg_fn = likeliest_probs_to_segments
Expand Down Expand Up @@ -103,7 +163,7 @@ def get_args():
parser.add_argument("--pose", required=True, type=Path, help="input .pose file")
parser.add_argument("--elan", required=True, type=str, help="output .eaf file path")
parser.add_argument("--model", default=None, type=str,
help="path to .ckpt checkpoint (default: dist/2026/best.ckpt)")
help="path to model directory (safetensors) or .ckpt file")
parser.add_argument("--video", default=None, type=str, help="video file to link in ELAN")
parser.add_argument("--subtitles", default=None, type=str, help="path to .srt subtitle file")
parser.add_argument("--no-pose-link", action="store_true", help="do not link pose file in ELAN")
Expand All @@ -120,21 +180,21 @@ def get_args():
def main():
args = get_args()

model_path = args.model or _default_model_path()
if not os.path.exists(model_path):
model_dir = args.model or resolve_model_path()
if not os.path.exists(model_dir):
raise FileNotFoundError(
f"Model not found: {model_path}\n"
"Download a checkpoint and place it at dist/2026/best.ckpt, "
"or pass --model <path>."
f"Model not found: {model_dir}\n"
"Set HF_MODEL_REPO env var, pass --model <path>, "
"or place model files at dist/2026/."
)

print(f"Loading pose: {args.pose}")
with open(args.pose, "rb") as f:
pose = Pose.read(f)

print(f"Loading model: {model_path}")
print(f"Loading model: {model_dir}")
print("Running inference...")
eaf, tiers = segment_pose(pose, model_path=model_path, device=args.device,
eaf, tiers = segment_pose(pose, model_dir=model_dir, device=args.device,
min_frames=args.min_frames, merge_gap=args.merge_gap)

sign_count = len(tiers["SIGN"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ def main() -> None:
parser.add_argument("--gcs_root", type=str, default="/mnt/nas/GCS")
parser.add_argument("--output", type=Path, default=_DEFAULT_ANNOTATIONS_CACHE)
parser.add_argument("--no_score", action="store_true", default=False, help="skip scoring after sync")
parser.add_argument("--model_path", type=str, default="sign_language_segmentation/dist/2026/best.ckpt",
help="model checkpoint for scoring (default: dist/2026/best.ckpt)")
parser.add_argument("--model_path", type=str, default="sign_language_segmentation/dist/2026",
help="model directory or .ckpt path for scoring")
parser.add_argument("--device", type=str, default="gpu")

args = parser.parse_args()
Expand Down
31 changes: 15 additions & 16 deletions sign_language_segmentation/datasets/dgs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def __init__(
cache_dirty = False

videos_dir = self.corpus_dir / "videos"
doc_ids = sorted(
d.name for d in videos_dir.iterdir() if d.is_dir()
)
doc_ids = sorted(d.name for d in videos_dir.iterdir() if d.is_dir())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file has whitespace-only ruff reflows (this line, and several more below) that are unrelated to the publish pipeline. Consider splitting them into a separate formatting PR against nagish so this PRs diff stays focused on the publishing feature.

(created by Claude)


for doc_id in doc_ids:
if doc_id in EXCLUDED_IDS or is_joke(self.corpus_dir, doc_id):
Expand Down Expand Up @@ -117,23 +115,26 @@ def __init__(
continue

person_sentences = [
s for s in sentences
if s["participant"].lower() == person and len(s["glosses"]) > 0
s for s in sentences if s["participant"].lower() == person and len(s["glosses"]) > 0
]
if not person_sentences:
continue

all_glosses = [g for s in person_sentences for g in s["glosses"]]
sentence_spans = [{"start": s["start"], "end": s["end"]} for s in person_sentences]

self._track_and_filter(cache_key, doc_split, {
"id": cache_key,
"pose_path": str(pose_path),
"fps": fps,
"total_frames": total_frames,
"glosses": all_glosses,
"sentences": sentence_spans,
})
self._track_and_filter(
cache_key,
doc_split,
{
"id": cache_key,
"pose_path": str(pose_path),
"fps": fps,
"total_frames": total_frames,
"glosses": all_glosses,
"sentences": sentence_spans,
},
)

if cache_dirty:
self._save_cache(cache_file, cache)
Expand All @@ -150,9 +151,7 @@ def get_split_manifest(self) -> dict:
return {
"dataset": self.dataset_name,
"splits_path": self.splits_path,
"splits": {
s.value: sorted(ids) for s, ids in self._all_split_ids.items()
},
"splits": {s.value: sorted(ids) for s, ids in self._all_split_ids.items()},
}

def _load_cache(self, path: Path) -> dict:
Expand Down
1 change: 1 addition & 0 deletions sign_language_segmentation/publish/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from sign_language_segmentation.publish.publish import main, publish
40 changes: 40 additions & 0 deletions sign_language_segmentation/publish/model_card_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
---
language: dgs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

language: dgs is hardcoded, but the publish pipeline accepts arbitrary --datasets (and all). When the model is trained on non-DGS corpora this frontmatter will be wrong on the HF model page. Either plumb the language list through generate_model_card as a derived field from the dataset registry, or drop the field entirely and let the {{dataset_section}} below describe coverage.

(created by Claude)

tags:
- sign-language
- segmentation
- pose
- pytorch
- pytorch-lightning
library_name: pytorch
pipeline_tag: other
{{model_index}}
---

# Sign Language Segmentation

CNN-medium-attn model with RoPE for sign language segmentation.
Jointly trained on sign (gloss) and phrase (sentence) BIO tagging.

**Published:** {{published_at}}
**Tag:** `{{tag}}`
**Regression status:** {{regression_status}}

## Architecture

{{architecture_rows}}

{{eval_section}}

## Training Config

{{training_rows}}

{{dataset_section}}

## Usage

```bash
pip install sign-language-segmentation
pose_to_segments --pose input.pose --elan output.eaf
```
Loading
Loading