This is a fork of Grokfast with an additional method: Delayed Compression — applying Frobenius norm regularization after a model overfits.
TL;DR: Letting the model memorize first, then forcing compression via weight regularization, accelerates grokking 1.1-1.8x faster than Grokfast on modular arithmetic tasks.
| Operation | Grokfast EMA | Delayed Compression | Winner | Speedup |
|---|---|---|---|---|
| Addition (+) | 1300 steps | 720 steps | Delayed Compression | 1.81x |
| Subtraction (-) | 1190 steps | 1040 steps | Delayed Compression | 1.14x |
| Multiplication (*) | 710 steps | 710 steps | Tie | 1.00x |
| Division (/) | N/A | 980 steps | Delayed Compression | - |
Steps to reach 95% validation accuracy with optimal hyperparameters for both methods.
The idea is simple: let the model fully memorize the training data, then apply weight regularization to force compression.
def frobenius_norm_loss(model):
"""Compute the Frobenius norm of all weight matrices."""
frob_loss = 0.0
for name, param in model.named_parameters():
if 'weight' in name and param.requires_grad:
frob_loss += torch.norm(param, p='fro') ** 2
return frob_loss
# In training loop, after detecting overfitting:
if train_acc >= 0.99: # Model has memorized
loss = ce_loss + 0.01 * frobenius_norm_loss(model)git clone <this-repo>
cd grokfast
python -m venv venv
source venv/bin/activate
pip install torch matplotlib tqdm numpy# Grokfast EMA (optimal params)
python main.py --label test_ema --operation "*" --filter ema --alpha 0.95 --lamb 2.0 --weight_decay 0.005 --budget 2000
# Delayed Compression (optimal params)
python main.py --label test_frob --operation "*" --filter none --frob_loss --frob_lamb 0.01 --frob_threshold 0.99 --budget 2000python run_optimal.pyRuns both methods on +, -, *, / with optimal hyperparameters and generates comparison plots in results/.
python run_grid_search.pyGrokfast EMA (from grid search):
alpha = 0.95lamb = 2.0weight_decay = 0.005
Delayed Compression (from grid search):
frob_lamb = 0.01frob_threshold = 0.99(trigger after 99% train accuracy)
| File | Description |
|---|---|
main.py |
Training script with both Grokfast and Delayed Compression |
grokfast.py |
Original Grokfast gradient filtering functions |
run_optimal.py |
Run all operations with optimal hyperparameters |
run_grid_search.py |
Hyperparameter sweep for both methods |
results/ |
Generated plots and saved model checkpoints |
--frob_loss Enable Frobenius norm regularization after overfitting
--frob_lamb FLOAT Frobenius loss coefficient (default: 0.01)
--frob_threshold FLOAT Train accuracy threshold to trigger frob loss (default: 0.99)
--operation {+,-,*,/} Modular arithmetic operation (default: *)
This work builds on:
@article{lee2024grokfast,
title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},
author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},
journal={arXiv preprint arXiv:2405.20233},
year={2024}
}Jaerin Lee* · Bong Gyun Kang* · Kihoon Kim · Kyoung Mu Lee
*Denotes equal contribution.
tl;dr: We accelerate the grokking phenomenon by amplifying low-frequencies of the parameter gradients with an augmented optimizer.
Abstract: One puzzling artifact in machine learning dubbed grokking is where delayed generalization is achieved tenfolds of iterations after near perfect overfitting to the training data. Focusing on the long delay itself on behalf of machine learning practitioners, our goal is to accelerate generalization of a model under grokking phenomenon. By regarding a series of gradients of a parameter over training iterations as a random signal over time, we can spectrally decompose the parameter trajectories under gradient descent into two components: the fast-varying, overfitting-yielding component and the slow-varying, generalization-inducing component. This analysis allows us to accelerate the grokking phenomenon more than × 50 with only a few lines of code that amplifies the slow-varying components of gradients.
Grokfast can be applied by inserting a single line before the optimizer call.
from grokfast import gradfilter_ema
grads = None
# ... in the optimization loop.
loss.backwards()
### Grokfast (has argument alpha, lamb)
grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb)
optimizer.step()-
Grokfast (
gradfilter_ema)m: nn.Module: Model that contains every trainable parameters.grads: Optional[Dict[str, torch.Tensor]] = None: Running memory (EMA). Initialize by setting it toNone.alpha: float = 0.98: Momentum hyperparameter of the EMA.lamb: float = 2.0: Amplifying factor hyperparameter of the filter.
MIT



