Skip to content
/ grokfast Public
forked from ironjr/grokfast

Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"

License

Notifications You must be signed in to change notification settings

vuciv/grokfast

 
 

Repository files navigation

Delayed Compression: Faster Grokking via Frobenius Regularization After Overfitting

This is a fork of Grokfast with an additional method: Delayed Compression — applying Frobenius norm regularization after a model overfits.

TL;DR: Letting the model memorize first, then forcing compression via weight regularization, accelerates grokking 1.1-1.8x faster than Grokfast on modular arithmetic tasks.


Results

Operation Grokfast EMA Delayed Compression Winner Speedup
Addition (+) 1300 steps 720 steps Delayed Compression 1.81x
Subtraction (-) 1190 steps 1040 steps Delayed Compression 1.14x
Multiplication (*) 710 steps 710 steps Tie 1.00x
Division (/) N/A 980 steps Delayed Compression -

Steps to reach 95% validation accuracy with optimal hyperparameters for both methods.

Addition

Addition

Subtraction

Subtraction

Multiplication

Multiplication

Division

Division


The Method

The idea is simple: let the model fully memorize the training data, then apply weight regularization to force compression.

def frobenius_norm_loss(model):
    """Compute the Frobenius norm of all weight matrices."""
    frob_loss = 0.0
    for name, param in model.named_parameters():
        if 'weight' in name and param.requires_grad:
            frob_loss += torch.norm(param, p='fro') ** 2
    return frob_loss

# In training loop, after detecting overfitting:
if train_acc >= 0.99:  # Model has memorized
    loss = ce_loss + 0.01 * frobenius_norm_loss(model)

Reproduce Results

Setup

git clone <this-repo>
cd grokfast
python -m venv venv
source venv/bin/activate
pip install torch matplotlib tqdm numpy

Quick Test (single operation)

# Grokfast EMA (optimal params)
python main.py --label test_ema --operation "*" --filter ema --alpha 0.95 --lamb 2.0 --weight_decay 0.005 --budget 2000

# Delayed Compression (optimal params)
python main.py --label test_frob --operation "*" --filter none --frob_loss --frob_lamb 0.01 --frob_threshold 0.99 --budget 2000

Full Comparison (all operations)

python run_optimal.py

Runs both methods on +, -, *, / with optimal hyperparameters and generates comparison plots in results/.

Hyperparameter Grid Search

python run_grid_search.py

Optimal Hyperparameters

Grokfast EMA (from grid search):

  • alpha = 0.95
  • lamb = 2.0
  • weight_decay = 0.005

Delayed Compression (from grid search):

  • frob_lamb = 0.01
  • frob_threshold = 0.99 (trigger after 99% train accuracy)

Key Files

File Description
main.py Training script with both Grokfast and Delayed Compression
grokfast.py Original Grokfast gradient filtering functions
run_optimal.py Run all operations with optimal hyperparameters
run_grid_search.py Hyperparameter sweep for both methods
results/ Generated plots and saved model checkpoints

New Command-Line Arguments

--frob_loss           Enable Frobenius norm regularization after overfitting
--frob_lamb FLOAT     Frobenius loss coefficient (default: 0.01)
--frob_threshold FLOAT  Train accuracy threshold to trigger frob loss (default: 0.99)
--operation {+,-,*,/}  Modular arithmetic operation (default: *)

Citation

This work builds on:

@article{lee2024grokfast,
    title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},
    author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},
    journal={arXiv preprint arXiv:2405.20233},
    year={2024}
}

Original Grokfast README

Grokfast: Accelerated Grokking by
Amplifying Slow Gradients

Jaerin Lee* · Bong Gyun Kang* · Kihoon Kim · Kyoung Mu Lee

Seoul National University

*Denotes equal contribution.

Project ArXiv Github

tl;dr: We accelerate the grokking phenomenon by amplifying low-frequencies of the parameter gradients with an augmented optimizer.

Abstract: One puzzling artifact in machine learning dubbed grokking is where delayed generalization is achieved tenfolds of iterations after near perfect overfitting to the training data. Focusing on the long delay itself on behalf of machine learning practitioners, our goal is to accelerate generalization of a model under grokking phenomenon. By regarding a series of gradients of a parameter over training iterations as a random signal over time, we can spectrally decompose the parameter trajectories under gradient descent into two components: the fast-varying, overfitting-yielding component and the slow-varying, generalization-inducing component. This analysis allows us to accelerate the grokking phenomenon more than × 50 with only a few lines of code that amplifies the slow-varying components of gradients.


Usage

Instructions

Grokfast can be applied by inserting a single line before the optimizer call.

from grokfast import gradfilter_ema

grads = None

# ... in the optimization loop.
loss.backwards()

### Grokfast (has argument alpha, lamb)
grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb)

optimizer.step()

Arguments

  1. Grokfast (gradfilter_ema)

    • m: nn.Module: Model that contains every trainable parameters.
    • grads: Optional[Dict[str, torch.Tensor]] = None: Running memory (EMA). Initialize by setting it to None.
    • alpha: float = 0.98: Momentum hyperparameter of the EMA.
    • lamb: float = 2.0: Amplifying factor hyperparameter of the filter.

License

MIT

About

Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%