Skip to content

Load from_pretrained at target dtype; restrict quantize= to int8#348

Merged
Ingvarstep merged 1 commit into
urchade:mainfrom
maxwbuckley:from-pretrained-dtype
Apr 24, 2026
Merged

Load from_pretrained at target dtype; restrict quantize= to int8#348
Ingvarstep merged 1 commit into
urchade:mainfrom
maxwbuckley:from-pretrained-dtype

Conversation

@maxwbuckley
Copy link
Copy Markdown
Contributor

@maxwbuckley maxwbuckley commented Apr 22, 2026

Important

No released wheel is affected. The current PyPI release — gliner 0.2.26, uploaded 2026-03-19 — pre-dates PR #342 (merged 2026-03-31), which introduced the quantize= surface. So removing the pure-downcast aliases here only reaches users installing from main / a git SHA. No one on pip install gliner needs to change anything.

Summary

Two related changes to the precision / quantization surface, landed together because they form one coherent story:

  1. Add dtype= to GLiNER.from_pretrained — load weights directly at the target floating-point precision. Casts each state-dict tensor during the safe_open read and pre-casts the model shell before load_state_dict, so a full fp32 copy is never materialized. Accepts "bf16"/"fp16"/"float32"/... or a floating-point torch.dtype; non-floating dtypes are rejected up front with a clear error that points at quantize="int8" for int paths. Int/bool buffers are left untouched.
  2. Restrict quantize= to "int8" only (per @Ingvarstep's review) — drops the pure-downcast aliases (True, "fp16", "bf16") entirely; those now raise ValueError with a migration message pointing at dtype= / model.to(...). quantize="int8" is unchanged: torchao int8 weight-only on GPU, FBGEMM dynamic quant on CPU.

Why dtype= matters

Previously from_pretrained always loaded the state dict at its on-disk precision (typically fp32), copied it into an fp32 model, then required a separate quantize(...) pass to cast down. Two copies co-resident, plus a post-load cast pass. For users who just want bf16/fp16 inference, that was ~2× the peak memory and an extra pass.

This is particularly useful for cold starts and scalable serverless deployments (AWS Lambda, Cloud Run, Modal, RunPod serverless, autoscaled Kubernetes pods, etc.) where startup latency and peak memory directly drive cost and SLA:

  • Shorter cold-start on every new container — one pass instead of load + cast.
  • Lower peak memory lets instances fit on smaller memory tiers and reduces boot-time OOMs.
  • Faster first-inference latency after scale-from-zero.

For CPU-only loads the peak-memory drop is a clean ~2×→1× fp32. For map_location="cuda", state-dict tensors stream to GPU while the shell is CPU-side, so the win is avoiding a simultaneous fp32 GPU state dict + fp32 GPU model (not literal halving of total footprint) plus the separate cast pass.

Why restrict quantize= to int8

Before this PR the quantize= surface was doing two different things under one name:

quantize= Device What happens Real quant?
True / "fp16" GPU model.half() No — pure downcast
True / "fp16" CPU quantize_dynamic(Linear, fp16) Yes — dynamic quant
"bf16" GPU model.bfloat16() No — pure downcast
"bf16" CPU model.bfloat16() No — pure downcast
"int8" GPU torchao int8 weight-only Yes
"int8" CPU quantize_dynamic(Linear, qint8) (FBGEMM) Yes

Three rows are just .to(dtype) with an extra fp32 intermediate — exactly what dtype= now does cheaply. A fourth (CPU fp16 dynamic quant) had no documented speed benefit and was asymmetric with every other CPU/GPU combination. The remaining two int8 rows are the only genuinely distinct behavior. After this PR:

  • quantize=TrueValueError (pointing at quantize='int8' or dtype=).
  • quantize="fp16" | "bf16" (any device) → ValueError (pointing at dtype= or model.to(torch.X)).
  • quantize="int8" → unchanged: torchao on GPU, FBGEMM dynamic quant on CPU.
  • model.quantize(...) takes the same restriction and defaults to "int8" instead of "fp16".

Signature narrows from Union[bool, str] = False to Optional[str] = None. See the callout at the top for why the blast radius is negligible.

Ray Serve integration (gliner.serve)

GLiNERFactory and the gliner.serve CLI are wired through in lockstep:

  • gliner/serve/server.py now passes map_location=config.device + dtype=self.torch_dtype to GLiNER.from_pretrained instead of loading at fp32 and then doing a post-load .to(device=..., dtype=...) cast. Serving cold starts get the same peak-memory win as direct users — which is exactly the autoscaled / serverless use case the core change targets.
  • --quantization CLI choices narrowed from {fp16, bf16, int8, None} to {int8, None}, matching the core API. Precision stays under --dtype (which already accepted float32/float16/fp16/bfloat16/bf16). docs/serving.md CLI reference + GLINER_QUANTIZATION env-var row updated.
  • Without this change, anyone passing --quantization fp16 or --quantization bf16 after the int8 restriction would have hit ValueError at server boot. The serve layer is also post-0.2.26 (PR Feature/serving #346), so no released wheel is affected.

Usage

from gliner import GLiNER
import torch

# Load directly in target precision
model = GLiNER.from_pretrained("urchade/gliner_small-v2.1", dtype="bf16")
model = GLiNER.from_pretrained("urchade/gliner_small-v2.1", dtype=torch.float16, map_location="cuda")

# Post-load, PyTorch's built-in `.to()` covers precision
model.to(torch.bfloat16)

# Real quantization (only remaining `quantize=` value)
model.quantize("int8")  # torchao on GPU, FBGEMM on CPU

Alignment with HuggingFace from_pretrained

This puts GLiNER on the same axis as HuggingFace transformers.PreTrainedModel.from_pretrained, which recently settled on the same split:

  • HF's current precision knob is dtype= (renamed from the now-legacy torch_dtype=), accepting a string or torch.dtype. We use the same parameter name and input surface.
  • HF separates real quantization into quantization_config=BitsAndBytesConfig(...) / AutoGPTQConfig / AutoAWQConfig. Our quantize="int8" plays the same role; restricting quantize= to int8 mirrors HF's decision not to overload the quantization knob for precision changes.

Users coming from HF will find GLiNER.from_pretrained(..., dtype="bf16") immediately familiar.

Known gap (possible follow-up, not in this PR): HF supports low_cpu_mem_usage=True / device_map=, which uses PyTorch's meta device to skip allocating the random-init shell entirely and then materializes weights via load_state_dict(assign=True). That's strictly cheaper than our pre-cast approach (~0 extra bytes for the shell vs ~1× target-dtype model size for ours). It would mean constructing the backbone on torch.device("meta") rather than via from_config, plus switching to assign=True loading — a bigger surgery than this PR. Worth chasing for the last slice of cold-start memory if it matters.

Tests

New tests/test_quantize_and_dtype.py — 50 unit tests covering:

  • _parse_dtype: None passthrough; all string aliases (fp16/float16/half/bf16/bfloat16/fp32/float32/float, case-insensitive); torch.dtype passthrough; unknown strings → ValueError; non-floating torch.dtype (int8/int32/int64/bool/uint8) → ValueError pointing at quantize='int8'; non-str/non-dtype → TypeError.
  • _load_state_dict: default path leaves tensors at stored dtype; dtype= casts floats on read in both the safetensors and torch.load branches; int/bool buffers preserved; idempotent when stored dtype already matches.
  • model.quantize(): "int8" (case-insensitive) routes to _apply_int8_quantization on both CPU and CUDA; precision aliases (fp16/float16/half/bf16/bfloat16) on both devices → ValueError with migration message naming the correct torch.float16/torch.bfloat16 replacement; unknown strings → ValueError; non-string inputs (True, False, None, ints, floats) → TypeError; ONNX model → RuntimeError (and the ONNX check runs before the alias check).
  • from_pretrained(quantize=True): rejected via the shared guard in model.quantize().

Heavier end-to-end tests (real hub checkpoint load with dtype="bf16" + inference parity) are intentionally out of scope for the unit suite — worth a manual smoke before merge.

Test plan

  • pytest tests/test_quantize_and_dtype.py → 50/50 pass locally.
  • ruff check / ruff format --check clean on new and modified files.
  • Manual end-to-end smoke against a real hub checkpoint with dtype="bf16"; confirm inference parity with the legacy quantize="bf16" path.

🤖 Generated with Claude Code

@maxwbuckley maxwbuckley changed the title Load from_pretrained weights directly at target dtype Load from_pretrained at target dtype; deprecate pure-downcast quantize Apr 23, 2026
@maxwbuckley maxwbuckley force-pushed the from-pretrained-dtype branch from ac6adb1 to 5de7bf0 Compare April 23, 2026 08:19
@urchade urchade requested a review from Ingvarstep April 23, 2026 12:05
Copy link
Copy Markdown
Collaborator

@Ingvarstep Ingvarstep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @maxwbuckley , thank you for your contribution, please, see my review.

Comment thread docs/usage.md
import torch

# Either a string or a torch.dtype
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", dtype="bf16", map_location="cuda")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like that it aligns with HF standards right now.

Comment thread gliner/model.py Outdated
f"or only `quantize='int8'` for real quantization.",
stacklevel=2,
)
instance.quantize(quantize_dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's discussable, but maybe we should allow only int8 quantizaiton?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — restricted quantize= to "int8" only in 9f551c0.

  • quantize=True / "fp16" / "bf16" on from_pretrained and model.quantize(...) now raise ValueError with a migration message pointing at dtype= / model.to(torch.X).
  • Dropped the CPU fp16 dynamic-quant path along with the pure-downcast aliases — it had no documented speed benefit and was asymmetric with every other CPU/GPU combination.
  • quantize= signature narrowed from Union[bool, str] = False to Optional[str] = None; model.quantize() default changed from "fp16" to "int8".
  • Added 50 unit tests in tests/test_quantize_and_dtype.py covering the new validation, the dtype= load path, and _load_state_dict cast-on-read.

Since gliner 0.2.26 (current PyPI, 2026-03-19) pre-dates #342, no released wheel ships the old aliases — blast radius is main / git-SHA installs only.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good, thank you.

Comment thread gliner/model.py
state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True)
if dtype is not None:
for k, v in state_dict.items():
if torch.is_tensor(v) and v.is_floating_point() and v.dtype != dtype:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good

@maxwbuckley maxwbuckley force-pushed the from-pretrained-dtype branch 2 times, most recently from 02535c7 to 9f551c0 Compare April 23, 2026 17:22
@maxwbuckley maxwbuckley changed the title Load from_pretrained at target dtype; deprecate pure-downcast quantize Load from_pretrained at target dtype; restrict quantize= to int8 Apr 23, 2026
@maxwbuckley
Copy link
Copy Markdown
Contributor Author

Awesome. Thanks a lot!

Two related changes to the precision / quantization surface, landed
together because they form one coherent story.

1. Add `dtype=` to `GLiNER.from_pretrained`

   Load weights directly at the target floating-point precision. Each
   state-dict tensor is cast during the `safe_open` read and the
   random-init model shell is pre-cast via `instance.model.to(dtype)`
   before `load_state_dict`, so a full fp32 snapshot never co-exists
   with the loaded weights. Accepts strings (`"bf16"`, `"fp16"`,
   `"float32"`, ...) or a floating-point `torch.dtype`; non-floating
   dtypes (e.g. `torch.int8`) are rejected up front with a message
   pointing at `quantize="int8"` for int paths. Int / bool buffers are
   preserved in the state dict.

   Memory impact: for CPU-only loads peak drops from ~2x fp32 to ~1x
   fp32; for `map_location="cuda"`, the saving is avoiding a
   simultaneous fp32 GPU state dict + fp32 GPU model plus the separate
   post-load cast pass. Matches the `dtype=` surface on
   `transformers.PreTrainedModel.from_pretrained` (string or
   `torch.dtype`, same semantics), so users coming from HF get a
   familiar API.

   Primary target: cold starts and scalable serverless deployments
   (Lambda, Cloud Run, Modal, RunPod serverless, autoscaled k8s) where
   startup latency and peak memory drive cost and SLA.

   The Ray-Serve layer (`gliner.serve.GLiNERFactory`) is wired through:
   it now passes `map_location` + `dtype` to `from_pretrained` instead
   of doing a post-load `.to(device=..., dtype=...)` cast, so serving
   cold starts get the same peak-memory win.

2. Restrict `quantize=` to `"int8"` only

   Previously `quantize=` accepted `True`, `"fp16"`, `"bf16"`, and
   `"int8"`. Three of the five effective rows (GPU fp16, GPU bf16, CPU
   bf16) were just `.to(dtype)` with an extra fp32 intermediate — not
   quantization. A fourth (CPU fp16) wrapped `nn.Linear` with
   dynamic-quantized variants but had no documented speed benefit and
   was asymmetric with the other CPU/GPU combinations.

   All non-int8 values now raise `ValueError` with a migration message
   pointing at `dtype=` / `model.to(torch_dtype)`. `quantize="int8"`
   is unchanged: torchao int8 weight-only on GPU, FBGEMM dynamic
   quant on CPU. The signature narrows from `Union[bool, str] = False`
   to `Optional[str] = None`; `model.quantize()` defaults to `"int8"`.

   The serving CLI (`gliner.serve --quantization`) is narrowed in
   lockstep to `{int8, None}`; precision stays under `--dtype`.

Docs in `docs/usage.md` and `docs/serving.md` are updated:
- New "Reduced-precision loading (`dtype`)" section in usage.md.
- "Quantization, Compilation & FlashDeBERTa" now shows `dtype="fp16"`
  as the only half-precision path.
- Serving CLI reference and env-var table reflect the int8-only
  `--quantization` / `GLINER_QUANTIZATION` surface.

New tests in `tests/test_quantize_and_dtype.py` (50 cases) cover
`_parse_dtype`, `_load_state_dict` cast-on-read, and the int8-only
`model.quantize(...)` validation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@maxwbuckley maxwbuckley force-pushed the from-pretrained-dtype branch from 9f551c0 to b7b819c Compare April 24, 2026 11:45
@Ingvarstep Ingvarstep merged commit 4b9e7f7 into urchade:main Apr 24, 2026
@maxwbuckley maxwbuckley deleted the from-pretrained-dtype branch April 24, 2026 12:34
Ingvarstep pushed a commit that referenced this pull request Apr 26, 2026
Addresses 35 pre-existing test failures that pre-dated #348 (all
confirmed on `f6137fb` prior to my work) across 6 root causes:

**decoder tests (15) — `Span` dataclass vs tuple mismatch**
Commit 18e4237 replaced the tuple output of SpanDecoder/TokenDecoder
etc. with a `Span` dataclass, but the tests were never updated. Tests
previously unpacked spans positionally (`start, end, entity_type, score
= span`, `len(span) == 4`, `span[i]`). Updated to use attribute access
(`span.start`, `span.entity_type`, `span.generated_labels`), and
`TestGreedySearch` now builds `Span(...)` objects instead of raw tuples
(production `greedy_search` does `sorted(spans, key=lambda x: -x.score)`
and requires `Span` objects).

**data_processing tests (16) — Mock/attribute + API drift**
- `mock_config` was bare `Mock()`, so `config.precomputed_prompts_mode`
  and `config.neg_spans_ratio` returned truthy Mocks; `precomputed_prompts`
  silently enabled a fast path that emits `[SEP] tokens` instead of
  `[ENT] label [SEP] tokens`, and `neg_spans_ratio` broke int * Mock
  arithmetic. Pinned both to real values (`False`, `0.0`).
- `UniEncoderTokenProcessor.create_labels` was refactored to take a
  batch dict; tests still called the old 4-positional API. Rewrote via
  a `_make_create_labels_batch` helper.
- `preprocess_example` no longer emits `entities_id`; produces
  `span_idx`/`span_label` tensors (or `None` for empty NER).
- `create_batch_dict` now requires `span_idx` in batch items (even if
  `None` to disable span padding).
- `create_relation_labels` now needs `span_mask`; `rel_class_to_ids`
  is list-per-example.

**modeling tests (3) — shape drift from compile-sync optimization**
`select_decoder_embedding` / `select_target_embedding` were changed in
commit 84994b3 to skip the `lengths.max().item()` GPU→CPU sync under
eager execution, keeping the full N-width output (padding stays zero
via `target_mask`). Tests expected the old truncated shape
(B, max_len, D); updated to (B, N, D) with padding-position checks.
`rel_loss` tests passed `rel_mask=(B, P, 1)` and
`class_mask=(B, C)` — broken by the current `(B, P, C) * (B, P, C)`
contract. Fixed both masks to match documented shapes.

**test collection — `from tests.utils_infer` import**
`tests/test_infer_packing.py` uses the absolute import
`from tests.utils_infer import ...`, which fails under plain `pytest
tests/` because the repo root wasn't on `sys.path`. Added
`[tool.pytest.ini_options]` with `pythonpath = ["."]` and
`testpaths = ["tests"]` to `pyproject.toml` — plain `pytest` from
repo root now works.

**CI: `.github/workflows/tests.yml`**
New workflow runs pytest on every PR to main and every push to main
(plus `workflow_dispatch`). Matrix: Python 3.10 and 3.12. Parallel
`ruff` job scoped to files changed in recent work (full-repo lint is a
separate cleanup — the repo has pre-existing ruff findings outside our
ownership). Concurrency group cancels in-progress runs on new pushes.

To enforce that broken tests can't be merged, the maintainer should
mark this workflow as a required status check in the branch protection
rules for `main` (Settings → Branches → Branch protection rules). CI
alone reports failures; it doesn't block merge without that setting.

Final suite: 265 passed, 1 skipped, 0 failed on Python 3.12 locally
(up from 230 passed / 35 failed / 1 error before).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants