# Gated DeltaNet Research Notebook

This notebook is for researching Gated DeltaNet on Google Colab.

**Paper**: [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464)

## 1. Setup and Installation

In [None]:
# Clone the repository
!git clone https://github.com/sustcsonglin/flash-linear-attention.git
%cd flash-linear-attention

In [None]:
# Install dependencies
!pip install -e .
!pip install transformers einops torch

## 2. Import Gated DeltaNet

In [None]:
import torch
from fla.layers import GatedDeltaNet
from fla.models import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel

print("Imports successful!")

## 3. Basic Layer Usage

In [None]:
# Create a Gated DeltaNet layer
layer = GatedDeltaNet(
    hidden_size=512,
    expand_v=2.0,
    head_dim=64,
    num_heads=8,
    mode='chunk',
    use_gate=True,
    use_short_conv=True,
)

# Test with random input
batch_size = 2
seq_len = 128
hidden_size = 512

x = torch.randn(batch_size, seq_len, hidden_size)
output, _, _ = layer(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

## 4. Model Configuration

In [None]:
# Create a small Gated DeltaNet model
config = GatedDeltaNetConfig(
    hidden_size=768,
    num_hidden_layers=12,
    num_heads=12,
    head_dim=64,
    vocab_size=50257,
)

model = GatedDeltaNetForCausalLM(config)
print(f"Model created with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters")

## 5. Forward Pass Test

In [None]:
# Test forward pass
input_ids = torch.randint(0, config.vocab_size, (2, 64))
outputs = model(input_ids)

print(f"Input IDs shape: {input_ids.shape}")
print(f"Logits shape: {outputs.logits.shape}")

## 6. Research Experiments

Add your research experiments below:

In [None]:
# Your experiments here


## 7. Key Architecture Components

### Gated DeltaNet Layer Parameters:

- **hidden_size**: Hidden dimension
- **expand_v**: Value dimension expansion ratio (default: 2.0)
- **head_dim**: Dimension per head
- **num_heads**: Number of attention heads
- **num_v_heads**: Number of value heads (GVA if > num_heads)
- **mode**: Kernel mode ('chunk' or 'fused_recurrent')
- **use_beta**: Use beta parameter
- **use_gate**: Use output gating
- **use_short_conv**: Use short convolutions
- **allow_neg_eigval**: Allow negative eigenvalues
- **conv_size**: Convolution kernel size

### Key Operations:
- `chunk_gated_delta_rule`: Chunk-based implementation (training)
- `fused_recurrent_gated_delta_rule`: Fused recurrent (inference)