# Long Context & NTK Decay Demo
This notebook demonstrates:
- Increasing context length
- Per-token NTK (Neural Tangent Kernel) decay estimation
- How mini-GPT handles long sequences

In [None]:
import torch
from gptx.modules.model import GPTModel

# Initialize mini-GPT
model = GPTModel(vocab_size=50257, d_model=128, n_layers=2, n_heads=4, max_seq_len=128)
model.eval()

# Generate a long context input
long_input = torch.randint(0, 50257, (1, 128))  # max_seq_len for Free-Tier demo

# Forward pass
with torch.no_grad():
    logits = model(long_input)

print('Logits shape:', logits.shape)

# Simple NTK decay demo
ntk_decay = torch.mean(torch.abs(logits[:, :-1, :] - logits[:, 1:, :]))
print('Approx. NTK decay per token:', ntk_decay.item())