# Test `scan` and `scan_layers` in Mixtral-8x7B

Hugging Face usage follows https://github.com/huggingface/notebooks/blob/main/examples/language_modeling.ipynb

To test scan, we need to use a custom modification of the transformer repo:
https://github.com/pytorch-tpu/transformers/commit/24c90b41fc04ad3769b633b41988dc7f33c1175d

In [1]:
%env PJRT_DEVICE=TPU
%env XLA_USE_SPMD=1
%env XLA_USE_BF16=0

env: PJRT_DEVICE=TPU
env: XLA_USE_SPMD=1
env: XLA_USE_BF16=0


In [2]:
import torch
import torch_xla

In [3]:
from datasets import load_dataset

dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
tokenizer.pad_token_id = tokenizer.eos_token_id

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"], batch_size=1000)

Map: 100%|██████████| 1801350/1801350 [02:00<00:00, 14937.52 examples/s]


In [5]:
tokenized_datasets.keys()  # type:ignore

dict_keys(['test', 'train', 'validation'])

In [6]:
tokenized_datasets["train"][1].keys()  # type:ignore

dict_keys(['input_ids', 'attention_mask'])

In [7]:
block_size = 128

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
)

Map: 100%|██████████| 1801350/1801350 [19:42<00:00, 1522.77 examples/s]
Map: 100%|██████████| 3760/3760 [00:02<00:00, 1587.62 examples/s]


In [8]:
lm_datasets["train"][1].keys(), lm_datasets["validation"][1].keys()  # type:ignore

(dict_keys(['input_ids', 'attention_mask', 'labels']),
 dict_keys(['input_ids', 'attention_mask', 'labels']))

In [9]:
len(lm_datasets["validation"])  # type:ignore

3760

In [10]:
from transformers import AutoTokenizer, AutoConfig, MixtralForCausalLM

model_id = "mistralai/Mixtral-8x7B-v0.1"

# Define model configuration
config = AutoConfig.from_pretrained(
    model_id,
    vocab_size=len(tokenizer),
    torch_dtype=torch.bfloat16,
    num_hidden_layers=16,
    num_attention_heads=8,
    hidden_size=128,
    intermediate_size=128,
    num_local_experts=2,
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=True,
)
config.flash_attention = True
config.static = False
config.gmm = True
config.gmm_stack = False

# Instantiate the model
model = MixtralForCausalLM(config)
model = model.to(torch_xla.device())  # type: ignore



In [11]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="no",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    max_steps=2500,
    save_strategy="no",
    save_total_limit=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    logging_strategy="no",
    gradient_accumulation_steps=1,
    fp16=False,
    bf16=False,
    tpu_num_cores=4,
    push_to_hub=False,
)



In [12]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

Profiling server started: <_XLAC.profiler.ProfilerServer object at 0x7faf401a31f0>


max_steps is given, it will override any value given in num_train_epochs


In [13]:
trainer.train()

NOTE: Using for loop to run decoder layers


Step,Training Loss


TrainOutput(global_step=2500, training_loss=5.98215859375, metrics={'train_runtime': 3742.6639, 'train_samples_per_second': 42.75, 'train_steps_per_second': 0.668, 'total_flos': 826461388800000.0, 'train_loss': 5.98215859375, 'epoch': 0.08881941237076775})

In [14]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 2497
Metric: CompileTime
  TotalSamples: 3
  Accumulator: 07m37s533ms809.052us
  ValueRate: 128ms971.758us / second
  Rate: 0.00096818 / second
  Percentiles: 1%=02m01s577ms772.680us; 5%=02m01s577ms772.680us; 10%=02m01s577ms772.680us; 20%=02m01s577ms772.680us; 50%=02m16s523ms340.934us; 80%=02m20s433ms695.438us; 90%=02m20s433ms695.438us; 95%=02m20s433ms695.438us; 99%=02m20s433ms695.438us
Metric: ExecuteReplicatedTime
  TotalSamples: 2500
  Accumulator: 03m40s865ms660.447us
  ValueRate: 042ms255.951us / second
  Rate: 0.669445 / second
  Percentiles: 1%=058ms277.354us; 5%=059ms902.614us; 10%=059ms296.924us; 20%=060ms820.644us; 50%=062ms345.244us; 80%=065ms723.073us; 90%=066ms381.143us; 95%=069ms702.473us; 99%=076ms956.492us
Metric: TransferToDeviceTime
  TotalSamples: 4372771
  Accumulator: 06m38s665ms198.653us
  ValueRate: 177ms609.091us / second
  Rate: 2337.01 / second
  Percentiles: 1%=056.960us; 5%=063.530us; 10%=066.170us; 20%=068.850us; 50%=074.140u

## Train again, this time using scan

In [15]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = AutoConfig.from_pretrained(
    model_id,
    vocab_size=len(tokenizer),
    torch_dtype=torch.bfloat16,
    num_hidden_layers=16,
    num_attention_heads=8,
    hidden_size=128,
    intermediate_size=128,
    num_local_experts=2,
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=False,
)
config.flash_attention = True
config.static = False
config.gmm = True
config.gmm_stack = False

# Instantiate the model
model = MixtralForCausalLM(config)
model = model.to(torch_xla.device())  # type: ignore



In [16]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs


NOTE: Using scan_layers to speed up compilation


Step,Training Loss


TrainOutput(global_step=2500, training_loss=5.97009375, metrics={'train_runtime': 2902.9153, 'train_samples_per_second': 55.117, 'train_steps_per_second': 0.861, 'total_flos': 826461388800000.0, 'train_loss': 5.97009375, 'epoch': 0.08881941237076775})

We can see that the for loop and scan model train to the same loss over time.

In [17]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 2497
Metric: CompileTime
  TotalSamples: 3
  Accumulator: 01m20s375ms273.037us
  ValueRate: 033ms672.670us / second
  Rate: 0.0012195 / second
  Percentiles: 1%=26s984ms557.394us; 5%=26s984ms557.394us; 10%=26s984ms557.394us; 20%=26s984ms557.394us; 50%=26s301ms178.710us; 80%=28s091ms536.933us; 90%=28s091ms536.933us; 95%=28s091ms536.933us; 99%=28s091ms536.933us
Metric: ExecuteReplicatedTime
  TotalSamples: 2500
  Accumulator: 04m53s907ms346.435us
  ValueRate: 082ms882.578us / second
  Rate: 0.87776 / second
  Percentiles: 1%=091ms976.401us; 5%=091ms457.350us; 10%=092ms846.801us; 20%=092ms303.480us; 50%=093ms262.800us; 80%=094ms094.020us; 90%=095ms686.330us; 95%=095ms262.760us; 99%=096ms311.770us
Metric: TransferToDeviceTime
  TotalSamples: 1532752
  Accumulator: 02m46s218ms925.740us
  ValueRate: 044ms886.994us / second
  Rate: 672.615 / second
  Percentiles: 1%=040.400us; 5%=043.700us; 10%=045.700us; 20%=049.451us; 50%=065.170us; 80%=075.520us; 90%=080.750

## Verify the numerical correctness of `scan_layers`

Under the same weights, and the same input tokens, both the for loop based
implementation and `scan_layers` based implementation should produce the same
output tokens.

In [18]:
import torch_xla
input_ids = torch.tensor(tokenized_datasets["train"][3]["input_ids"]).unsqueeze(0).type(torch.LongTensor) # type:ignore
attention_mask = torch.tensor(tokenized_datasets["train"][3]["attention_mask"]).unsqueeze(0) # type:ignore
input_ids = input_ids.to(torch_xla.device())
attention_mask = attention_mask.to(torch_xla.device())
torch_xla.sync()

In [19]:
input_ids

tensor([[    1, 28705,  5355, 28768, 28934,   708,   550,  1093, 28724,  3931,
         28705, 28770,   714, 28705,     0, 28705, 23967,  4992,   325,  8092,
           714, 28705, 30842, 30016, 28993, 31428, 30000, 29182, 29753, 30051,
         29306, 29322, 28770,  1200,  8724,   842,   550,  1093, 28724,  3931,
           302,   272, 13711,  2222, 28705, 28770,  1143,  1200, 14473, 11449,
           298,   390,   550,  1093, 28724,  3931, 23967,  4992,  6950,  3536,
          4720,  1200,   349,   264, 12529,   745,  3905,   802, 28733, 28818,
          4543,  3798,  2039,  6202,   486,   318,  4770,   304,  9347, 28723,
         28790,  1522,   354,   272,  6879, 23558,  4194,   522,   842,  1298,
         22246,   297,  4624, 28705, 28750, 28734, 28740, 28740,   297,  4720,
          1200,   378,   349,   272,  4008,  2039,   297,   272,   550,  1093,
         28724,  3931,  3518,   842,  2929,  2193,   288,   272,  1348, 22104,
           302, 12529,   745,   304,  1353,   802, 2

In [20]:
model.model.unroll_decoders = False
model.model.logged_messages = set()
logits = model.forward(input_ids, attention_mask).logits  # type:ignore
logits.shape, logits

NOTE: Using scan_layers to speed up compilation


(torch.Size([1, 128, 32000]),
 tensor([[[ 1.9352,  2.4326,  4.1603,  ..., -2.3325, -2.4542, -2.3538],
          [ 2.2733,  2.7671,  4.7023,  ..., -2.7263, -2.6115, -2.7480],
          [ 2.2035,  2.7508,  4.6773,  ..., -2.7347, -2.6657, -2.7148],
          ...,
          [ 2.2157,  2.8167,  4.7999,  ..., -2.8423, -2.6625, -2.7821],
          [ 2.2212,  2.7764,  4.7414,  ..., -2.8068, -2.6568, -2.7357],
          [ 2.1995,  2.7361,  4.6238,  ..., -2.7696, -2.6364, -2.6633]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [21]:
def pick_token(logits):
  return torch.argmax(logits, dim=-1)

In [22]:
tokens = pick_token(logits)
tokens

tensor([[28705,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,  

In [23]:
tokenizer.decode(tokens[0].detach().cpu().numpy().tolist())

'</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>'

In [24]:
model.model.unroll_decoders = True
model.model.logged_messages = set()
for_loop_logits = model.forward(input_ids, attention_mask).logits  # type:ignore
for_loop_logits.shape, for_loop_logits

NOTE: Using for loop to run decoder layers


(torch.Size([1, 128, 32000]),
 tensor([[[ 1.9352,  2.4326,  4.1603,  ..., -2.3325, -2.4542, -2.3538],
          [ 2.2744,  2.7641,  4.7005,  ..., -2.7256, -2.6111, -2.7477],
          [ 2.2037,  2.7484,  4.6760,  ..., -2.7333, -2.6671, -2.7150],
          ...,
          [ 2.2164,  2.8177,  4.8028,  ..., -2.8431, -2.6634, -2.7836],
          [ 2.2204,  2.7747,  4.7416,  ..., -2.8071, -2.6574, -2.7342],
          [ 2.1988,  2.7359,  4.6236,  ..., -2.7689, -2.6370, -2.6630]]],
        device='xla:0', grad_fn=<UnsafeViewBackward0>))

In [25]:
for_loop_tokens = pick_token(for_loop_logits)
for_loop_tokens

tensor([[28705,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,  

In [26]:
tokenizer.decode(for_loop_tokens[0].detach().cpu().numpy().tolist())

'</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>'

In [37]:
# Should be accurate to within 1%
torch.allclose(logits, for_loop_logits, atol=1e-2, rtol=1e-2)

True

In [28]:
# Shouldn't be completely the same
torch.allclose(logits, for_loop_logits, atol=1e-6, rtol=1e-6)

False

# Test the gradients of scan

After I run both scan and for loop versions of the model on the same input, their
gradients should also be similar.

In [29]:
torch_xla.sync()

input_ids.requires_grad_(False)
attention_mask.requires_grad_(False)
labels = input_ids[:, :].clone().contiguous()

# Run for loop model and collect the gradients
torch.manual_seed(42)
for_loop_grads = []
model.zero_grad()
model.model.zero_grad()
torch_xla.sync()
model.model.unroll_decoders = True
model.model.logged_messages = set()
with torch.enable_grad():
  model(input_ids, attention_mask, labels=labels).loss.backward()  # type: ignore
torch_xla.sync()
for (name, param) in model.named_parameters():
  assert param.grad is not None
  for_loop_grads.append((name, param.grad.clone().detach()))

# Run scan model and collect the gradients
torch.manual_seed(42)
scan_grads = []
model.zero_grad()
model.model.zero_grad()
torch_xla.sync()
model.model.unroll_decoders = False
model.model.logged_messages = set()
with torch.enable_grad():
  model(input_ids, attention_mask, labels=labels).loss.backward()  # type: ignore
torch_xla.sync()
for (name, param) in model.named_parameters():
  assert param.grad is not None
  scan_grads.append((name, param.grad.clone().detach()))

NOTE: Using for loop to run decoder layers
NOTE: Using scan_layers to speed up compilation


In [30]:
# Compare the gradients
assert len(for_loop_grads) == len(scan_grads)
assert len(for_loop_grads) > 0

In [36]:
for ((for_loop_name, for_loop_grad), (scan_name, scan_grad)) in zip(for_loop_grads, scan_grads):
  assert for_loop_name == scan_name
  assert torch.allclose(for_loop_grad, scan_grad, atol=1e-3, rtol=1e-3), f"{for_loop_name} mismatch by: {torch.max(torch.abs(for_loop_grad - scan_grad))}"
  print(f"Pass: {for_loop_name}")

Pass: model.embed_tokens.weight
Pass: model.layers.0.self_attn.q_proj.weight
Pass: model.layers.0.self_attn.k_proj.weight
Pass: model.layers.0.self_attn.v_proj.weight
Pass: model.layers.0.self_attn.o_proj.weight
Pass: model.layers.0.block_sparse_moe.gate.weight
Pass: model.layers.0.block_sparse_moe.experts.w1
Pass: model.layers.0.block_sparse_moe.experts.w2
Pass: model.layers.0.block_sparse_moe.experts.w3
Pass: model.layers.0.input_layernorm.weight
Pass: model.layers.0.post_attention_layernorm.weight
Pass: model.layers.1.self_attn.q_proj.weight
Pass: model.layers.1.self_attn.k_proj.weight
Pass: model.layers.1.self_attn.v_proj.weight
Pass: model.layers.1.self_attn.o_proj.weight
Pass: model.layers.1.block_sparse_moe.gate.weight
Pass: model.layers.1.block_sparse_moe.experts.w1
Pass: model.layers.1.block_sparse_moe.experts.w2
Pass: model.layers.1.block_sparse_moe.experts.w3
Pass: model.layers.1.input_layernorm.weight
Pass: model.layers.1.post_attention_layernorm.weight
Pass: model.layers.2

In [32]:
for_loop_grads[3], torch.max(for_loop_grads[3][1]), torch.min(for_loop_grads[3][1])

(('model.layers.0.self_attn.v_proj.weight',
  tensor([[-6.0332e-04,  5.3868e-04, -7.1599e-04,  ..., -2.9187e-05,
           -6.4175e-05, -3.4787e-04],
          [-1.0868e-03,  1.7916e-03, -2.0205e-04,  ..., -2.7523e-04,
           -4.3805e-04, -1.5460e-03],
          [-1.2247e-03,  5.8597e-04, -1.6040e-03,  ...,  2.2637e-04,
            9.9657e-04, -1.5635e-04],
          ...,
          [ 4.8003e-04, -2.9207e-04,  2.3917e-04,  ...,  1.4600e-04,
           -1.7820e-04,  3.6919e-04],
          [ 1.9950e-03, -1.5000e-03,  1.9700e-03,  ...,  4.3291e-05,
           -8.0997e-04,  1.0016e-03],
          [-3.6173e-03,  1.8253e-03, -2.9197e-03,  ..., -4.3047e-04,
            1.6450e-03, -1.8373e-03]], device='xla:0')),
 tensor(0.0125, device='xla:0'),
 tensor(-0.0125, device='xla:0'))

In [33]:
scan_grads[3], torch.max(scan_grads[3][1]), torch.min(scan_grads[3][1])

(('model.layers.0.self_attn.v_proj.weight',
  tensor([[-5.8174e-04,  5.2028e-04, -6.6052e-04,  ..., -1.5826e-06,
           -4.5698e-05, -3.6250e-04],
          [-1.0702e-03,  1.7425e-03, -2.0061e-04,  ..., -2.7088e-04,
           -4.2271e-04, -1.5163e-03],
          [-1.2026e-03,  5.9066e-04, -1.5438e-03,  ...,  2.6712e-04,
            1.0029e-03, -1.5864e-04],
          ...,
          [ 4.9042e-04, -3.0399e-04,  2.6145e-04,  ...,  1.7872e-04,
           -1.5594e-04,  3.4986e-04],
          [ 2.0169e-03, -1.5067e-03,  1.9870e-03,  ...,  5.5200e-05,
           -8.0024e-04,  1.0151e-03],
          [-3.6087e-03,  1.8497e-03, -2.9272e-03,  ..., -4.1953e-04,
            1.7206e-03, -1.8574e-03]], device='xla:0')),
 tensor(0.0126, device='xla:0'),
 tensor(-0.0127, device='xla:0'))