In [2]:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.debug.profiler as xp
import torch_xla.debug.metrics as met

dev = xm.xla_device()
t1 = torch.randn(3,3,device=dev)
t2 = torch.randn(3,3,device=dev)
print(t1 + t2)


# Monitor XLA compilation
def monitor_xla():
    import threading
    import time
    
    def print_xla_metrics():
        while True:
            try:
                metrics = met.metrics_report()
                print(f"[XLA METRICS] {metrics}")
                time.sleep(10)
            except Exception as e:
                print(f"[XLA METRICS ERROR] {e}")
                break
    
    thread = threading.Thread(target=print_xla_metrics, daemon=True)
    thread.start()

monitor_xla()

tensor([[ 1.8535, -0.0823, -0.2653],
        [-0.0051,  1.3175, -0.5839],
        [-1.2089,  1.2293,  1.3118]], device='xla:0')
[XLA METRICS] Metric: DeviceLockWait
  TotalSamples: 2
  Accumulator: 009.610us
  ValueRate: 02s288ms095.238us / second
  Rate: 476190 / second
  Percentiles: 1%=001.070us; 5%=001.070us; 10%=001.070us; 20%=001.070us; 50%=008.540us; 80%=008.540us; 90%=008.540us; 95%=008.540us; 99%=008.540us
Metric: InputOutputAliasCount
  TotalSamples: 1
  Accumulator: 0.00
  Percentiles: 1%=0.00; 5%=0.00; 10%=0.00; 20%=0.00; 50%=0.00; 80%=0.00; 90%=0.00; 95%=0.00; 99%=0.00
Metric: LazyTracing
  TotalSamples: 10
  Accumulator: 135ms874.190us
  ValueRate: 891ms168.266us / second
  Rate: 66.074 / second
  Percentiles: 1%=000.650us; 5%=000.650us; 10%=001.170us; 20%=006.030us; 50%=079.800us; 80%=997.700us; 90%=133ms274.820us; 95%=133ms274.820us; 99%=133ms274.820us
Metric: TensorToData
  TotalSamples: 1
  Accumulator: 139.820us
  Percentiles: 1%=139.820us; 5%=139.820us; 10%=139.820u

[XLA METRICS] Metric: DeviceLockWait
  TotalSamples: 8
  Accumulator: 047.211us
  ValueRate: 008.922us / second
  Rate: 1.51189 / second
  Percentiles: 1%=001.070us; 5%=001.070us; 10%=001.070us; 20%=001.700us; 50%=008.540us; 80%=010.650us; 90%=010.940us; 95%=010.940us; 99%=010.940us
Metric: InputOutputAliasCount
  TotalSamples: 3
  Accumulator: 0.00
  ValueRate: 0.00 / second
  Rate: 0.437702 / second
  Percentiles: 1%=0.00; 5%=0.00; 10%=0.00; 20%=0.00; 50%=0.00; 80%=0.00; 90%=0.00; 95%=0.00; 99%=0.00
Metric: IrValueTensorToXlaData
  TotalSamples: 220
  Accumulator: 516ms097.514us
  ValueRate: 488ms660.725us / second
  Rate: 207.878 / second
  Percentiles: 1%=031.120us; 5%=033.720us; 10%=041.360us; 20%=085.610us; 50%=575.060us; 80%=003ms750.180us; 90%=009ms127.149us; 95%=011ms146.769us; 99%=015ms007.469us
Metric: LazyTracing
  TotalSamples: 3300
  Accumulator: 01s373ms102.924us
  ValueRate: 645ms638.000us / second
  Rate: 24260.7 / second
  Percentiles: 1%=012.509us; 5%=013.210us; 10%=

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import os

import torch
from torch.optim import AdamW
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from transformers.data.data_collator import DataCollatorForLanguageModeling
from transformers.training_args import TrainingArguments
# from accelerate import Accelerator, DataLoaderConfiguration, DistributedType

In [None]:
from streaming_dataset import StatefulShardedDataset, EvaluationShardedDataset
from callback import DynamicSamplingOnEvaluationCallback, StepTimingCallback


In [6]:
TOP_SEED = 42
torch.manual_seed(TOP_SEED)

base = '/home/shuyaoli/llm_data/converted_dataset'
domain_dirs = {
    'book':        os.path.join(base, 'book'),
    'arxiv':       os.path.join(base, 'arxiv'),
    'stackexchange':os.path.join(base, 'stackexchange'),
    'wiki':        os.path.join(base, 'wiki'),
    'c4-rp':       os.path.join(base, 'c4-rp'),
    'cc':          os.path.join(base, 'cc'),
    'github':      os.path.join(base, 'github'),
}

In [7]:
EVAL_PATH = '/home/shuyaoli/llm_data/converted_dataset/eval_merge'
eval_dataset = EvaluationShardedDataset(
    EVAL_PATH,
)

In [8]:
initial_weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

# Instantiate your master dataset
master_train_dataset = StatefulShardedDataset(
    domain_dirs=domain_dirs,
    initial_weights=initial_weights,
    chunk_size=32  # small for demo
)

Initialized dataset with 7 domains.
  -> Found 7 shards for domain 'book'
  -> Found 4 shards for domain 'arxiv'
  -> Found 4 shards for domain 'stackexchange'
  -> Found 7 shards for domain 'wiki'
  -> Found 23 shards for domain 'c4-rp'
  -> Found 103 shards for domain 'cc'
  -> Found 7 shards for domain 'github'


In [9]:
# Using the Sheared-LLaMA model for continued pretraining.
model_name = "princeton-nlp/Sheared-LLaMA-1.3B-Pruned"

In [10]:
# You will need to be logged into your Hugging Face account and have
# access to meta-llama models for this to work.
# `huggingface-cli login`
tokenizer_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Set pad token to EOS token for Causal LM
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)

In [11]:
training_args = TrainingArguments(
    output_dir="./tpu_eval_sampling_model",
    max_steps=5000,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,     # increase if you need >16 global batch on v4-8
    gradient_checkpointing=False,
    logging_steps=1,
    eval_strategy="steps",
    eval_steps=200, # Run evaluation every 1000 steps
    bf16=True,                         # enable BF16
    bf16_full_eval=True,
    dataloader_num_workers=0, 
    remove_unused_columns=True,
)

In [12]:
eval_sampling_callback = DynamicSamplingOnEvaluationCallback(
    dataset=master_train_dataset,
    weight_update_fn=lambda x: 1
)

In [13]:
from trainer import StreamingTrainer
trainer = StreamingTrainer(
    model=model,
    args=training_args,
    train_dataset=master_train_dataset,
    eval_dataset=eval_dataset, # Provide the evaluation dataset
    callbacks=[
        # eval_sampling_callback, 
        StepTimingCallback(),
    ], # Use the new callback
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

In [14]:
batch = next(iter(trainer.get_train_dataloader()))
print(batch.keys())                     # should include 'input_ids', 'attention_mask', 'labels'
print(batch['input_ids'].device)        # should be xla:0 (or your TPU device)

[StreamingTrainer] Use device: xla:0 for dataloader
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 3
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 3
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 6
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 6
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one

[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 2
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 5
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 6
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 2
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 4
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 4
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 4
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 3
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample fr

In [15]:
# Force XLA compilation to complete
xm.mark_step()

In [16]:
def test_minimal_training():
    print("[DEBUG] Getting train dataloader...")
    train_dataloader = trainer.get_train_dataloader()
    
    print("[DEBUG] Getting first batch...")
    batch = next(iter(train_dataloader))
    print(f"[DEBUG] Batch keys: {batch.keys()}")
    print(f"[DEBUG] Batch device: {batch['input_ids'].device}")
    
    print("[DEBUG] Moving model to TPU...")
    model.to(xm.xla_device())
    
    print("[DEBUG] Forward pass...")
    with torch.no_grad():
        outputs = model(**batch)
        print(f"[DEBUG] Loss: {outputs.loss}")
    
    print("[DEBUG] Mark step...")
    xm.mark_step()
    
    print("[DEBUG] Getting second batch...")
    batch2 = next(iter(train_dataloader))
    print(f"[DEBUG] Second batch device: {batch2['input_ids'].device}")
    
    print("[DEBUG] Second forward pass...")
    with torch.no_grad():
        outputs2 = model(**batch2)
        print(f"[DEBUG] Second loss: {outputs2.loss}")
    
    print("[DEBUG] Second mark step...")
    xm.mark_step()
    
    print("[DEBUG] Minimal test completed!")

# Run this instead of trainer.train()
test_minimal_training()

[DEBUG] Getting train dataloader...
[StreamingTrainer] Use device: xla:0 for dataloader
[DEBUG] Getting first batch...
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 3
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 3
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 6
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 6
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DA

In [None]:
trainer.train()

[StreamingTrainer] Use device: xla:0 for dataloader
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 3
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 3
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 6
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 6
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one sample from domain 0
[DATASET DEBUG] Item yielded successfully
[Dataset DEBUG] yielding one

Step,Training Loss,Validation Loss


In [None]:
import os; cpu_count = os.cpu_count()

In [None]:
cpu_count

In [None]:
import torch_xla.core.xla_model as xm

In [None]:
world_size = xm.xrt_world_size()

In [None]:
import torch_xla.runtime as xr; world_size = xr.world_size()

In [None]:
world_size