### Define dataset class

In [1]:
from datasets import load_dataset
import numpy as np  # or jax.numpy as jnp if needed

class TLDRDataset:
  def __init__(self, train_path, tokenizer, split, max_length=550):
    dataset = load_dataset(train_path, split=split)
    self.examples = [sample["prompt"] + sample["label"] for sample in dataset]
    self.examples = self.examples[:2000] if "valid" in split else self.examples
    self.tokenizer = tokenizer
    self.max_length = max_length

  def __len__(self):
    return len(self.examples)

  def __getitem__(self, idx):
    enc = self.tokenizer(
      self.examples[idx],
      truncation=True,
      max_length=self.max_length,
      padding="max_length",
    )
    return {
      "input_ids": np.array(enc["input_ids"], dtype=np.int32),
      "attention_mask": np.array(enc["attention_mask"], dtype=np.int32),
      "labels": np.array(enc["input_ids"], dtype=np.int32),  # teacher forcing
    }


  from .autonotebook import tqdm as notebook_tqdm


### Load model and tokeniser
Stick to gpt2 now for compatibility, then move to qwen. GPT2 = 124m params, 500mb. qwen0.6b = 550m params, 2gb. So beware 3x qwen0.6b on my 8gb gpu might tank its memory.

In [2]:
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import jax.numpy as jnp

# 1. Tokenizer is identical
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. Load the Flax (JAX) model
#    .from_pretrained returns a FlaxAutoModelForCausalLM whose weights live in model.params
model = FlaxAutoModelForCausalLM.from_pretrained("gpt2", dtype=jnp.float32)

# 3. If you’ve added new tokens, resize just like in PyTorch:
#    model = model.resize_token_embeddings(len(tokenizer))

# 4. Make sure padding is configured
model.config.pad_token_id = tokenizer.eos_token_id

# 5. Pull out the parameter dict for training
params = model.params


TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.
