# Test `scan` and `apply_layers` in Llama 3

Hugging Face usage follows https://github.com/huggingface/notebooks/blob/main/examples/language_modeling.ipynb

To test scan, we need to use a custom modification of the transformer repo:
https://github.com/tengyifei/transformers/commit/646a575928d8514f220384c29d27c8b956826a91

In [1]:
%env PJRT_DEVICE=TPU
%env XLA_USE_SPMD=1

env: PJRT_DEVICE=TPU
env: XLA_USE_SPMD=1


In [2]:
import torch
import torch_xla

In [3]:
from datasets import load_dataset

dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.bos_token_id = 128000
tokenizer.eos_token_id = 128001
tokenizer.pad_token_id = tokenizer.eos_token_id 

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"], batch_size=1000)

In [5]:
tokenized_datasets.keys()  # type:ignore

dict_keys(['test', 'train', 'validation'])

In [6]:
tokenized_datasets["train"][1].keys()  # type:ignore

dict_keys(['input_ids', 'attention_mask'])

In [7]:
block_size = 128

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
)

In [8]:
lm_datasets["train"][1].keys(), lm_datasets["validation"][1].keys()  # type:ignore

(dict_keys(['input_ids', 'attention_mask', 'labels']),
 dict_keys(['input_ids', 'attention_mask', 'labels']))

In [9]:
len(lm_datasets["validation"])  # type:ignore

3760

In [10]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=True,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [11]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    per_device_train_batch_size=48,
    per_device_eval_batch_size=48,
    num_train_epochs=1,
    max_steps=250,
    save_strategy="no",
    save_total_limit=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    gradient_accumulation_steps=1,
    fp16=False,
    bf16=False,
    tpu_num_cores=4,
    push_to_hub=False,
)



In [12]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

max_steps is given, it will override any value given in num_train_epochs


In [13]:
trainer.train()

NOTE: Using for loop to run decoder layers


Epoch,Training Loss,Validation Loss
0,9.3223,9.280924


  xldata.append(torch.load(xbio))


TrainOutput(global_step=250, training_loss=9.857639526367187, metrics={'train_runtime': 494.3655, 'train_samples_per_second': 24.274, 'train_steps_per_second': 0.506, 'total_flos': 1377380597760000.0, 'train_loss': 9.857639526367187, 'epoch': 0.006661515094993205})

In [14]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 418
Metric: CompileTime
  TotalSamples: 88
  Accumulator: 04m08s065ms120.553us
  ValueRate: 585ms226.465us / second
  Rate: 0.207606 / second
  Percentiles: 1%=025ms831.227us; 5%=025ms416.577us; 10%=026ms513.888us; 20%=026ms870.888us; 50%=027ms288.307us; 80%=029ms990.057us; 90%=046ms746.155us; 95%=09s818ms285.753us; 99%=01m21s677ms248.163us
Metric: ExecuteReplicatedTime
  TotalSamples: 506
  Accumulator: 21s165ms334.633us
  ValueRate: 050ms997.647us / second
  Rate: 1.19529 / second
  Percentiles: 1%=970.529us; 5%=003ms791.809us; 10%=003ms069.959us; 20%=016ms755.449us; 50%=046ms001.205us; 80%=067ms610.363us; 90%=067ms895.353us; 95%=067ms195.213us; 99%=078ms117.212us
Metric: TransferToDeviceTime
  TotalSamples: 1879
  Accumulator: 165ms703.039us
  ValueRate: 359.445us / second
  Rate: 4.22028 / second
  Percentiles: 1%=044.170us; 5%=048.610us; 10%=051.500us; 20%=056.890us; 50%=074.190us; 80%=111.590us; 90%=143.580us; 95%=152.080us; 99%=184.720us
Metric: T

## Train again, this time using scan

In [15]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=False,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [16]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs


NOTE: Using apply_layers to speed up compilation


Epoch,Training Loss,Validation Loss
0,9.322,9.292122


TrainOutput(global_step=250, training_loss=9.864378784179687, metrics={'train_runtime': 389.1826, 'train_samples_per_second': 30.834, 'train_steps_per_second': 0.642, 'total_flos': 1377380597760000.0, 'train_loss': 9.864378784179687, 'epoch': 0.006661515094993205})

In [17]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 501
Metric: CompileTime
  TotalSamples: 5
  Accumulator: 02m51s682ms544.317us
  ValueRate: 308ms604.997us / second
  Rate: 0.0138959 / second
  Percentiles: 1%=06s818ms504.432us; 5%=06s818ms504.432us; 10%=06s818ms504.432us; 20%=07s942ms549.106us; 50%=26s231ms246.452us; 80%=37s972ms181.958us; 90%=37s972ms181.958us; 95%=37s972ms181.958us; 99%=37s972ms181.958us
Metric: ExecuteReplicatedTime
  TotalSamples: 506
  Accumulator: 30s212ms458.665us
  ValueRate: 084ms995.421us / second
  Rate: 1.40676 / second
  Percentiles: 1%=976.630us; 5%=013ms494.628us; 10%=014ms839.659us; 20%=017ms191.778us; 50%=076ms731.302us; 80%=101ms265.730us; 90%=102ms558.240us; 95%=102ms812.550us; 99%=107ms823.029us
Metric: TransferToDeviceTime
  TotalSamples: 9316
  Accumulator: 856ms251.618us
  ValueRate: 011ms750.635us / second
  Rate: 110.407 / second
  Percentiles: 1%=049.060us; 5%=055.760us; 10%=068.300us; 20%=079.510us; 50%=089.610us; 80%=114.160us; 90%=149.000us; 95%=164.300us; 

## Verify the numerical correctness of `apply_layers`

Under the same weights, and the same input tokens, both the for loop based
implementation and `apply_layers` based implementation should produce the same
output tokens.

In [18]:
import torch_xla
input_ids = torch.tensor(tokenized_datasets["train"][3]["input_ids"]).unsqueeze(0).type(torch.LongTensor) # type:ignore
attention_mask = torch.tensor(tokenized_datasets["train"][3]["attention_mask"]).unsqueeze(0) # type:ignore
input_ids = input_ids.to(torch_xla.device())
attention_mask = attention_mask.to(torch_xla.device())
torch_xla.sync()

In [19]:
input_ids

tensor([[128000,   5476,     73,  56761,    912,  86262,     88,   4298,    220,
             18,    551,    366,   3200,     29,  66416,    320,  11002,    551,
          50534,     99,  75267,  16144, 115687,  33710, 123283, 104612,     18,
           1174,  13318,    662,  86262,     88,   4298,    315,    279,  71735,
            220,     18,    883,   1174,  17037,  14183,    311,    439,  86262,
             88,   4298,  66416,  14767,   4994,   6457,   1174,    374,    264,
          39747,   3560,    571,     12,     31,   5737,   2835,   1847,   8040,
            555,  80949,    323,   7972,   5168,   1854,    369,    279,  32365,
          42585,    662,  45894,    304,   6186,    220,    679,     16,    304,
           6457,   1174,    433,    374,    279,   4948,   1847,    304,    279,
          86262,     88,   4298,   4101,    662,  21445,    287,    279,   1890,
          37608,    315,  39747,    323,   1972,    571,     12,     31,    892,
          27120,    439,   1

In [20]:
model.model.unroll_decoders = False
model.model.logged_messages = set()
logits = model.forward(input_ids, attention_mask).logits  # type:ignore
logits.shape, logits

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


NOTE: Using apply_layers to speed up compilation


(torch.Size([1, 128, 128000]),
 tensor([[[-0.5572, -0.6444, -0.1399,  ..., -0.1027, -1.0320, -1.4400],
          [-0.9679, -1.3558, -0.9182,  ..., -0.6744, -2.1547, -1.2086],
          [-0.9505, -1.4370, -0.9905,  ..., -0.7459, -2.1894, -1.0792],
          ...,
          [-0.9731, -1.5157, -1.0472,  ..., -0.8122, -2.0453, -0.9087],
          [-0.9745, -1.5194, -1.0236,  ..., -0.8299, -2.0414, -0.9306],
          [-0.9767, -1.5013, -1.0378,  ..., -0.8184, -2.0358, -0.9563]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [21]:
def pick_token(logits):
  return torch.argmax(logits, dim=-1)

In [22]:
tokens = pick_token(logits)
tokens

tensor([[ 284, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174]], device='xla:0')

In [23]:
tokenizer.decode(tokens[0].detach().cpu().numpy().tolist())

' =,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,'

In [24]:
model.model.unroll_decoders = True
model.model.logged_messages = set()
for_loop_logits = model.forward(input_ids, attention_mask).logits  # type:ignore
for_loop_logits.shape, for_loop_logits

NOTE: Using for loop to run decoder layers


(torch.Size([1, 128, 128000]),
 tensor([[[-0.5572, -0.6444, -0.1399,  ..., -0.1027, -1.0320, -1.4400],
          [-0.9679, -1.3558, -0.9182,  ..., -0.6744, -2.1547, -1.2086],
          [-0.9505, -1.4370, -0.9905,  ..., -0.7459, -2.1894, -1.0792],
          ...,
          [-0.9729, -1.5153, -1.0475,  ..., -0.8124, -2.0456, -0.9089],
          [-0.9745, -1.5185, -1.0242,  ..., -0.8295, -2.0408, -0.9304],
          [-0.9772, -1.5007, -1.0376,  ..., -0.8187, -2.0361, -0.9560]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [25]:
for_loop_tokens = pick_token(for_loop_logits)
for_loop_tokens

tensor([[ 284, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174]], device='xla:0')

In [26]:
tokenizer.decode(for_loop_tokens[0].detach().cpu().numpy().tolist())

' =,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,'

In [27]:
# Should be accurate to within 1%
torch.allclose(logits, for_loop_logits, atol=1e-2, rtol=1e-2)

True

In [28]:
# Shouldn't be completely the same
torch.allclose(logits, for_loop_logits, atol=1e-6, rtol=1e-6)

False

# Test the gradients of scan

After I run both scan and for loop versions of the model on the same input, their
gradients should also be similar.

In [52]:
torch_xla.sync()

input_ids.requires_grad_(False)
attention_mask.requires_grad_(False)

# Run for loop model and collect the gradients
torch.manual_seed(42)
for_loop_grads = []
model.zero_grad()
model.model.zero_grad()
torch_xla.sync()
model.model.unroll_decoders = True
model.model.logged_messages = set()
with torch.enable_grad():
  for_loop_logits = model(input_ids, attention_mask).logits  # type: ignore
torch.sum(for_loop_logits).backward()
torch_xla.sync()
for (name, param) in model.named_parameters():
  # print(name, param.shape, param.grad.shape if param.grad is not None else None)
  assert param.grad is not None
  for_loop_grads.append((name, param.grad.clone().detach()))

# Run scan model and collect the gradients
torch.manual_seed(42)
scan_grads = []
model.zero_grad()
model.model.zero_grad()
torch_xla.sync()
model.model.unroll_decoders = False
model.model.logged_messages = set()
with torch.enable_grad():
  scan_logits = model(input_ids, attention_mask).logits  # type: ignore
torch.sum(scan_logits).backward()
torch_xla.sync()
for (name, param) in model.named_parameters():
  # print(name, param.shape, param.grad.shape if param.grad is not None else None)
  assert param.grad is not None
  scan_grads.append((name, param.grad.clone().detach()))

NOTE: Using for loop to run decoder layers
NOTE: Using apply_layers to speed up compilation


In [53]:
# Compare the gradients
assert len(for_loop_grads) == len(scan_grads)
assert len(for_loop_grads) > 0

In [75]:
for ((for_loop_name, for_loop_grad), (scan_name, scan_grad)) in zip(for_loop_grads, scan_grads):
  assert for_loop_name == scan_name
  assert torch.allclose(for_loop_grad, scan_grad, atol=150, rtol=5e-2), f"{for_loop_name} mismatch by: {torch.max(torch.abs(for_loop_grad - scan_grad))}"
  print(f"Pass: {for_loop_name}")

Pass: model.embed_tokens.weight
Pass: model.layers.0.self_attn.q_proj.weight
Pass: model.layers.0.self_attn.k_proj.weight
Pass: model.layers.0.self_attn.v_proj.weight
Pass: model.layers.0.self_attn.o_proj.weight
Pass: model.layers.0.mlp.gate_proj.weight
Pass: model.layers.0.mlp.up_proj.weight
Pass: model.layers.0.mlp.down_proj.weight
Pass: model.layers.0.input_layernorm.weight
Pass: model.layers.0.post_attention_layernorm.weight
Pass: model.layers.1.self_attn.q_proj.weight
Pass: model.layers.1.self_attn.k_proj.weight
Pass: model.layers.1.self_attn.v_proj.weight
Pass: model.layers.1.self_attn.o_proj.weight
Pass: model.layers.1.mlp.gate_proj.weight
Pass: model.layers.1.mlp.up_proj.weight
Pass: model.layers.1.mlp.down_proj.weight
Pass: model.layers.1.input_layernorm.weight
Pass: model.layers.1.post_attention_layernorm.weight
Pass: model.layers.2.self_attn.q_proj.weight
Pass: model.layers.2.self_attn.k_proj.weight
Pass: model.layers.2.self_attn.v_proj.weight
Pass: model.layers.2.self_attn.

In [76]:
for_loop_grads[3]

('model.layers.0.self_attn.v_proj.weight',
 tensor([[-2941.8899,  -524.6237,   955.9371,  ...,  -303.0801,  2649.9629,
           5254.5259],
         [-2609.9597,  1783.7124,   987.2698,  ...,  -145.8702,   593.4179,
           1061.9828],
         [-2566.4500,  -767.1675,   282.2457,  ...,  -266.0605,  1773.7897,
           4529.5469],
         ...,
         [-1018.5014, -5451.6758,  1914.1949,  ...,  3249.9912,  4745.0952,
          10958.9365],
         [ 2208.6790, -4751.6416,  1141.7797,  ...,  2499.1401,  1118.8171,
           2555.9319],
         [-1710.6656,  1940.0388,   103.2529,  ...,  -963.4113,   703.6531,
           2087.0320]], device='xla:0'))

In [77]:
scan_grads[3]

('model.layers.0.self_attn.v_proj.weight',
 tensor([[-2929.3120,  -549.2886,   954.0081,  ...,  -302.0883,  2650.4756,
           5271.4229],
         [-2612.5220,  1782.0955,   991.0074,  ...,  -151.0744,   588.5607,
           1072.1655],
         [-2563.7986,  -755.3365,   280.3193,  ...,  -271.0489,  1768.2189,
           4514.1919],
         ...,
         [-1020.3356, -5457.8105,  1913.4497,  ...,  3252.3542,  4746.0688,
          10957.5186],
         [ 2202.8054, -4738.2310,  1140.2487,  ...,  2488.2073,  1114.0991,
           2532.7532],
         [-1719.0941,  1943.4293,   105.6401,  ...,  -956.2529,   711.2629,
           2101.1606]], device='xla:0'))

In [78]:
torch.abs(for_loop_grads[3][1] - scan_grads[3][1]) / scan_grads[3][1]

tensor([[-0.0043, -0.0449,  0.0020,  ..., -0.0033,  0.0002,  0.0032],
        [-0.0010,  0.0009,  0.0038,  ..., -0.0344,  0.0083,  0.0095],
        [-0.0010, -0.0157,  0.0069,  ..., -0.0184,  0.0032,  0.0034],
        ...,
        [-0.0018, -0.0011,  0.0004,  ...,  0.0007,  0.0002,  0.0001],
        [ 0.0027, -0.0028,  0.0013,  ...,  0.0044,  0.0042,  0.0092],
        [-0.0049,  0.0017,  0.0226,  ..., -0.0075,  0.0107,  0.0067]],
       device='xla:0')