Skip to content

prototype quant_logger tool for logging weights and activations#3987

Merged
vkuzo merged 5 commits intomainfrom
gh/vkuzo/225/head
Mar 5, 2026
Merged

prototype quant_logger tool for logging weights and activations#3987
vkuzo merged 5 commits intomainfrom
gh/vkuzo/225/head

Conversation

@vkuzo
Copy link
Copy Markdown
Contributor

@vkuzo vkuzo commented Mar 4, 2026

Summary:

Simple tool to log activation and weight statistics and shapes. No dependencies other than PyTorch, advanced use cases are left to the user to implement via overrides. Features included in this PR:

  • logs fqn, MKN shape, max, avg, std to stdout by default
  • convenience functions are provided to log tensors to disk or to save stats to csv instead or printing to stdout
  • user can also override the logging function to do anything they want
  • only supports F.linear for now, but can be extended to other ops in the future. The design wraps weights, so will only work for ops where one of the args is a weight
  • compile is not supported (but should work)
  • distributed is not supported (but probably can be added later)
  • quantized models are not supported (and no plans to include this, to keep complexity low)
  • serialization not supported (and no plans to include this, to keep complexity low)

summary of proposed API:

# the op to log stats for a tensor to stdout, user can override this to log anything they want
@torch.library.custom_op("quant_logger::log_tensor", mutates_args=("x",))
def log_tensor(
    x: torch.Tensor, fqn: str, op: str, tag: str, extra: str | None = None
) -> None: ...

# convenience utils to log entire tensors to files, or to log stats to csv instead of stdout
def enable_log_tensor_save_tensors_to_disk(save_dir): ...
def enable_log_stats_to_file(filename): ...

# user facing API
# 1. log param info
def log_parameter_info(model: torch.nn.Module): ...

# 2. add F.linear activation logger weight wrappers to Linear modules
def add_activation_loggers(model: torch.nn.Module): ...
# for (2), user then needs to run data through the model to see activation logs

usage example on a diffusers model:

import torch                                                                                                                                       
from diffusers import DiffusionPipeline                                                                                                                    
from torchao.prototype.quant_logger.quant_logger import add_activation_loggers, log_parameter_info, reset_counter    
                                                                                                                          
# Load model                                                                                                        
pipe = DiffusionPipeline.from_pretrained(                                                                                       
    "black-forest-labs/FLUX.1-schnell",                                                                                                                      
    torch_dtype=torch.bfloat16,                                                                                                                            
).to("cuda")                                                                                                                                       
                                                                                                                                
# Log parameter statistics                                                                                          
print("=" * 70)                                                                                                                               
print("Parameter statistics:")                                                                                                     
print("=" * 70)                                                                                                                                   
log_parameter_info(pipe.transformer)                                                                                               
                                                                                                                                                        
# Reset logging counter before logging activations                                                                  
reset_counter()                                                                                                                                                       
                                                                                                                                   
# Add activation loggers                                                                                            
add_activation_loggers(pipe.transformer)                                                                                                    
                                                                                                                                                            
# Generate one image                                                                                                
print("=" * 70)                                                                                                                         
print("Activation statistics during inference:")                                                                                   
print("=" * 70)                                                                                                                                                     
result = pipe(                                                                                                                     
    prompt="A cat holding a sign that says hello world",                                                                          
    height=1024,                                                                                                                                                            
    width=1024,                                                                                                                     
    # `num_inference_steps` is usually 4 for FLUX.1-schnell, but set to 1                                                          
    # for the purposes of this demo                                                                                                                                                          
    num_inference_steps=1,                                                                                                                             
    generator=torch.manual_seed(0),                                                                                                           
)                                                                                                                                                      

output of running the example above:

...
======================================================================                                                                                                                                                         
Parameter statistics:                                                                                                                                                                                                          
======================================================================                                                                                                                                                         
t=param, c=0, fqn='time_text_embed.timestep_embedder.linear_1.weight', op='', max=0.29, avg=-0.00, std=0.01                                                                                                                    
t=param, c=1, fqn='time_text_embed.timestep_embedder.linear_1.bias', op='', max=0.18, avg=-0.00, std=0.01                                                                                                                      
t=param, c=2, fqn='time_text_embed.timestep_embedder.linear_2.weight', op='', max=1.16, avg=0.00, std=0.04                                                                                                                     
t=param, c=3, fqn='time_text_embed.timestep_embedder.linear_2.bias', op='', max=1.11, avg=-0.11, std=0.27                                                                                                                      
t=param, c=4, fqn='time_text_embed.text_embedder.linear_1.weight', op='', max=0.27, avg=0.00, std=0.02                                                                                                                         
t=param, c=5, fqn='time_text_embed.text_embedder.linear_1.bias', op='', max=0.16, avg=-0.01, std=0.01                                                                                                                          
t=param, c=6, fqn='time_text_embed.text_embedder.linear_2.weight', op='', max=1.19, avg=0.01, std=0.06                                                                                                                         
t=param, c=7, fqn='time_text_embed.text_embedder.linear_2.bias', op='', max=1.11, avg=-0.11, std=0.27                                                                                                                          
t=param, c=8, fqn='context_embedder.weight', op='', max=0.73, avg=-0.00, std=0.05                                                                                                                                              
t=param, c=9, fqn='context_embedder.bias', op='', max=0.56, avg=-0.00, std=0.05                                                                                                                                                
t=param, c=10, fqn='x_embedder.weight', op='', max=0.57, avg=0.00, std=0.05                                                                                                                                                    
t=param, c=11, fqn='x_embedder.bias', op='', max=0.29, avg=0.01, std=0.02                                                                                                                                                      
t=param, c=12, fqn='transformer_blocks.0.norm1.linear.weight', op='', max=0.40, avg=0.00, std=0.03                                                                                                                             
t=param, c=13, fqn='transformer_blocks.0.norm1.linear.bias', op='', max=0.20, avg=-0.02, std=0.03                                                                                                                              
t=param, c=14, fqn='transformer_blocks.0.norm1_context.linear.weight', op='', max=0.55, avg=0.00, std=0.04                                                                                                                     
t=param, c=15, fqn='transformer_blocks.0.norm1_context.linear.bias', op='', max=0.31, avg=-0.01, std=0.03
...
======================================================================                                                                                                                                                         
Activation statistics during inference:                                                                                                                                                                                        
======================================================================                                                                                                                                                         
  0%|                                                                                                                                                                                                    | 0/1 [00:00<?, ?it/s]
t=act, c=0, fqn='x_embedder.weight', op='linear', extra='MKN=4096|64|3072', max=3.33, avg=0.00, std=1.00                                                                                                                       
t=act, c=1, fqn='time_text_embed.timestep_embedder.linear_1.weight', op='linear', extra='MKN=1|256|3072', max=1.00, avg=0.20, std=0.68
t=act, c=2, fqn='time_text_embed.timestep_embedder.linear_2.weight', op='linear', extra='MKN=1|3072|3072', max=0.78, avg=-0.01, std=0.05
t=act, c=3, fqn='time_text_embed.text_embedder.linear_1.weight', op='linear', extra='MKN=1|768|3072', max=3.98, avg=-0.11, std=1.00
t=act, c=4, fqn='time_text_embed.text_embedder.linear_2.weight', op='linear', extra='MKN=1|3072|3072', max=0.35, avg=-0.01, std=0.05  
t=act, c=5, fqn='context_embedder.weight', op='linear', extra='MKN=512|4096|3072', max=6.50, avg=0.00, std=0.12                       
t=act, c=6, fqn='transformer_blocks.0.norm1.linear.weight', op='linear', extra='MKN=1|3072|18432', max=2.11, avg=0.00, std=0.06       
t=act, c=7, fqn='transformer_blocks.0.norm1_context.linear.weight', op='linear', extra='MKN=1|3072|18432', max=2.11, avg=0.00, std=0.06
t=act, c=8, fqn='transformer_blocks.0.attn.to_q.weight', op='linear', extra='MKN=4096|3072|3072', max=13.88, avg=0.00, std=0.32     
t=act, c=9, fqn='transformer_blocks.0.attn.to_k.weight', op='linear', extra='MKN=4096|3072|3072', max=13.88, avg=0.00, std=0.32      
t=act, c=10, fqn='transformer_blocks.0.attn.to_v.weight', op='linear', extra='MKN=4096|3072|3072', max=13.88, avg=0.00, std=0.32
t=act, c=11, fqn='transformer_blocks.0.attn.add_q_proj.weight', op='linear', extra='MKN=512|3072|3072', max=15.31, avg=-0.00, std=0.33       
t=act, c=12, fqn='transformer_blocks.0.attn.add_k_proj.weight', op='linear', extra='MKN=512|3072|3072', max=15.31, avg=-0.00, std=0.33  
t=act, c=13, fqn='transformer_blocks.0.attn.add_v_proj.weight', op='linear', extra='MKN=512|3072|3072', max=15.31, avg=-0.00, std=0.33
t=act, c=14, fqn='transformer_blocks.0.attn.to_out.0.weight', op='linear', extra='MKN=4096|3072|3072', max=1.20, avg=-0.00, std=0.12    
t=act, c=15, fqn='transformer_blocks.0.attn.to_add_out.weight', op='linear', extra='MKN=512|3072|3072', max=6.97, avg=-0.00, std=0.34
t=act, c=16, fqn='transformer_blocks.0.ff.net.0.proj.weight', op='linear', extra='MKN=4096|3072|12288', max=3.64, avg=0.00, std=0.14
t=act, c=17, fqn='transformer_blocks.0.ff.net.2.weight', op='linear', extra='MKN=4096|12288|3072', max=2.53, avg=-0.00, std=0.13
t=act, c=18, fqn='transformer_blocks.0.ff_context.net.0.proj.weight', op='linear', extra='MKN=512|3072|12288', max=11.31, avg=-0.00, std=0.45
t=act, c=19, fqn='transformer_blocks.0.ff_context.net.2.weight', op='linear', extra='MKN=512|12288|3072', max=40.00, avg=0.00, std=0.29
t=act, c=20, fqn='transformer_blocks.1.norm1.linear.weight', op='linear', extra='MKN=1|3072|18432', max=2.11, avg=0.00, std=0.06      
...

full output:

(pt_nightly) dev@gpu-dev-d82835d5:~/ao (20260304_quant_logger)$ python docs/source/examples/prototype/quant_logger/example.py 2>&1 | gh gist create
- Creating gist...
✓ Created secret gist
https://gist.github.com/vkuzo/1fffca0974d1f59099f3c0d16a3a1834

I want to check this in so we can have a reproducible tool for getting activation shapes on a blog we are working on with @sayakpaul for diffusion quantization with mxfp8 and nvfp4. This should be useful in general for various quantization debugging use cases.

Prototype folder, so no BC for now. Once we have alignment on general interface, I will also add docs as a part of this PR.

Test Plan:

pytest test/prototype/quant_logger -s

[ghstack-poisoned]
@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Mar 4, 2026

vkuzo added a commit that referenced this pull request Mar 4, 2026
Summary:

Test Plan:
ghstack-source-id: 03be399
ghstack-comment-id: 3998362506
Pull-Request: #3987
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 4, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3987

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 8 Pending

As of commit dcf4af7 with merge base d6d423e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2026
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 4, 2026
Summary:

Test Plan:
ghstack-source-id: 0b41669
ghstack-comment-id: 3998362506
Pull-Request: #3987
@vkuzo vkuzo added the module: core changes affecting multiple modules, e.g. base config/tensor, observers, quant ops label Mar 4, 2026
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 4, 2026
Summary:

Test Plan:
ghstack-source-id: 11c9a58
ghstack-comment-id: 3998362506
Pull-Request: #3987
vkuzo added a commit that referenced this pull request Mar 4, 2026
Summary:

Using the logs from #3987, we now know
that FQNs that look like
`single_transformer_blocks.37.norm.linear.weight` have small shapes,
gate them out.

Test Plan:

TODO run new perf/accuracy bench after this
ghstack-source-id: 5e858f8
ghstack-comment-id: 3999281890
Pull-Request: #3989
@danielvegamyhre
Copy link
Copy Markdown
Contributor

Seems useful to me, don't have any major feedback or concerns

@sayakpaul
Copy link
Copy Markdown
Contributor

sayakpaul commented Mar 5, 2026

Thanks this is cool and in general, it will be useful to filter layers from quantization.

Thoughts / questions:

  • add_activation_loggers -- log_parameter_info -- it's counterintuitive to me why can't it be called log_activation_info similar to log_parameter_info.
  • What does c mean in the logs?
  • Would this work when a module is compiled?
  • Would it be too much work to provide a utility to summarize the activation and parameter statistics? An example below:
Layer Group Count Typical MKN max(max) mean(std) Outliers
embedders 3 4096×64×3072 (var) 6.50 0.37 context(max=6.50)
norms 4 1×3072×18432 2.11 0.06 -
attn (q/k/v) 6 4096×3072×3072 (var) 15.31 0.32 -
attn (out) 2 4096×3072×3072 (var) 6.97 0.23 add_out(max=6.97)
ff 4 4096×3072×12288 (var) 3.64 0.14 -
ff_context 4 512×3072×12288 (var) 40.00 0.37 net.2(max=40.00)

@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Mar 5, 2026

thanks @sayakpaul , here are my thoughts. Also, once we finalize on the design I do plan to include doc site content and a short tutorial in this PR to explain everything better.

add_activation_loggers -- log_parameter_info it's counterintuitive to me why can't it be called log_activation_info similar to log_parameter_info.

Currently they are named differently because add_activation_loggers requires the user to do an additional step of feeding the data through the model to see the logs. We theoretically could add one more API log_activation_info which would require an example datum provided by the user. I leaned on not doing that to keep things simple, since we'll still need the add_activation_loggers API for cases when the user wants to instrument a single building block and it's annoying to figure out the shape of the input to the building block. Example below. Lmk your thoughts.

# log parameter info - one liner
log_parameter_info(model)

# log activations - two+ lines
add_activation_loggers(model)
for datum in dataset:
    model(datum)
    
# potential API we could add
# would call `model(datum)` under the hood, less flexible
log_activation_info(model, datum)

What does c mean in the logs?

c is a counter which increments by 1 with every log, this is useful for models which have a loop. c instead of counter to save chars on the screen, I do plan to explain this clearly in a docblock/docs site. Happy to hear feedback.

Would this work when a module is compiled?

Basic things did work but I did not test compile very thoroughly, I deleted compile tests before cleaning up the code to keep things simple. Would be interested to hear what use cases would need compile support for things like numerical debugging.

Would it be too much work to provide a utility to summarize the activation and parameter statistics?

I think this is very easy for a single model, and unclear how to generalize to arbitrary models as the definition of "layer group", "outlier", etc can all change. Thoughts on if we provide an example of this for flux.1 schnell in the tutorial, and ensure that it's easy to modify to adjust to other models?

@sayakpaul
Copy link
Copy Markdown
Contributor

Currently they are named differently because add_activation_loggers requires the user to do an additional step of feeding the data through the model to see the logs. We theoretically could add one more API log_activation_info which would require an example datum provided by the user. I leaned on not doing that to keep things simple, since we'll still need the add_activation_loggers API for cases when the user wants to instrument a single building block and it's annoying to figure out the shape of the input to the building block. Example below. Lmk your thoughts.

Makes sense. Your current reasoning explains it. No issues.

Basic things did work but I did not test compile very thoroughly, I deleted compile tests before cleaning up the code to keep things simple. Would be interested to hear what use cases would need compile support for things like numerical debugging.

SG!

Thoughts on if we provide an example of this for flux.1 schnell in the tutorial, and ensure that it's easy to modify to adjust to other models?

Works for me!

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 5, 2026
Summary:

Test Plan:
ghstack-source-id: 1d95f7f
ghstack-comment-id: 3998362506
Pull-Request: #3987

import torch

counter = [0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we use list[int] for counter?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because counter[0] is mutable

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 5, 2026
Summary:

Test Plan:
ghstack-source-id: 7b89a56
ghstack-comment-id: 3998362506
Pull-Request: #3987
@vkuzo vkuzo merged commit f97ec45 into main Mar 5, 2026
62 of 70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: core changes affecting multiple modules, e.g. base config/tensor, observers, quant ops

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants