# Test `scan` and `apply_layers` in Llama 3

Hugging Face usage follows https://github.com/huggingface/notebooks/blob/main/examples/language_modeling.ipynb

To test scan, we need to use a custom modification of the transformer repo:
https://github.com/tengyifei/transformers/commit/646a575928d8514f220384c29d27c8b956826a91

In [1]:
%env PJRT_DEVICE=TPU
%env XLA_USE_SPMD=1

env: PJRT_DEVICE=TPU
env: XLA_USE_SPMD=1


In [2]:
import torch
import torch_xla
import numpy as np
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def reset_seed():
  torch.random.manual_seed(42)
  torch_xla.manual_seed(42)
  np.random.seed(42)

In [4]:
from datasets import load_dataset

dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.bos_token_id = 128000
tokenizer.eos_token_id = 128001
tokenizer.pad_token_id = tokenizer.eos_token_id 

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"], batch_size=1000)

Map: 100%|██████████| 4358/4358 [00:00<00:00, 5492.63 examples/s]
Map: 100%|██████████| 1801350/1801350 [02:28<00:00, 12148.35 examples/s]
Map: 100%|██████████| 3760/3760 [00:00<00:00, 10411.03 examples/s]


In [6]:
tokenized_datasets.keys()  # type:ignore

dict_keys(['test', 'train', 'validation'])

In [7]:
tokenized_datasets["train"][1].keys()  # type:ignore

dict_keys(['input_ids', 'attention_mask'])

In [8]:
block_size = 128

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
)

Map: 100%|██████████| 4358/4358 [00:02<00:00, 1820.75 examples/s]
Map: 100%|██████████| 1801350/1801350 [17:01<00:00, 1762.87 examples/s]
Map: 100%|██████████| 3760/3760 [00:02<00:00, 1816.89 examples/s]


In [9]:
lm_datasets["train"][1].keys(), lm_datasets["validation"][1].keys()  # type:ignore

(dict_keys(['input_ids', 'attention_mask', 'labels']),
 dict_keys(['input_ids', 'attention_mask', 'labels']))

In [10]:
len(lm_datasets["validation"])  # type:ignore

3760

In [11]:
from transformers import LlamaConfig, LlamaForCausalLM

reset_seed()

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=True,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [11]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    max_steps=2500,
    save_strategy="no",
    save_total_limit=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=500,
    gradient_accumulation_steps=1,
    fp16=False,
    bf16=False,
    tpu_num_cores=4,
    push_to_hub=False,
)

2024-11-08 07:27:01.068015: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-08 07:27:01.068116: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-08 07:27:01.069501: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [12]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    processing_class=tokenizer,
)

NameError: name 'model' is not defined

In [14]:
trainer.train()

NOTE: Using for loop to run decoder layers


Epoch,Training Loss,Validation Loss
0,6.2021,6.074296


  xldata.append(torch.load(xbio))


TrainOutput(global_step=2500, training_loss=6.7848349609375, metrics={'train_runtime': 2166.4797, 'train_samples_per_second': 73.853, 'train_steps_per_second': 1.154, 'total_flos': 1.83650746368e+16, 'train_loss': 6.7848349609375, 'epoch': 0.08881941237076775})

In [15]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 2628
Metric: CompileTime
  TotalSamples: 68
  Accumulator: 04m45s661ms943.517us
  ValueRate: 112ms992.609us / second
  Rate: 0.0338977 / second
  Percentiles: 1%=027ms010.525us; 5%=027ms205.166us; 10%=027ms374.196us; 20%=028ms867.906us; 50%=029ms050.266us; 80%=031ms631.485us; 90%=051ms153.823us; 95%=12s953ms540.568us; 99%=01m12s953ms510.969us
Metric: ExecuteReplicatedTime
  TotalSamples: 2696
  Accumulator: 04m40s292ms833.041us
  ValueRate: 135ms694.627us / second
  Rate: 1.81133 / second
  Percentiles: 1%=002ms342.130us; 5%=003ms790.429us; 10%=022ms129.746us; 20%=086ms793.757us; 50%=086ms367.937us; 80%=087ms787.907us; 90%=087ms078.627us; 95%=087ms355.897us; 99%=089ms792.497us
Metric: TransferToDeviceTime
  TotalSamples: 13063
  Accumulator: 01s116ms795.332us
  ValueRate: 753.166us / second
  Rate: 9.17877 / second
  Percentiles: 1%=042.330us; 5%=045.490us; 10%=047.450us; 20%=050.260us; 50%=066.650us; 80%=129.369us; 90%=153.390us; 95%=158.160us; 99%=180.

## Train again, this time using scan

In [11]:
from transformers import LlamaConfig, LlamaForCausalLM

reset_seed()

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=False,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    processing_class=tokenizer,
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs


NOTE: Using apply_layers to speed up compilation


AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.unsqueeze.default(tensor([...], device='xla:0', size=(1, 128, 64)), 1)

We can see that the for loop and scan model train to the same loss over time.

In [None]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 2689
Metric: CompileTime
  TotalSamples: 7
  Accumulator: 03m53s782ms040.051us
  ValueRate: 055ms660.536us / second
  Rate: 0.00221449 / second
  Percentiles: 1%=07s096ms216.112us; 5%=07s096ms216.112us; 10%=07s096ms216.112us; 20%=08s578ms547.271us; 50%=32s036ms769.307us; 80%=33s879ms083.254us; 90%=36s811ms836.505us; 95%=36s811ms836.505us; 99%=36s811ms836.505us
Metric: ExecuteReplicatedTime
  TotalSamples: 2696
  Accumulator: 06m36s541ms094.942us
  ValueRate: 104ms571.321us / second
  Rate: 0.918485 / second
  Percentiles: 1%=009ms357.388us; 5%=021ms740.458us; 10%=024ms257.247us; 20%=131ms199.266us; 50%=132ms933.316us; 80%=132ms433.307us; 90%=133ms727.416us; 95%=133ms989.236us; 99%=135ms132.886us
Metric: TransferToDeviceTime
  TotalSamples: 78760
  Accumulator: 07s139ms091.466us
  ValueRate: 004ms247.511us / second
  Rate: 41.5308 / second
  Percentiles: 1%=053.389us; 5%=059.270us; 10%=065.400us; 20%=079.460us; 50%=094.330us; 80%=115.780us; 90%=161.000us;

## Verify the numerical correctness of `apply_layers`

Under the same weights, and the same input tokens, both the for loop based
implementation and `apply_layers` based implementation should produce the same
output tokens.

In [None]:
import torch_xla
input_ids = torch.tensor(tokenized_datasets["train"][3]["input_ids"]).unsqueeze(0).type(torch.LongTensor) # type:ignore
attention_mask = torch.tensor(tokenized_datasets["train"][3]["attention_mask"]).unsqueeze(0) # type:ignore
input_ids = input_ids.to(torch_xla.device())
attention_mask = attention_mask.to(torch_xla.device())
torch_xla.sync()

In [None]:
input_ids

tensor([[128000,   5476,     73,  56761,    912,  86262,     88,   4298,    220,
             18,    551,    366,   3200,     29,  66416,    320,  11002,    551,
          50534,     99,  75267,  16144, 115687,  33710, 123283, 104612,     18,
           1174,  13318,    662,  86262,     88,   4298,    315,    279,  71735,
            220,     18,    883,   1174,  17037,  14183,    311,    439,  86262,
             88,   4298,  66416,  14767,   4994,   6457,   1174,    374,    264,
          39747,   3560,    571,     12,     31,   5737,   2835,   1847,   8040,
            555,  80949,    323,   7972,   5168,   1854,    369,    279,  32365,
          42585,    662,  45894,    304,   6186,    220,    679,     16,    304,
           6457,   1174,    433,    374,    279,   4948,   1847,    304,    279,
          86262,     88,   4298,   4101,    662,  21445,    287,    279,   1890,
          37608,    315,  39747,    323,   1972,    571,     12,     31,    892,
          27120,    439,   1

In [None]:
model.model.unroll_decoders = False
model.model.logged_messages = set()
logits = model.forward(input_ids, attention_mask).logits  # type:ignore
logits.shape, logits

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


NOTE: Using apply_layers to speed up compilation


(torch.Size([1, 128, 128000]),
 tensor([[[-3.5386, -3.6623, -3.4869,  ..., -3.0051, -4.4353, -4.4233],
          [-4.9098, -4.5080, -4.4443,  ..., -4.2112, -5.0679, -4.4701],
          [-4.6143, -4.4359, -4.2461,  ..., -4.3088, -4.6903, -4.2176],
          ...,
          [-4.7349, -4.1849, -4.3566,  ..., -4.7857, -5.0580, -4.8042],
          [-4.4976, -4.0455, -4.2522,  ..., -4.6340, -4.9632, -4.8398],
          [-4.8248, -5.4161, -4.9124,  ..., -4.9636, -5.9460, -5.1975]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [None]:
def pick_token(logits):
  return torch.argmax(logits, dim=-1)

In [None]:
tokens = pick_token(logits)
tokens

tensor([[ 284,  315, 1174,  574,   13,  574,  574, 1174,   16, 1389,  220, 3200,
           29,  720,  720,  220,  883,  220,  883,  883,  883,  883,  883,  883,
          883,  883,  883,  220,  883,  720, 1174, 1174,  883,  279,  220,  315,
         1049, 1389,  720,  323,  220,  311,  279,  264,  311,  662,  662,  662,
          662,  279,  662,  323,  264,  220,  571,  315,   12,   31,  220,  571,
          662,  662,  555,  279,  571,  279,  662,  366,  662,  279,  220,  220,
          662,  720, 1174,  279,  220, 1049,   15, 1174,  279, 1174,  279,  574,
          264, 1176,  571, 1174,  279,  220,  315,  662,  662,  662,  720,  374,
          304, 1176,  220,  315,  279,  571,  279, 1174,   12,   31,  220, 1174,
          662,  264, 1176,  662,  323, 1176,  315,  304,  311,  279,  220,  571,
          662,  279,  279, 1176,  366,  330,  330,  662]], device='xla:0')

In [None]:
tokenizer.decode(tokens[0].detach().cpu().numpy().tolist())

' = of, was. was was,1 – unk> \n \n  )  ) ) ) ) ) ) ) ) )  ) \n,, ) the  of200 – \n and  to the a to.... the. and a  @ of-@  @.. by the @ the. <. the  . \n, the 2000, the, the was a first @, the  of... \n is in first  of the @ the,-@ ,. a first. and first of in to the  @. the the first < " ".'

In [None]:
model.model.unroll_decoders = True
model.model.logged_messages = set()
for_loop_logits = model.forward(input_ids, attention_mask).logits  # type:ignore
for_loop_logits.shape, for_loop_logits

NOTE: Using for loop to run decoder layers


(torch.Size([1, 128, 128000]),
 tensor([[[-3.5386, -3.6623, -3.4869,  ..., -3.0051, -4.4353, -4.4233],
          [-4.9098, -4.5080, -4.4443,  ..., -4.2112, -5.0679, -4.4701],
          [-4.6143, -4.4359, -4.2461,  ..., -4.3088, -4.6903, -4.2176],
          ...,
          [-4.7349, -4.1843, -4.3579,  ..., -4.7867, -5.0578, -4.8052],
          [-4.4990, -4.0472, -4.2530,  ..., -4.6336, -4.9631, -4.8392],
          [-4.8254, -5.4172, -4.9121,  ..., -4.9634, -5.9421, -5.1983]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [None]:
for_loop_tokens = pick_token(for_loop_logits)
for_loop_tokens

tensor([[ 284,  315, 1174,  574,   13,  574,  574, 1174,   16, 1389,  220, 3200,
           29,  720,  720,  220,  883,  220,  883,  883,  883,  883,  883,  883,
          883,  883,  883,  220,  883,  720, 1174, 1174,  883,  279,  220,  315,
         1049, 1389,  720,  323,  220,  311,  279,  264,  311,  662,  662,  662,
          662,  279,  662,  323,  264,  220,  571,  315,   12,   31,  220,  571,
          662,  662,  555,  279,  571,  279,  662,  366,  662,  279,  220,  220,
          662,  720, 1174,  279,  220, 1049,   15, 1174,  279, 1174,  279,  574,
          264, 1176,  571, 1174,  279,  220,  315,  662,  662,  662,  720,  374,
          304, 1176,  220,  315,  279,  571,  279, 1174,   12,   31,  220, 1174,
          662,  264, 1176,  662,  323, 1176,  315,  304,  311,  279,  220,  571,
          662,  279,  279, 1176,  366,  330,  330,  662]], device='xla:0')

In [None]:
tokenizer.decode(for_loop_tokens[0].detach().cpu().numpy().tolist())

' = of, was. was was,1 – unk> \n \n  )  ) ) ) ) ) ) ) ) )  ) \n,, ) the  of200 – \n and  to the a to.... the. and a  @ of-@  @.. by the @ the. <. the  . \n, the 2000, the, the was a first @, the  of... \n is in first  of the @ the,-@ ,. a first. and first of in to the  @. the the first < " ".'

In [None]:
# Should be accurate to within 1%
torch.allclose(logits, for_loop_logits, atol=1e-2, rtol=1e-2)

True

In [None]:
# Shouldn't be completely the same
torch.allclose(logits, for_loop_logits, atol=1e-6, rtol=1e-6)

False

# Test the gradients of scan

After I run both scan and for loop versions of the model on the same input, their
gradients should also be similar.

In [None]:
torch_xla.sync()

input_ids.requires_grad_(False)
attention_mask.requires_grad_(False)
labels = input_ids[:, :].clone().contiguous()

# Run for loop model and collect the gradients
reset_seed()
for_loop_grads = []
model.zero_grad()
model.model.zero_grad()
torch_xla.sync()
model.model.unroll_decoders = True
model.model.logged_messages = set()
with torch.enable_grad():
  model(input_ids, attention_mask, labels=labels).loss.backward()  # type: ignore
torch_xla.sync()
for (name, param) in model.named_parameters():
  assert param.grad is not None
  for_loop_grads.append((name, param.grad.clone().detach()))

# Run scan model and collect the gradients
reset_seed()
scan_grads = []
model.zero_grad()
model.model.zero_grad()
torch_xla.sync()
model.model.unroll_decoders = False
model.model.logged_messages = set()
with torch.enable_grad():
  model(input_ids, attention_mask, labels=labels).loss.backward()  # type: ignore
torch_xla.sync()
for (name, param) in model.named_parameters():
  assert param.grad is not None
  scan_grads.append((name, param.grad.clone().detach()))

NOTE: Using for loop to run decoder layers
NOTE: Using apply_layers to speed up compilation


In [None]:
# Compare the gradients
assert len(for_loop_grads) == len(scan_grads)
assert len(for_loop_grads) > 0

In [None]:
for ((for_loop_name, for_loop_grad), (scan_name, scan_grad)) in zip(for_loop_grads, scan_grads):
  assert for_loop_name == scan_name
  assert torch.allclose(for_loop_grad, scan_grad, atol=1e-3, rtol=1e-3), f"{for_loop_name} mismatch by: {torch.max(torch.abs(for_loop_grad - scan_grad))}"
  print(f"Pass: {for_loop_name}")

Pass: model.embed_tokens.weight
Pass: model.layers.0.self_attn.q_proj.weight
Pass: model.layers.0.self_attn.k_proj.weight
Pass: model.layers.0.self_attn.v_proj.weight
Pass: model.layers.0.self_attn.o_proj.weight
Pass: model.layers.0.mlp.gate_proj.weight
Pass: model.layers.0.mlp.up_proj.weight
Pass: model.layers.0.mlp.down_proj.weight
Pass: model.layers.0.input_layernorm.weight
Pass: model.layers.0.post_attention_layernorm.weight
Pass: model.layers.1.self_attn.q_proj.weight
Pass: model.layers.1.self_attn.k_proj.weight
Pass: model.layers.1.self_attn.v_proj.weight
Pass: model.layers.1.self_attn.o_proj.weight
Pass: model.layers.1.mlp.gate_proj.weight
Pass: model.layers.1.mlp.up_proj.weight
Pass: model.layers.1.mlp.down_proj.weight
Pass: model.layers.1.input_layernorm.weight
Pass: model.layers.1.post_attention_layernorm.weight
Pass: model.layers.2.self_attn.q_proj.weight
Pass: model.layers.2.self_attn.k_proj.weight
Pass: model.layers.2.self_attn.v_proj.weight
Pass: model.layers.2.self_attn.

In [None]:
for_loop_grads[3], torch.max(for_loop_grads[3][1]), torch.min(for_loop_grads[3][1])

(('model.layers.0.self_attn.v_proj.weight',
  tensor([[ 1.7305e-03, -1.6064e-03, -2.2568e-03,  ...,  2.7648e-03,
           -9.1544e-04, -2.4204e-04],
          [ 2.0572e-03,  4.0270e-04,  1.7538e-03,  ..., -8.5519e-04,
            1.0854e-03, -8.3036e-04],
          [ 2.5064e-03,  1.0427e-03, -5.9767e-04,  ...,  1.7253e-03,
            3.0151e-03,  5.6954e-03],
          ...,
          [-7.3633e-03,  4.2043e-03,  7.1877e-04,  ..., -5.1985e-03,
           -2.1175e-03,  1.1589e-04],
          [-1.7412e-03, -5.6326e-03, -4.5214e-03,  ...,  6.4733e-03,
           -1.0346e-03,  3.4541e-03],
          [ 7.2834e-04, -2.6041e-03, -9.6229e-05,  ..., -1.0209e-03,
           -2.4465e-03, -3.2896e-03]], device='xla:0')),
 tensor(0.0273, device='xla:0'),
 tensor(-0.0253, device='xla:0'))

In [None]:
scan_grads[3], torch.max(scan_grads[3][1]), torch.min(scan_grads[3][1])

(('model.layers.0.self_attn.v_proj.weight',
  tensor([[ 0.0017, -0.0016, -0.0023,  ...,  0.0028, -0.0009, -0.0002],
          [ 0.0021,  0.0004,  0.0017,  ..., -0.0008,  0.0011, -0.0008],
          [ 0.0025,  0.0010, -0.0006,  ...,  0.0017,  0.0030,  0.0057],
          ...,
          [-0.0073,  0.0042,  0.0007,  ..., -0.0052, -0.0021,  0.0001],
          [-0.0018, -0.0056, -0.0045,  ...,  0.0065, -0.0010,  0.0035],
          [ 0.0007, -0.0026, -0.0001,  ..., -0.0010, -0.0025, -0.0033]],
         device='xla:0')),
 tensor(0.0273, device='xla:0'),
 tensor(-0.0253, device='xla:0'))