Skip to content

Add pruning-aware training in torchao.prototype.pat#3429

Merged
lisjin merged 3 commits intomainfrom
lvj/pat
Feb 23, 2026
Merged

Add pruning-aware training in torchao.prototype.pat#3429
lisjin merged 3 commits intomainfrom
lvj/pat

Conversation

@lisjin
Copy link
Copy Markdown
Contributor

@lisjin lisjin commented Dec 3, 2025

Adding our pruning-aware training (PAT) library as a prototype. The original library is under fairinternal/qpat but we would like to surface it in torchao for broader adoption.

The interface is almost identical to torchao.prototype.parq, but we use (group) Lasso instead of piecewise-affine regularization. More details on code organization and usage can be found in the README.

@lisjin lisjin requested a review from andrewor14 December 3, 2025 21:54
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Dec 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3429

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1c753dc with merge base d988122 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 3, 2025
@lisjin lisjin added the topic: new feature Use this tag if this PR adds a new feature label Dec 3, 2025
@lisjin lisjin force-pushed the lvj/pat branch 2 times, most recently from ffa338e to 4f78b65 Compare December 8, 2025 14:08
@lisjin
Copy link
Copy Markdown
Contributor Author

lisjin commented Dec 8, 2025

@andrewor14 Let me know if anything needs to be cleared up in this diff. I'm hoping to update D88501706 so that it imports from torchao.prototype.pat instead of copying code.

@meta-codesync
Copy link
Copy Markdown

meta-codesync bot commented Dec 8, 2025

@lisjin has imported this pull request. If you are a Meta employee, you can view this in D88638093.

a base optimizer (e.g., SGD or AdamW)
- update the latent variables for QAT
Other parameters:
warmup_steps: int >= 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the central API right, can we add an example usage in this docstring?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call—I updated the README example to include keyword args like warmup_steps and reg_lambda

return out


class MaskedLayerNorm(nn.LayerNorm):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this is not used anywhere other than in tests. Can we delete this? Am I missing something?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hoping to keep this class since it's important for converting pruned models to their compressed inference-ready forms. This functionality can be added to PAT in the future.

from .pruneopt import PruneOptimizer


class NMSGDOptimizer(PruneOptimizer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: I notice a lot of APIs in this PR that are not used or referenced anywhere. Are these all user-facing APIs? If so can we document them somewhere (e.g. main README) and explain how they're related to the main PruneOptimizer API? If they're not user-facing APIs and they're not used, do we still need them?

Some examples:

  • NMSGDOptimizer
  • ProxNuclearNorm
  • all the groupers like QKSVDGrouper

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The NMSGDOptimizer was written by a summer intern last year and has shown promising results. Since it's an experimental feature, we don't have unit tests for it yet.
  • ProxNuclearNorm is important for applying low-rank pruning to embeddings. Here's an example config.
  • The other groupers are more experimental. It would be great to keep them around so that we can stay in sync with the original repo, but I can also remove them if you'd like.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we can keep them if we document them somewhere. If they're experimental we can mark them as such in the README. In general public APIs should have associated documentation somewhere, otherwise users won't be able to find them

@andrewor14
Copy link
Copy Markdown
Contributor

Hi @lisjin looks good overall. My main comment is just my confusion about how the APIs are used, seems like the code snippet in the main README only references 1 or 2 of these, so it's unclear to me how the rest are related. Would be great if you can clarify this in documentation.

Separately do you have any initial results? If so, would be great to include these in the README too.

@lisjin lisjin force-pushed the lvj/pat branch 3 times, most recently from 73a572c to 71c2270 Compare February 12, 2026 15:58
@lisjin
Copy link
Copy Markdown
Contributor Author

lisjin commented Feb 12, 2026

@andrewor14 Thanks for taking the time to review this back in Dec! I found out in January that the team I was collaborating with no longer needed to use PAT in torchao. However, now @Ninja91 and his team are planning to experiment with PAT. Could you please check that my fixes addressed all your comments? I've also added some initial results on unstructured pruning to the README.

{
"params": weights",
"group_type": "pat.group.Dim0Grouper",
"prox_type": "pat.prox.ProxGroupLasso",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these take in actual classes instead of strings of classes? Seems like it'll be more robust

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah this usage is actually outdated. I updated it a while back to accept strings like "Dim0Grouper" and "ProxGroupLasso" so that there's no dependency on import structure. The README is fixed to reflect this.

from .pruneopt import PruneOptimizer


class NMSGDOptimizer(PruneOptimizer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we can keep them if we document them somewhere. If they're experimental we can mark them as such in the README. In general public APIs should have associated documentation somewhere, otherwise users won't be able to find them

@lisjin
Copy link
Copy Markdown
Contributor Author

lisjin commented Feb 13, 2026

@andrewor14 Thanks for the suggestions again. Here's what I've updated in the latest commit:

  • Removed experimental classes like NMSGDOptimizer, QKGrouper, QKSVDGrouper
  • Documented all remaining grouper and proximal mapping classes in a new table of the README
  • Added underscores to non user-facing methods in distributed_utils.py

Let me know if anything's missing—this is very much a research prototype :)

Copy link
Copy Markdown
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@lisjin lisjin enabled auto-merge (squash) February 23, 2026 14:44
@lisjin lisjin merged commit 2a37912 into main Feb 23, 2026
21 of 22 checks passed
@lisjin lisjin deleted the lvj/pat branch February 23, 2026 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants