Skip to content

Conversation

@andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Nov 21, 2025

Summary: Add PerBlock to safe globals so users don't have to do this themselves when they load config.json with PerBlock.

WeightsUnpickler error: Unsupported global: GLOBAL torchao.quantization.granularity.PerBlock was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchao.quantization.granularity.PerBlock])` or the `torch.serialization.safe_globals([torchao.quantization.granularity.PerBlock])` context manager to allowlist this global if you trust this class/function.

Test Plan:

python test/core/test_config.py -k test_granularity_serialization

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3370

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 21, 2025
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Nov 21, 2025
@vkuzo
Copy link
Contributor

vkuzo commented Nov 21, 2025

I think we should

  1. put this in the same place where the other granularities are exposed for serialization
  2. add a test (same as we should have for other granularities)

@andrewor14
Copy link
Contributor Author

put this in the same place where the other granularities are exposed for serialization

Today these are added in observer.py, I think it just so happens that they're called when loading that file indirectly: https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py#L356

Should I move all of these to granularity.py to be more explicit?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Nov 21, 2025

put this in the same place where the other granularities are exposed for serialization

Today these are added in observer.py, I think it just so happens that they're called when loading that file indirectly: main/torchao/quantization/observer.py#L356

Should I move all of these to granularity.py to be more explicit?

yeah I think just adding all used granularity would be good, and remove these from observer.py

@andrewor14 andrewor14 force-pushed the per_block_safe_globals branch 2 times, most recently from 2a12d58 to fd31a98 Compare November 21, 2025 22:47
```
import torch
import torch.nn as nn
from torchao.quantization.granularity import PerTensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just import torchao.quantization import PerTensor is better I think

Int8WeightOnlyConfig,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.granularity import PerRow, PerTensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove granularity

Comment on lines +187 to +194
code = f"""
import torch
import torchao
_ = torch.load('{fname}', weights_only=True)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these instead of just code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because by the time we run the test we have already imported everything. This starts a fresh environment and shows you only need to import torchao for loading to work. Copied from: https://github.com/pytorch/ao/blob/main/test/prototype/mx_formats/test_mx_serialization.py#L36

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg, see comments inline

@andrewor14 andrewor14 force-pushed the per_block_safe_globals branch 2 times, most recently from bfba7fc to 150ac89 Compare November 21, 2025 23:00
andrewor14 added a commit to andrewor14/unsloth that referenced this pull request Nov 21, 2025
**Summary:** Following unslothai#3440,
this PR extends torchao FP8 + RL support to also handle 128x128
PerBlock granularity (in addition to PerRow).

**Example usage:**

```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)
```

**Initial results:** TBD

**Note:**
- Requires pytorch/ao#3370
danielhanchen pushed a commit to unslothai/unsloth that referenced this pull request Nov 22, 2025
* Add 128x128 PerBlock FP8 + RL

**Summary:** Following #3440,
this PR extends torchao FP8 + RL support to also handle 128x128
PerBlock granularity (in addition to PerRow).

**Example usage:**

```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)
```

**Initial results:** TBD

**Note:**
- Requires pytorch/ao#3370

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@andrewor14 andrewor14 force-pushed the per_block_safe_globals branch from 150ac89 to a71f121 Compare November 22, 2025 23:56
**Summary:** Add PerBlock to safe globals so users don't have
to do this themselves when they load config.json with PerBlock.

```
WeightsUnpickler error: Unsupported global: GLOBAL torchao.quantization.granularity.PerBlock was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchao.quantization.granularity.PerBlock])` or the `torch.serialization.safe_globals([torchao.quantization.granularity.PerBlock])` context manager to allowlist this global if you trust this class/function.
```

**Test Plan:**
```
python test/core/test_config.py -k test_granularity_serialization
```
@andrewor14 andrewor14 force-pushed the per_block_safe_globals branch from a71f121 to c2daf56 Compare November 23, 2025 02:02
@andrewor14
Copy link
Contributor Author

Test failures don't seem related. Thanks, merging this

@andrewor14 andrewor14 merged commit b55713a into main Nov 23, 2025
16 of 19 checks passed
danielhanchen added a commit to unslothai/unsloth that referenced this pull request Nov 25, 2025
* Enable FP8 + RL training for bf16 models (#3440)

* Enable FP8 + RL training for bf16 models

**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage:
- We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16
- We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel
- For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet  (this is in progress: vllm-project/vllm#26327)

**Example usage:**
```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = True,  # set this to True
)

\# the rest is the same as before
model = FastLanguageModel.get_peft_model(...)
```

**Initial results:**
```
\# fp8
{'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01}

\# bf16
{'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01}
```

<img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" />

Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

**Requires:**
- pytorch/ao#3158 (torchao nightly or 0.15.0+)
- unslothai/unsloth-zoo#351

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* _get_inference_mode_context_manager

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update __init__.py

* Fix/save torchao model loading logic (#3621)

* make loading gpt-oss-BF16 faster. Linked to unsloth-zoo PR #314

* fix model loading and clean merged model directory

* revert default quant

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert mapper.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Update loader_utils.py

* Update loader_utils.py

* Add 128x128 PerBlock FP8 + RL (#3629)

* Add 128x128 PerBlock FP8 + RL

**Summary:** Following #3440,
this PR extends torchao FP8 + RL support to also handle 128x128
PerBlock granularity (in addition to PerRow).

**Example usage:**

```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)
```

**Initial results:** TBD

**Note:**
- Requires pytorch/ao#3370

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Version

* Update vision.py

* Update rl.py

* Add torch 2.9.1

* Fix auto installer

* Update fp8.py

* Float8

* Update fp8.py

* Update mapper.py

* Update mapper.py

* Update loader_utils.py

* Update loader.py

* Update fp8.py

* Versioning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: andrewor14 <andrewor14@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
namgyu-youn pushed a commit to namgyu-youn/ao that referenced this pull request Nov 25, 2025
**Summary:** Add PerBlock to safe globals so users don't have
to do this themselves when they load config.json with PerBlock.

```
WeightsUnpickler error: Unsupported global: GLOBAL torchao.quantization.granularity.PerBlock was not an allowed global by default. Please use `torch.serialization.add_safe_globals([torchao.quantization.granularity.PerBlock])` or the `torch.serialization.safe_globals([torchao.quantization.granularity.PerBlock])` context manager to allowlist this global if you trust this class/function.
```

**Test Plan:**
```
python test/core/test_config.py -k test_granularity_serialization
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants