# Test `scan` and `apply_layers` in Llama 3

Hugging Face usage follows https://github.com/huggingface/notebooks/blob/main/examples/language_modeling.ipynb

To test scan, we need to use a custom modification of the transformer repo:
https://github.com/tengyifei/transformers/commit/646a575928d8514f220384c29d27c8b956826a91

In [39]:
%env PJRT_DEVICE=TPU
%env XLA_USE_SPMD=1

env: PJRT_DEVICE=TPU
env: XLA_USE_SPMD=1


In [2]:
import torch
import torch_xla

In [3]:
from datasets import load_dataset

dataset = load_dataset("Salesforce/wikitext", "wikitext-2-v1")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.bos_token_id = 128000
tokenizer.eos_token_id = 128001
tokenizer.pad_token_id = tokenizer.eos_token_id 

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"], batch_size=100)

In [5]:
tokenized_datasets.keys()  # type:ignore

dict_keys(['test', 'train', 'validation'])

In [6]:
tokenized_datasets["train"][1].keys()  # type:ignore

dict_keys(['input_ids', 'attention_mask'])

In [7]:
block_size = 128

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=100,
)

In [8]:
lm_datasets["train"][1].keys(), lm_datasets["validation"][1].keys()  # type:ignore

(dict_keys(['input_ids', 'attention_mask', 'labels']),
 dict_keys(['input_ids', 'attention_mask', 'labels']))

In [9]:
len(lm_datasets["validation"])  # type:ignore

3760

In [10]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=True,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [11]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=1,
    max_steps=250,
    save_strategy="no",
    save_total_limit=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    gradient_accumulation_steps=1,
    fp16=False,
    bf16=False,
    tpu_num_cores=4,
    push_to_hub=False,
)



In [12]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

max_steps is given, it will override any value given in num_train_epochs


In [13]:
trainer.train()

NOTE: Using for loop to run decoder layers


Epoch,Training Loss,Validation Loss
0,9.0329,8.54132


  xldata.append(torch.load(xbio))


TrainOutput(global_step=250, training_loss=9.73523388671875, metrics={'train_runtime': 484.2565, 'train_samples_per_second': 16.52, 'train_steps_per_second': 0.516, 'total_flos': 918253731840000.0, 'train_loss': 9.73523388671875, 'epoch': 0.21777003484320556})

In [14]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 496
Metric: CompileTime
  TotalSamples: 127
  Accumulator: 04m04s891ms765.836us
  ValueRate: 590ms462.892us / second
  Rate: 0.307469 / second
  Percentiles: 1%=022ms650.479us; 5%=022ms066.140us; 10%=022ms327.880us; 20%=023ms554.170us; 50%=025ms904.800us; 80%=027ms582.839us; 90%=029ms554.130us; 95%=050ms745.925us; 99%=01m16s310ms546.313us
Metric: ExecuteReplicatedTime
  TotalSamples: 623
  Accumulator: 17s943ms372.395us
  ValueRate: 041ms067.648us / second
  Rate: 1.51004 / second
  Percentiles: 1%=833.400us; 5%=002ms358.791us; 10%=003ms580.940us; 20%=003ms940.070us; 50%=014ms368.830us; 80%=049ms472.635us; 90%=050ms728.689us; 95%=050ms928.485us; 99%=051ms505.469us
Metric: TransferToDeviceTime
  TotalSamples: 1996
  Accumulator: 162ms571.091us
  ValueRate: 310.012us / second
  Rate: 4.67301 / second
  Percentiles: 1%=038.100us; 5%=041.440us; 10%=043.260us; 20%=046.090us; 50%=057.080us; 80%=079.650us; 90%=113.630us; 95%=120.690us; 99%=138.370us
Metric: Tra

## Train again, this time using scan

In [15]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=False,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [16]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs


NOTE: Using apply_layers to speed up compilation


Epoch,Training Loss,Validation Loss
0,8.978,8.451576


TrainOutput(global_step=250, training_loss=9.693704956054688, metrics={'train_runtime': 428.5784, 'train_samples_per_second': 18.666, 'train_steps_per_second': 0.583, 'total_flos': 918253731840000.0, 'train_loss': 9.693704956054688, 'epoch': 0.21777003484320556})

In [17]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 616
Metric: CompileTime
  TotalSamples: 7
  Accumulator: 03m44s793ms666.199us
  ValueRate: 408ms842.934us / second
  Rate: 0.01743 / second
  Percentiles: 1%=06s147ms919.325us; 5%=06s147ms919.325us; 10%=06s147ms919.325us; 20%=07s504ms010.619us; 50%=30s006ms548.202us; 80%=32s450ms634.352us; 90%=33s349ms461.745us; 95%=33s349ms461.745us; 99%=33s349ms461.745us
Metric: ExecuteReplicatedTime
  TotalSamples: 623
  Accumulator: 25s664ms659.413us
  ValueRate: 061ms431.209us / second
  Rate: 1.55174 / second
  Percentiles: 1%=831.631us; 5%=009ms238.389us; 10%=009ms321.840us; 20%=010ms608.990us; 50%=016ms623.439us; 80%=077ms214.197us; 90%=080ms723.306us; 95%=080ms400.482us; 99%=081ms456.271us
Metric: TransferToDeviceTime
  TotalSamples: 9901
  Accumulator: 802ms315.374us
  ValueRate: 010ms253.538us / second
  Rate: 126.662 / second
  Percentiles: 1%=041.370us; 5%=047.130us; 10%=051.570us; 20%=065.730us; 50%=076.350us; 80%=092.600us; 90%=121.080us; 95%=135.470us; 99

## Verify the numerical correctness of `apply_layers`

Under the same weights, and the same input tokens, both the for loop based
implementation and `apply_layers` based implementation should produce the same
output tokens.

In [27]:
import torch_xla
input_ids = torch.tensor(tokenized_datasets["train"][3]["input_ids"]).unsqueeze(0).type(torch.LongTensor) # type:ignore
attention_mask = torch.tensor(tokenized_datasets["train"][3]["attention_mask"]).unsqueeze(0) # type:ignore
input_ids = input_ids.to(torch_xla.device())
attention_mask = attention_mask.to(torch_xla.device())
torch_xla.sync()

In [28]:
input_ids

tensor([[128000,   5476,     73,  56761,    912,  86262,     88,   4298,    220,
             18,    551,    366,   3200,     29,  66416,    320,  11002,    551,
          50534,     99,  75267,  16144, 115687,  33710, 123283, 104612,     18,
           1174,  13318,    662,  86262,     88,   4298,    315,    279,  71735,
            220,     18,    883,   1174,  17037,  14183,    311,    439,  86262,
             88,   4298,  66416,  14767,   4994,   6457,   1174,    374,    264,
          39747,   3560,    571,     12,     31,   5737,   2835,   1847,   8040,
            555,  80949,    323,   7972,   5168,   1854,    369,    279,  32365,
          42585,    662,  45894,    304,   6186,    220,    679,     16,    304,
           6457,   1174,    433,    374,    279,   4948,   1847,    304,    279,
          86262,     88,   4298,   4101,    662,    366,   3200,     29,    279,
           1890,  37608,    315,  39747,    323,   1972,    571,     12,     31,
            892,  27120,    

In [29]:
model.model.unroll_decoders = False
model.model.logged_messages = set()
logits = model.forward(input_ids, attention_mask).logits  # type:ignore
logits.shape, logits

NOTE: Using apply_layers to speed up compilation


(torch.Size([1, 128, 128000]),
 tensor([[[-0.4818, -0.7262, -0.3666,  ..., -0.2981, -1.1244, -1.6042],
          [-0.8163, -1.2638, -0.9508,  ..., -0.7345, -2.1637, -1.3916],
          [-0.8062, -1.4021, -1.0180,  ..., -0.8188, -2.2277, -1.2193],
          ...,
          [-0.8461, -1.4137, -0.9949,  ..., -0.9075, -2.0537, -1.0461],
          [-0.8235, -1.4461, -1.0154,  ..., -0.8720, -2.0809, -0.9722],
          [-0.8358, -1.4509, -0.9605,  ..., -0.9073, -2.0669, -1.0069]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [30]:
def pick_token(logits):
  return torch.argmax(logits, dim=-1)

In [31]:
tokens = pick_token(logits)
tokens

tensor([[ 284, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 3200,
           29, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174,  279, 1174, 1174, 1174, 1174,  279, 1174, 1174, 1174,
         1174, 1174, 1174, 1174,  279, 1174, 1174, 1174, 1174, 1174, 1174, 3200,
           29, 1174, 1174, 1174,  279,  279, 1174, 1174,  279, 1174, 1174, 1174,
         1174, 1174,  279,  279, 1174,  279,  279, 1174,  279,  279,  279,  279,
          279,  279,  279, 1174,  279,  279,  279,  279]], device='xla:0')

In [32]:
tokenizer.decode(tokens[0].detach().cpu().numpy().tolist())

' =,,,,,,,,,,unk>,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, the,,,, the,,,,,,, the,,,,,,unk>,,, the the,, the,,,,, the the, the the, the the the the the the the, the the the the'

In [33]:
model.model.unroll_decoders = True
model.model.logged_messages = set()
for_loop_logits = model.forward(input_ids, attention_mask).logits  # type:ignore
for_loop_logits.shape, for_loop_logits

NOTE: Using for loop to run decoder layers


(torch.Size([1, 128, 128000]),
 tensor([[[-0.4818, -0.7262, -0.3666,  ..., -0.2981, -1.1244, -1.6042],
          [-0.8163, -1.2638, -0.9508,  ..., -0.7345, -2.1637, -1.3916],
          [-0.8062, -1.4021, -1.0180,  ..., -0.8188, -2.2277, -1.2193],
          ...,
          [-0.8452, -1.4137, -0.9950,  ..., -0.9081, -2.0534, -1.0463],
          [-0.8245, -1.4468, -1.0145,  ..., -0.8721, -2.0822, -0.9734],
          [-0.8346, -1.4508, -0.9589,  ..., -0.9069, -2.0677, -1.0074]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [35]:
for_loop_tokens = pick_token(for_loop_logits)
for_loop_tokens

tensor([[ 284, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 3200,
           29, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174, 1174,
         1174, 1174, 1174,  279, 1174, 1174, 1174, 1174,  279, 1174, 1174, 1174,
         1174, 1174, 1174, 1174,  279, 1174, 1174, 1174, 1174, 1174, 1174, 3200,
           29, 1174, 1174, 1174,  279,  279, 1174, 1174,  279, 1174, 1174, 1174,
         1174, 1174,  279,  279, 1174,  279,  279, 1174,  279,  279,  279,  279,
          279,  279,  279, 1174,  279,  279,  279,  279]], device='xla:0')

In [36]:
tokenizer.decode(for_loop_tokens[0].detach().cpu().numpy().tolist())

' =,,,,,,,,,,unk>,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, the,,,, the,,,,,,, the,,,,,,unk>,,, the the,, the,,,,, the the, the the, the the the the the the the, the the the the'

In [38]:
# Should be accurate to within 1%
torch.allclose(logits, for_loop_logits, atol=1e-2, rtol=1e-2)

True