### Define dataset class

In [None]:
from datasets import load_dataset
import numpy as np  # or jax.numpy as jnp if needed
import pandas as pd
from pathlib import Path

class TLDRDataset:
    def __init__(self, data_dir, tokenizer, split, max_length=550):
        """
        Load TLDR dataset from local parquet files.
        
        Args:
            data_dir: Path to directory containing parquet files
            tokenizer: Tokenizer to use
            split: 'train', 'valid', or 'test'
            max_length: Maximum sequence length
        """
        # Load the parquet file
        parquet_file = Path(data_dir) / f"tldr_{split}.parquet"
        if not parquet_file.exists():
            raise FileNotFoundError(f"Dataset file not found: {parquet_file}")
        
        df = pd.read_parquet(parquet_file)
        
        # Combine prompt and label for training (teacher forcing)
        self.examples = [row["prompt"] + row["label"] for _, row in df.iterrows()]
        
        # Limit validation set size for faster iteration
        if "valid" in split:
            self.examples = self.examples[:2000]
        
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        print(f"Loaded {len(self.examples)} examples from {parquet_file}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        # Tokenize the text
        enc = self.tokenizer(
            self.examples[idx],
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
        )
        return {
            "input_ids": np.array(enc["input_ids"], dtype=np.int32),
            "attention_mask": np.array(enc["attention_mask"], dtype=np.int32),
            "labels": np.array(enc["input_ids"], dtype=np.int32),  # teacher forcing
        }


### Load model and tokeniser
Stick to gpt2 now for compatibility, then move to qwen. GPT2 = 124m params, 500mb. qwen0.6b = 550m params, 2gb. So beware 3x qwen0.6b on my 8gb gpu might tank its memory.

In [2]:
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import jax.numpy as jnp

# 1. Tokenizer is identical
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. Load the Flax (JAX) model
#    .from_pretrained returns a FlaxAutoModelForCausalLM whose weights live in model.params
model = FlaxAutoModelForCausalLM.from_pretrained("gpt2", dtype=jnp.float32)

# 3. If you’ve added new tokens, resize just like in PyTorch:
#    model = model.resize_token_embeddings(len(tokenizer))

# 4. Make sure padding is configured
model.config.pad_token_id = tokenizer.eos_token_id

# 5. Pull out the parameter dict for training
params = model.params

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.


### Inspect model 

In [3]:
# Inspect JAX/Flax model architecture and parameters
import jax
from jax.tree_util import tree_map
from flax.core import freeze

print("🔍 JAX/Flax Model Inspection")
print("=" * 50)

# 1. Basic model info
print(f"Model type: {type(model)}")
print(f"Model config: {model.config}")
print(f"Vocab size: {model.config.vocab_size}")
print(f"Hidden size: {model.config.n_embd}")
print(f"Number of layers: {model.config.n_layer}")
print(f"Number of attention heads: {model.config.n_head}")

print("\n📊 Parameter Analysis")
print("=" * 30)

# 2. Count parameters (JAX way)
def count_params(params):
    """Count total parameters in a JAX parameter tree"""
    return sum(x.size for x in jax.tree_util.tree_leaves(params))

total_params = count_params(params)
print(f"Total parameters: {total_params:,}")
print(f"Total parameters (millions): {total_params / 1_000_000:.2f}M")

# 3. Inspect parameter structure
print("\n🏗️ Parameter Structure")
print("=" * 25)

def print_param_shapes(params, prefix=""):
    """Recursively print parameter shapes"""
    if isinstance(params, dict):
        for key, value in params.items():
            print_param_shapes(value, f"{prefix}.{key}" if prefix else key)
    else:
        print(f"{prefix}: {params.shape} ({params.dtype})")

print_param_shapes(params)

# 4. Memory usage estimation
def estimate_memory(params):
    """Estimate memory usage in MB"""
    total_bytes = sum(x.nbytes for x in jax.tree_util.tree_leaves(params))
    return total_bytes / (1024 ** 2)

memory_mb = estimate_memory(params)
print(f"\n💾 Estimated memory usage: {memory_mb:.2f} MB")

# 5. Compare with different dtypes
print(f"\n🔢 Memory usage by dtype:")
print(f"  float32: {memory_mb:.2f} MB")
print(f"  float16: {memory_mb / 2:.2f} MB") 
print(f"  bfloat16: {memory_mb / 2:.2f} MB")

🔍 JAX/Flax Model Inspection
Model type: <class 'transformers.models.gpt2.modeling_flax_gpt2.FlaxGPT2LMHeadModel'>
Model config: GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "pad_token_id": 50256,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.53.2",
  "use_cache": true,
  "vocab

### Do stuff

In [8]:
train_dataset = TLDRDataset("../data", tokenizer, split="train")
dev_dataset   = TLDRDataset("../data", tokenizer, split="valid")

print(f"📊 Dataset loaded:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(dev_dataset)}")

# Preview a sample
sample = train_dataset[0]
print(f"\n📝 Sample data shapes:")
print(f"  input_ids: {sample['input_ids'].shape}")
print(f"  attention_mask: {sample['attention_mask'].shape}")
print(f"  labels: {sample['labels'].shape}")

# Preview actual content
print(f"\n📋 Sample content:")
print(f"  First 100 chars of tokenized text: {train_dataset.examples[0][:100]}...")
print(f"  Input IDs (first 10): {sample['input_ids'][:10]}")
print(f"  Labels (first 10): {sample['labels'][:10]}")


Loaded 116722 examples from ../data/tldr_train.parquet
Loaded 2000 examples from ../data/tldr_valid.parquet
📊 Dataset loaded:
  Training samples: 116722
  Validation samples: 2000

📝 Sample data shapes:
  input_ids: (550,)
  attention_mask: (550,)
  labels: (550,)

📋 Sample content:
  First 100 chars of tokenized text: SUBREDDIT: r/relationships
TITLE: I (f/22) have to figure out if I want to still know these girls or...
  Input IDs (first 10): [   50 10526 22083 49828    25   374    14 39468  5748   198]
  Labels (first 10): [   50 10526 22083 49828    25   374    14 39468  5748   198]
