Load from_pretrained at target dtype; restrict quantize= to int8#348
Merged
Conversation
5 tasks
quantize
ac6adb1 to
5de7bf0
Compare
Ingvarstep
reviewed
Apr 23, 2026
Collaborator
Ingvarstep
left a comment
There was a problem hiding this comment.
Hi @maxwbuckley , thank you for your contribution, please, see my review.
| import torch | ||
|
|
||
| # Either a string or a torch.dtype | ||
| model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", dtype="bf16", map_location="cuda") |
Collaborator
There was a problem hiding this comment.
I really like that it aligns with HF standards right now.
| f"or only `quantize='int8'` for real quantization.", | ||
| stacklevel=2, | ||
| ) | ||
| instance.quantize(quantize_dtype) |
Collaborator
There was a problem hiding this comment.
It's discussable, but maybe we should allow only int8 quantizaiton?
Contributor
Author
There was a problem hiding this comment.
Done — restricted quantize= to "int8" only in 9f551c0.
quantize=True/"fp16"/"bf16"onfrom_pretrainedandmodel.quantize(...)now raiseValueErrorwith a migration message pointing atdtype=/model.to(torch.X).- Dropped the CPU fp16 dynamic-quant path along with the pure-downcast aliases — it had no documented speed benefit and was asymmetric with every other CPU/GPU combination.
quantize=signature narrowed fromUnion[bool, str] = FalsetoOptional[str] = None;model.quantize()default changed from"fp16"to"int8".- Added 50 unit tests in
tests/test_quantize_and_dtype.pycovering the new validation, thedtype=load path, and_load_state_dictcast-on-read.
Since gliner 0.2.26 (current PyPI, 2026-03-19) pre-dates #342, no released wheel ships the old aliases — blast radius is main / git-SHA installs only.
Collaborator
There was a problem hiding this comment.
It looks good, thank you.
| state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True) | ||
| if dtype is not None: | ||
| for k, v in state_dict.items(): | ||
| if torch.is_tensor(v) and v.is_floating_point() and v.dtype != dtype: |
02535c7 to
9f551c0
Compare
quantizequantize= to int8
Contributor
Author
|
Awesome. Thanks a lot! |
Two related changes to the precision / quantization surface, landed
together because they form one coherent story.
1. Add `dtype=` to `GLiNER.from_pretrained`
Load weights directly at the target floating-point precision. Each
state-dict tensor is cast during the `safe_open` read and the
random-init model shell is pre-cast via `instance.model.to(dtype)`
before `load_state_dict`, so a full fp32 snapshot never co-exists
with the loaded weights. Accepts strings (`"bf16"`, `"fp16"`,
`"float32"`, ...) or a floating-point `torch.dtype`; non-floating
dtypes (e.g. `torch.int8`) are rejected up front with a message
pointing at `quantize="int8"` for int paths. Int / bool buffers are
preserved in the state dict.
Memory impact: for CPU-only loads peak drops from ~2x fp32 to ~1x
fp32; for `map_location="cuda"`, the saving is avoiding a
simultaneous fp32 GPU state dict + fp32 GPU model plus the separate
post-load cast pass. Matches the `dtype=` surface on
`transformers.PreTrainedModel.from_pretrained` (string or
`torch.dtype`, same semantics), so users coming from HF get a
familiar API.
Primary target: cold starts and scalable serverless deployments
(Lambda, Cloud Run, Modal, RunPod serverless, autoscaled k8s) where
startup latency and peak memory drive cost and SLA.
The Ray-Serve layer (`gliner.serve.GLiNERFactory`) is wired through:
it now passes `map_location` + `dtype` to `from_pretrained` instead
of doing a post-load `.to(device=..., dtype=...)` cast, so serving
cold starts get the same peak-memory win.
2. Restrict `quantize=` to `"int8"` only
Previously `quantize=` accepted `True`, `"fp16"`, `"bf16"`, and
`"int8"`. Three of the five effective rows (GPU fp16, GPU bf16, CPU
bf16) were just `.to(dtype)` with an extra fp32 intermediate — not
quantization. A fourth (CPU fp16) wrapped `nn.Linear` with
dynamic-quantized variants but had no documented speed benefit and
was asymmetric with the other CPU/GPU combinations.
All non-int8 values now raise `ValueError` with a migration message
pointing at `dtype=` / `model.to(torch_dtype)`. `quantize="int8"`
is unchanged: torchao int8 weight-only on GPU, FBGEMM dynamic
quant on CPU. The signature narrows from `Union[bool, str] = False`
to `Optional[str] = None`; `model.quantize()` defaults to `"int8"`.
The serving CLI (`gliner.serve --quantization`) is narrowed in
lockstep to `{int8, None}`; precision stays under `--dtype`.
Docs in `docs/usage.md` and `docs/serving.md` are updated:
- New "Reduced-precision loading (`dtype`)" section in usage.md.
- "Quantization, Compilation & FlashDeBERTa" now shows `dtype="fp16"`
as the only half-precision path.
- Serving CLI reference and env-var table reflect the int8-only
`--quantization` / `GLINER_QUANTIZATION` surface.
New tests in `tests/test_quantize_and_dtype.py` (50 cases) cover
`_parse_dtype`, `_load_state_dict` cast-on-read, and the int8-only
`model.quantize(...)` validation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
9f551c0 to
b7b819c
Compare
4 tasks
Ingvarstep
pushed a commit
that referenced
this pull request
Apr 26, 2026
Addresses 35 pre-existing test failures that pre-dated #348 (all confirmed on `f6137fb` prior to my work) across 6 root causes: **decoder tests (15) — `Span` dataclass vs tuple mismatch** Commit 18e4237 replaced the tuple output of SpanDecoder/TokenDecoder etc. with a `Span` dataclass, but the tests were never updated. Tests previously unpacked spans positionally (`start, end, entity_type, score = span`, `len(span) == 4`, `span[i]`). Updated to use attribute access (`span.start`, `span.entity_type`, `span.generated_labels`), and `TestGreedySearch` now builds `Span(...)` objects instead of raw tuples (production `greedy_search` does `sorted(spans, key=lambda x: -x.score)` and requires `Span` objects). **data_processing tests (16) — Mock/attribute + API drift** - `mock_config` was bare `Mock()`, so `config.precomputed_prompts_mode` and `config.neg_spans_ratio` returned truthy Mocks; `precomputed_prompts` silently enabled a fast path that emits `[SEP] tokens` instead of `[ENT] label [SEP] tokens`, and `neg_spans_ratio` broke int * Mock arithmetic. Pinned both to real values (`False`, `0.0`). - `UniEncoderTokenProcessor.create_labels` was refactored to take a batch dict; tests still called the old 4-positional API. Rewrote via a `_make_create_labels_batch` helper. - `preprocess_example` no longer emits `entities_id`; produces `span_idx`/`span_label` tensors (or `None` for empty NER). - `create_batch_dict` now requires `span_idx` in batch items (even if `None` to disable span padding). - `create_relation_labels` now needs `span_mask`; `rel_class_to_ids` is list-per-example. **modeling tests (3) — shape drift from compile-sync optimization** `select_decoder_embedding` / `select_target_embedding` were changed in commit 84994b3 to skip the `lengths.max().item()` GPU→CPU sync under eager execution, keeping the full N-width output (padding stays zero via `target_mask`). Tests expected the old truncated shape (B, max_len, D); updated to (B, N, D) with padding-position checks. `rel_loss` tests passed `rel_mask=(B, P, 1)` and `class_mask=(B, C)` — broken by the current `(B, P, C) * (B, P, C)` contract. Fixed both masks to match documented shapes. **test collection — `from tests.utils_infer` import** `tests/test_infer_packing.py` uses the absolute import `from tests.utils_infer import ...`, which fails under plain `pytest tests/` because the repo root wasn't on `sys.path`. Added `[tool.pytest.ini_options]` with `pythonpath = ["."]` and `testpaths = ["tests"]` to `pyproject.toml` — plain `pytest` from repo root now works. **CI: `.github/workflows/tests.yml`** New workflow runs pytest on every PR to main and every push to main (plus `workflow_dispatch`). Matrix: Python 3.10 and 3.12. Parallel `ruff` job scoped to files changed in recent work (full-repo lint is a separate cleanup — the repo has pre-existing ruff findings outside our ownership). Concurrency group cancels in-progress runs on new pushes. To enforce that broken tests can't be merged, the maintainer should mark this workflow as a required status check in the branch protection rules for `main` (Settings → Branches → Branch protection rules). CI alone reports failures; it doesn't block merge without that setting. Final suite: 265 passed, 1 skipped, 0 failed on Python 3.12 locally (up from 230 passed / 35 failed / 1 error before). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Important
No released wheel is affected. The current PyPI release —
gliner 0.2.26, uploaded 2026-03-19 — pre-dates PR #342 (merged 2026-03-31), which introduced thequantize=surface. So removing the pure-downcast aliases here only reaches users installing frommain/ a git SHA. No one onpip install glinerneeds to change anything.Summary
Two related changes to the precision / quantization surface, landed together because they form one coherent story:
dtype=toGLiNER.from_pretrained— load weights directly at the target floating-point precision. Casts each state-dict tensor during thesafe_openread and pre-casts the model shell beforeload_state_dict, so a full fp32 copy is never materialized. Accepts"bf16"/"fp16"/"float32"/... or a floating-pointtorch.dtype; non-floating dtypes are rejected up front with a clear error that points atquantize="int8"for int paths. Int/bool buffers are left untouched.quantize=to"int8"only (per @Ingvarstep's review) — drops the pure-downcast aliases (True,"fp16","bf16") entirely; those now raiseValueErrorwith a migration message pointing atdtype=/model.to(...).quantize="int8"is unchanged: torchao int8 weight-only on GPU, FBGEMM dynamic quant on CPU.Why
dtype=mattersPreviously
from_pretrainedalways loaded the state dict at its on-disk precision (typically fp32), copied it into an fp32 model, then required a separatequantize(...)pass to cast down. Two copies co-resident, plus a post-load cast pass. For users who just want bf16/fp16 inference, that was ~2× the peak memory and an extra pass.This is particularly useful for cold starts and scalable serverless deployments (AWS Lambda, Cloud Run, Modal, RunPod serverless, autoscaled Kubernetes pods, etc.) where startup latency and peak memory directly drive cost and SLA:
For CPU-only loads the peak-memory drop is a clean ~2×→1× fp32. For
map_location="cuda", state-dict tensors stream to GPU while the shell is CPU-side, so the win is avoiding a simultaneous fp32 GPU state dict + fp32 GPU model (not literal halving of total footprint) plus the separate cast pass.Why restrict
quantize=to int8Before this PR the
quantize=surface was doing two different things under one name:quantize=True/"fp16"model.half()True/"fp16"quantize_dynamic(Linear, fp16)"bf16"model.bfloat16()"bf16"model.bfloat16()"int8""int8"quantize_dynamic(Linear, qint8)(FBGEMM)Three rows are just
.to(dtype)with an extra fp32 intermediate — exactly whatdtype=now does cheaply. A fourth (CPU fp16 dynamic quant) had no documented speed benefit and was asymmetric with every other CPU/GPU combination. The remaining two int8 rows are the only genuinely distinct behavior. After this PR:quantize=True→ValueError(pointing atquantize='int8'ordtype=).quantize="fp16" | "bf16"(any device) →ValueError(pointing atdtype=ormodel.to(torch.X)).quantize="int8"→ unchanged: torchao on GPU, FBGEMM dynamic quant on CPU.model.quantize(...)takes the same restriction and defaults to"int8"instead of"fp16".Signature narrows from
Union[bool, str] = FalsetoOptional[str] = None. See the callout at the top for why the blast radius is negligible.Ray Serve integration (
gliner.serve)GLiNERFactoryand thegliner.serveCLI are wired through in lockstep:gliner/serve/server.pynow passesmap_location=config.device+dtype=self.torch_dtypetoGLiNER.from_pretrainedinstead of loading at fp32 and then doing a post-load.to(device=..., dtype=...)cast. Serving cold starts get the same peak-memory win as direct users — which is exactly the autoscaled / serverless use case the core change targets.--quantizationCLI choices narrowed from{fp16, bf16, int8, None}to{int8, None}, matching the core API. Precision stays under--dtype(which already acceptedfloat32/float16/fp16/bfloat16/bf16).docs/serving.mdCLI reference +GLINER_QUANTIZATIONenv-var row updated.--quantization fp16or--quantization bf16after the int8 restriction would have hitValueErrorat server boot. The serve layer is also post-0.2.26 (PR Feature/serving #346), so no released wheel is affected.Usage
Alignment with HuggingFace
from_pretrainedThis puts GLiNER on the same axis as HuggingFace
transformers.PreTrainedModel.from_pretrained, which recently settled on the same split:dtype=(renamed from the now-legacytorch_dtype=), accepting a string ortorch.dtype. We use the same parameter name and input surface.quantization_config=BitsAndBytesConfig(...)/AutoGPTQConfig/AutoAWQConfig. Ourquantize="int8"plays the same role; restrictingquantize=to int8 mirrors HF's decision not to overload the quantization knob for precision changes.Users coming from HF will find
GLiNER.from_pretrained(..., dtype="bf16")immediately familiar.Known gap (possible follow-up, not in this PR): HF supports
low_cpu_mem_usage=True/device_map=, which uses PyTorch's meta device to skip allocating the random-init shell entirely and then materializes weights viaload_state_dict(assign=True). That's strictly cheaper than our pre-cast approach (~0 extra bytes for the shell vs ~1× target-dtype model size for ours). It would mean constructing the backbone ontorch.device("meta")rather than viafrom_config, plus switching toassign=Trueloading — a bigger surgery than this PR. Worth chasing for the last slice of cold-start memory if it matters.Tests
New
tests/test_quantize_and_dtype.py— 50 unit tests covering:_parse_dtype:Nonepassthrough; all string aliases (fp16/float16/half/bf16/bfloat16/fp32/float32/float, case-insensitive);torch.dtypepassthrough; unknown strings →ValueError; non-floatingtorch.dtype(int8/int32/int64/bool/uint8) →ValueErrorpointing atquantize='int8'; non-str/non-dtype →TypeError._load_state_dict: default path leaves tensors at stored dtype;dtype=casts floats on read in both the safetensors andtorch.loadbranches; int/bool buffers preserved; idempotent when stored dtype already matches.model.quantize():"int8"(case-insensitive) routes to_apply_int8_quantizationon both CPU and CUDA; precision aliases (fp16/float16/half/bf16/bfloat16) on both devices →ValueErrorwith migration message naming the correcttorch.float16/torch.bfloat16replacement; unknown strings →ValueError; non-string inputs (True,False,None, ints, floats) →TypeError; ONNX model →RuntimeError(and the ONNX check runs before the alias check).from_pretrained(quantize=True): rejected via the shared guard inmodel.quantize().Heavier end-to-end tests (real hub checkpoint load with
dtype="bf16"+ inference parity) are intentionally out of scope for the unit suite — worth a manual smoke before merge.Test plan
pytest tests/test_quantize_and_dtype.py→ 50/50 pass locally.ruff check/ruff format --checkclean on new and modified files.dtype="bf16"; confirm inference parity with the legacyquantize="bf16"path.🤖 Generated with Claude Code