# Profile `scan` and `scan_layers` in Llama 3

We generate profiles to see what's causing low TPU duty cycle.

In [1]:
%env PJRT_DEVICE=TPU
%env XLA_USE_SPMD=1
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env XLA_SAVE_TENSORS_FILE=profile/graph.log
%env TF_CPP_MIN_LOG_LEVEL=0
%env TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3"

env: PJRT_DEVICE=TPU
env: XLA_USE_SPMD=1
env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1
env: XLA_SAVE_TENSORS_FILE=profile/graph.log
env: TF_CPP_MIN_LOG_LEVEL=0
env: TF_CPP_VMODULE="xla_graph_executor=5,pjrt_computation_client=3"


In [2]:
import torch
import torch_xla

In [3]:
from datasets import load_dataset

dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.bos_token_id = 128000
tokenizer.eos_token_id = 128001
tokenizer.pad_token_id = tokenizer.eos_token_id 

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"], batch_size=1000)

Map: 100%|██████████| 3760/3760 [00:00<00:00, 9411.36 examples/s]


In [5]:
tokenized_datasets.keys()  # type:ignore

dict_keys(['test', 'train', 'validation'])

In [6]:
tokenized_datasets["train"][1].keys()  # type:ignore

dict_keys(['input_ids', 'attention_mask'])

In [7]:
block_size = 128

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
)

Map: 100%|██████████| 3760/3760 [00:01<00:00, 1895.87 examples/s]


In [8]:
lm_datasets["train"][1].keys(), lm_datasets["validation"][1].keys()  # type:ignore

(dict_keys(['input_ids', 'attention_mask', 'labels']),
 dict_keys(['input_ids', 'attention_mask', 'labels']))

In [9]:
len(lm_datasets["validation"])  # type:ignore

3760

In [10]:
from transformers import LlamaConfig, LlamaForCausalLM

# Define model configuration
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,  # Model size
    num_hidden_layers=32,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    intermediate_size=1024,  # Size of the hidden feedforward layer
    max_position_embeddings=128,  # Max tokens in a sequence
    use_cache=False,
    unroll_decoders=False,
)

# Instantiate the model
model = LlamaForCausalLM(config)

In [11]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    per_device_train_batch_size=48,
    per_device_eval_batch_size=48,
    num_train_epochs=1,
    max_steps=2500,
    save_strategy="no",
    save_total_limit=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=500,
    gradient_accumulation_steps=1,
    fp16=False,
    bf16=False,
    tpu_num_cores=4,
    push_to_hub=False,
)

2024-09-05 18:08:41.328475: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:41.328550: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:41.328585: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:44.083665: I external/xla/xla/pjrt/pjrt_c_api_client.cc:126] PjRtCApiClient created.
2024-09-05 18:08:44.084753: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:44.084896: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:44.084918: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:44.087101: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:41.328550: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:41.328585: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09

In [12]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"].shuffle(seed=42),  # type:ignore
    eval_dataset=lm_datasets["validation"].shuffle(seed=42),  # type:ignore
    tokenizer=tokenizer,
)

2024-09-05 18:08:44.147054: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
max_steps is given, it will override any value given in num_train_epochs


In [13]:
import os
import torch_xla.debug.profiler as xp

profile_port = 9012
# you can also set profile_logdir to a gs bucket, for example
# profile_logdir = "gs://your_gs_bucket/profile"
profile_logdir = "profile/"
duration_ms = 300000
assert profile_logdir.startswith('gs://') or os.path.exists(profile_logdir)
server = xp.start_server(profile_port)
# Ideally you want to start the profile tracing after the initial compilation, for example
# at step 5.
xp.trace_detached(
    f'localhost:{profile_port}', profile_logdir, duration_ms=duration_ms)

trainer.train()

2024-09-05 18:08:44.340894: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:44.342334: I external/tsl/tsl/profiler/rpc/profiler_server.cc:46] Profiler server listening on [::]:9012 selected port:9012
2024-09-05 18:08:44.343301: I external/tsl/tsl/profiler/rpc/client/capture_profile.cc:213] Profiler delay_ms was 0, start_timestamp_ns set to 1725559724343285169 [2024-09-05T18:08:44.343285169+00:00]
2024-09-05 18:08:44.343387: I external/tsl/tsl/profiler/rpc/client/remote_profiler_session_manager.cc:78] Deadline set to 2024-09-05T18:18:44.343238699+00:00 because max_session_duration_ms was 600000 and session_creation_timestamp_ns was 1725559724343238699 [2024-09-05T18:08:44.343238699+00:00]
2024-09-05 18:08:44.343485: I external/tsl/tsl/profiler/rpc/client/profiler_client.cc:124] Asynchronous gRPC Profile() to localhost:9012
2024-09-05 18:08:44.343546: I external/tsl/tsl/profiler/rpc/client/remote_profiler_session_manager.cc:99] Issued Profile gRPC 

Starting to trace for 300000 ms. Remaining attempt(s): 2


2024-09-05 18:08:45.239329: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:08:45.244979: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization


NOTE: Using scan_layers to speed up compilation


2024-09-05 18:08:45.787813: I torch_xla/csrc/runtime/pjrt_computation_client.cc:341] ReplicateShardedData (handle=94622795681024, shape=f32[512,1024]{1,0})
2024-09-05 18:08:45.839157: I torch_xla/csrc/runtime/pjrt_computation_client.cc:341] ReplicateShardedData (handle=94622795676272, shape=f32[1024,512]{1,0})
2024-09-05 18:08:45.839206: I torch_xla/csrc/runtime/pjrt_computation_client.cc:341] ReplicateShardedData (handle=94622809550368, shape=f32[])
2024-09-05 18:08:45.839235: I torch_xla/csrc/runtime/pjrt_computation_client.cc:341] ReplicateShardedData (handle=94622795608544, shape=f32[512,512]{1,0})
2024-09-05 18:08:45.839262: I torch_xla/csrc/runtime/pjrt_computation_client.cc:341] ReplicateShardedData (handle=94622795595488, shape=f32[512,512]{1,0})
2024-09-05 18:08:45.839318: I torch_xla/csrc/runtime/pjrt_computation_client.cc:341] ReplicateShardedData (handle=94622795713088, shape=f32[48,128,512]{2,1,0})
2024-09-05 18:08:45.839512: I torch_xla/csrc/runtime/pjrt_computation_clien

Epoch,Training Loss,Validation Loss


2024-09-05 18:08:48.649767: I torch_xla/csrc/runtime/pjrt_computation_client.cc:569] Auto SPMD partitioning disabled.
2024-09-05 18:09:14.602139: I torch_xla/csrc/runtime/pjrt_computation_client.cc:635] memory usage detail = CompiledMemoryStats(generated_code_size_in_bytes=57962496, argument_size_in_bytes=860174336, output_size_in_bytes=2579901440, alias_size_in_bytes=859965952, temp_size_in_bytes=5713100800, host_generated_code_size_in_bytes=0, host_argument_size_in_bytes=0, host_output_size_in_bytes=0, host_alias_size_in_bytes=0, host_temp_size_in_bytes=0)
2024-09-05 18:09:15.755638: I torch_xla/csrc/device.cpp:85] Using SPMD virtual device optimization
2024-09-05 18:09:15.763065: I torch_xla/csrc/runtime/pjrt_computation_client.cc:873] Processing output with shape (f32[], f32[], f32[], f32[128000,512], f32[128000,512], /*index=5*/f32[128000,512], f32[512,512], f32[512,512], f32[512,512], f32[512,512], /*index=10*/f32[512,512], f32[512,512], f32[512,512], f32[512,512], f32[512,512], 

In [None]:
import torch_xla.debug.metrics as met
print(met.short_metrics_report())
met.clear_all()

Counter: CachedCompile
  Value: 2668
Metric: CompileTime
  TotalSamples: 88
  Accumulator: 05m45s242ms492.497us
  ValueRate: 112ms815.668us / second
  Rate: 0.0344962 / second
  Percentiles: 1%=028ms879.307us; 5%=031ms263.066us; 10%=031ms456.126us; 20%=032ms031.136us; 50%=033ms408.706us; 80%=035ms237.815us; 90%=038ms359.085us; 95%=10s142ms329.682us; 99%=02m35s284ms655.907us
Metric: ExecuteReplicatedTime
  TotalSamples: 2756
  Accumulator: 03m54s205ms461.657us
  ValueRate: 067ms553.151us / second
  Rate: 1.2048 / second
  Percentiles: 1%=002ms176.630us; 5%=003ms665.040us; 10%=016ms926.918us; 20%=020ms151.147us; 50%=068ms543.992us; 80%=068ms898.421us; 90%=068ms116.450us; 95%=068ms298.821us; 99%=069ms369.122us
Metric: TransferToDeviceTime
  TotalSamples: 13123
  Accumulator: 01s015ms904.557us
  ValueRate: 473.883us / second
  Rate: 6.09047 / second
  Percentiles: 1%=040.200us; 5%=043.580us; 10%=045.910us; 20%=049.740us; 50%=066.140us; 80%=102.590us; 90%=146.380us; 95%=153.540us; 99%=162.0