# Test `scan` and `apply_layers` in Llama 3

Hugging Face usage follows https://github.com/huggingface/notebooks/blob/main/examples/language_modeling.ipynb

To test scan, we need to use a custom modification of the transformer repo:
https://github.com/tengyifei/transformers/commit/646a575928d8514f220384c29d27c8b956826a91

In [1]:
%env PJRT_DEVICE=TPU
%env XLA_USE_SPMD=1

env: PJRT_DEVICE=TPU
env: XLA_USE_SPMD=1


In [2]:
import torch
import torch_xla

In [3]:
from datasets import load_dataset

dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

  from .autonotebook import tqdm as notebook_tqdm
Downloading data: 100%|██████████| 722k/722k [00:00<00:00, 2.66MB/s]
Downloading data: 100%|██████████| 156M/156M [00:01<00:00, 91.4MB/s] 
Downloading data: 100%|██████████| 156M/156M [00:01<00:00, 94.5MB/s] 
Downloading data: 100%|██████████| 655k/655k [00:00<00:00, 3.16MB/s]
Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 15298.90 examples/s]
Generating train split: 100%|██████████| 1801350/1801350 [00:02<00:00, 874436.10 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 654001.12 examples/s]


In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.bos_token_id = 128000
tokenizer.eos_token_id = 128001
tokenizer.pad_token_id = tokenizer.eos_token_id 

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"], batch_size=1000)

Map: 100%|██████████| 4358/4358 [00:00<00:00, 8769.23 examples/s]
Map: 100%|██████████| 1801350/1801350 [02:31<00:00, 11910.55 examples/s]
Map: 100%|██████████| 3760/3760 [00:00<00:00, 13159.53 examples/s]


In [5]:
tokenized_datasets.keys()  # type:ignore

dict_keys(['test', 'train', 'validation'])

In [6]:
tokenized_datasets["train"][1].keys()  # type:ignore

dict_keys(['input_ids', 'attention_mask'])

In [7]:
block_size = 128

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
)

Map: 100%|██████████| 4358/4358 [00:02<00:00, 1850.07 examples/s]
Map: 100%|██████████| 1801350/1801350 [16:50<00:00, 1782.54 examples/s]
Map: 100%|██████████| 3760/3760 [00:02<00:00, 1859.26 examples/s]


In [8]:
lm_datasets["train"][1].keys(), lm_datasets["validation"][1].keys()  # type:ignore

(dict_keys(['input_ids', 'attention_mask', 'labels']),
 dict_keys(['input_ids', 'attention_mask', 'labels']))

In [9]:
len(lm_datasets["validation"])  # type:ignore

3760

In [10]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=True,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [11]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    per_device_train_batch_size=48,
    per_device_eval_batch_size=48,
    num_train_epochs=1,
    max_steps=2500,
    save_strategy="no",
    save_total_limit=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=500,
    gradient_accumulation_steps=1,
    fp16=False,
    bf16=False,
    tpu_num_cores=4,
    push_to_hub=False,
)



In [12]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

max_steps is given, it will override any value given in num_train_epochs


In [13]:
trainer.train()

NOTE: Using for loop to run decoder layers


Epoch,Training Loss,Validation Loss
0,6.2594,6.140633


  xldata.append(torch.load(xbio))


TrainOutput(global_step=2500, training_loss=6.8357990234375, metrics={'train_runtime': 2629.9609, 'train_samples_per_second': 45.628, 'train_steps_per_second': 0.951, 'total_flos': 1.37738059776e+16, 'train_loss': 6.8357990234375, 'epoch': 0.06661515094993205})

In [14]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 2668
Metric: CompileTime
  TotalSamples: 88
  Accumulator: 05m45s242ms492.497us
  ValueRate: 112ms815.668us / second
  Rate: 0.0344962 / second
  Percentiles: 1%=028ms879.307us; 5%=031ms263.066us; 10%=031ms456.126us; 20%=032ms031.136us; 50%=033ms408.706us; 80%=035ms237.815us; 90%=038ms359.085us; 95%=10s142ms329.682us; 99%=02m35s284ms655.907us
Metric: ExecuteReplicatedTime
  TotalSamples: 2756
  Accumulator: 03m54s205ms461.657us
  ValueRate: 067ms553.151us / second
  Rate: 1.2048 / second
  Percentiles: 1%=002ms176.630us; 5%=003ms665.040us; 10%=016ms926.918us; 20%=020ms151.147us; 50%=068ms543.992us; 80%=068ms898.421us; 90%=068ms116.450us; 95%=068ms298.821us; 99%=069ms369.122us
Metric: TransferToDeviceTime
  TotalSamples: 13123
  Accumulator: 01s015ms904.557us
  ValueRate: 473.883us / second
  Rate: 6.09047 / second
  Percentiles: 1%=040.200us; 5%=043.580us; 10%=045.910us; 20%=049.740us; 50%=066.140us; 80%=102.590us; 90%=146.380us; 95%=153.540us; 99%=162.0

## Train again, this time using scan

In [15]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=False,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [16]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs


NOTE: Using apply_layers to speed up compilation


Epoch,Training Loss,Validation Loss
0,6.26,6.143163


TrainOutput(global_step=2500, training_loss=6.8406908203125, metrics={'train_runtime': 2892.4518, 'train_samples_per_second': 41.487, 'train_steps_per_second': 0.864, 'total_flos': 1.37738059776e+16, 'train_loss': 6.8406908203125, 'epoch': 0.06661515094993205})

In [17]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 2749
Metric: CompileTime
  TotalSamples: 7
  Accumulator: 03m05s196ms971.333us
  ValueRate: 065ms779.664us / second
  Rate: 0.00244853 / second
  Percentiles: 1%=06s777ms973.436us; 5%=06s777ms973.436us; 10%=06s777ms973.436us; 20%=07s744ms350.902us; 50%=35s123ms851.761us; 80%=36s685ms352.966us; 90%=36s955ms241.673us; 95%=36s955ms241.673us; 99%=36s955ms241.673us
Metric: ExecuteReplicatedTime
  TotalSamples: 2756
  Accumulator: 05m32s609ms674.346us
  ValueRate: 094ms342.546us / second
  Rate: 1.09878 / second
  Percentiles: 1%=015ms914.408us; 5%=015ms340.588us; 10%=018ms227.378us; 20%=022ms900.657us; 50%=106ms211.167us; 80%=107ms736.647us; 90%=107ms043.007us; 95%=107ms294.227us; 99%=110ms656.957us
Metric: TransferToDeviceTime
  TotalSamples: 79060
  Accumulator: 07s496ms172.422us
  ValueRate: 010ms147.754us / second
  Rate: 113.826 / second
  Percentiles: 1%=044.970us; 5%=050.630us; 10%=058.370us; 20%=073.030us; 50%=083.550us; 80%=098.291us; 90%=141.440us; 

## Verify the numerical correctness of `apply_layers`

Under the same weights, and the same input tokens, both the for loop based
implementation and `apply_layers` based implementation should produce the same
output tokens.

In [18]:
import torch_xla
input_ids = torch.tensor(tokenized_datasets["train"][3]["input_ids"]).unsqueeze(0).type(torch.LongTensor) # type:ignore
attention_mask = torch.tensor(tokenized_datasets["train"][3]["attention_mask"]).unsqueeze(0) # type:ignore
input_ids = input_ids.to(torch_xla.device())
attention_mask = attention_mask.to(torch_xla.device())
torch_xla.sync()

In [19]:
input_ids

tensor([[128000,   5476,     73,  56761,    912,  86262,     88,   4298,    220,
             18,    551,    366,   3200,     29,  66416,    320,  11002,    551,
          50534,     99,  75267,  16144, 115687,  33710, 123283, 104612,     18,
           1174,  13318,    662,  86262,     88,   4298,    315,    279,  71735,
            220,     18,    883,   1174,  17037,  14183,    311,    439,  86262,
             88,   4298,  66416,  14767,   4994,   6457,   1174,    374,    264,
          39747,   3560,    571,     12,     31,   5737,   2835,   1847,   8040,
            555,  80949,    323,   7972,   5168,   1854,    369,    279,  32365,
          42585,    662,  45894,    304,   6186,    220,    679,     16,    304,
           6457,   1174,    433,    374,    279,   4948,   1847,    304,    279,
          86262,     88,   4298,   4101,    662,  21445,    287,    279,   1890,
          37608,    315,  39747,    323,   1972,    571,     12,     31,    892,
          27120,    439,   1

In [20]:
model.model.unroll_decoders = False
model.model.logged_messages = set()
logits = model.forward(input_ids, attention_mask).logits  # type:ignore
logits.shape, logits

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


NOTE: Using apply_layers to speed up compilation


(torch.Size([1, 128, 128000]),
 tensor([[[-3.5842, -3.6861, -3.5516,  ..., -3.0494, -4.4317, -4.5019],
          [-4.8827, -4.5860, -4.4935,  ..., -4.1767, -4.9484, -4.4475],
          [-4.7807, -4.6876, -4.4286,  ..., -4.4788, -4.8637, -4.3323],
          ...,
          [-4.8265, -4.3731, -4.5441,  ..., -4.9554, -5.2088, -4.9145],
          [-4.5418, -4.1405, -4.4080,  ..., -4.7163, -5.1185, -4.9804],
          [-4.8695, -5.3834, -4.9970,  ..., -4.9628, -5.9179, -5.2406]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [21]:
def pick_token(logits):
  return torch.argmax(logits, dim=-1)

In [22]:
tokens = pick_token(logits)
tokens

tensor([[ 284,   13,   13, 1174,   13,  574,  574, 1174,   16, 1389,  220, 3200,
           29,  720,  720,  220,  366,  220,  883,  883,  883,  883,  883,  883,
          883,  883,  883,  220, 1174,  720, 1174, 1174, 1174,  279,  220,  315,
         1049, 1389,  720,  323,  220,  311,  279,  264,  662, 1174,  662,  662,
          662,  279,  662,  323,  264,  220,  571,  315,   12,   31,  220,  571,
          662,  662,  311,  279,  662,  279,  662,  662,  662,  279,  220,  220,
          662,  720, 1174,  279,  220, 1049,   15, 1174,  279, 1174,  279,  574,
          264,  220,  571, 1174,  279,  220,  315,  662,  662,  662,  720,  374,
          279, 1176,  220, 1174,  279,  571,  279, 1174,   12,   31,  220, 1174,
         1174,  264, 1176,  662,  323,  220,  315, 1174,  311,  279,  220,  571,
          662,  279,  279, 1176,  578,  330,  330,  662]], device='xla:0')

In [23]:
tokenizer.decode(tokens[0].detach().cpu().numpy().tolist())

' =..,. was was,1 – unk> \n \n  <  ) ) ) ) ) ) ) ) ) , \n,,, the  of200 – \n and  to the a.,... the. and a  @ of-@  @.. to the. the... the  . \n, the 2000, the, the was a  @, the  of... \n is the first , the @ the,-@ ,, a first. and  of, to the  @. the the first The " ".'

In [24]:
model.model.unroll_decoders = True
model.model.logged_messages = set()
for_loop_logits = model.forward(input_ids, attention_mask).logits  # type:ignore
for_loop_logits.shape, for_loop_logits

NOTE: Using for loop to run decoder layers


(torch.Size([1, 128, 128000]),
 tensor([[[-3.5840, -3.6860, -3.5528,  ..., -3.0511, -4.4328, -4.5008],
          [-4.8854, -4.5886, -4.4982,  ..., -4.1781, -4.9502, -4.4503],
          [-4.7799, -4.6868, -4.4279,  ..., -4.4778, -4.8628, -4.3292],
          ...,
          [-4.8243, -4.3744, -4.5454,  ..., -4.9558, -5.2097, -4.9143],
          [-4.5432, -4.1429, -4.4069,  ..., -4.7164, -5.1198, -4.9812],
          [-4.8692, -5.3832, -4.9956,  ..., -4.9625, -5.9176, -5.2399]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [25]:
for_loop_tokens = pick_token(for_loop_logits)
for_loop_tokens

tensor([[ 284,   13,   13, 1174,   13,  574,  574, 1174,   16, 1389,  220, 3200,
           29,  720,  720,  220,  366,  220,  883,  883,  883,  883,  883,  883,
          883,  883,  883,  220, 1174,  720, 1174, 1174, 1174,  279,  220,  315,
         1049, 1389,  720,  323,  220,  311,  279,  264,  662, 1174,  662,  662,
          662,  279,  662,  323,  264,  220,  571,  315,   12,   31,  220,  571,
          662,  662,  311,  279,  662,  279,  662,  662,  662,  279,  220,  220,
          662,  720, 1174,  279,  220, 1049,   15, 1174,  279, 1174,  279,  574,
          264, 1176,  571, 1174,  279,  220,  315,  662,  662,  662,  720,  374,
          279, 1176,  220, 1174,  279,  571,  279, 1174,   12,   31,  220, 1174,
         1174,  264, 1176,  662,  323,  220,  315, 1174,  311,  279,  220,  571,
          662,  279,  279, 1176,  578,  330,  330,  662]], device='xla:0')

In [26]:
tokenizer.decode(for_loop_tokens[0].detach().cpu().numpy().tolist())

' =..,. was was,1 – unk> \n \n  <  ) ) ) ) ) ) ) ) ) , \n,,, the  of200 – \n and  to the a.,... the. and a  @ of-@  @.. to the. the... the  . \n, the 2000, the, the was a first @, the  of... \n is the first , the @ the,-@ ,, a first. and  of, to the  @. the the first The " ".'

In [27]:
# Should be accurate to within 1%
torch.allclose(logits, for_loop_logits, atol=1e-2, rtol=1e-2)

True