In [1]:
!export PJRT_DEVICE=TPU

In [2]:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

dev = xm.xla_device()
t1 = torch.randn(3,3,device=dev)
t2 = torch.randn(3,3,device=dev)
print(t1 + t2)




tensor([[ 1.8535, -0.0823, -0.2653],
        [-0.0051,  1.3175, -0.5839],
        [-1.2089,  1.2293,  1.3118]], device='xla:0')


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import argparse
import os

import evaluate
import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    Trainer,
    TrainerCallback,
    TrainingArguments,
    TrainerState,
    TrainerControl,
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
)

from accelerate import Accelerator, DataLoaderConfiguration, DistributedType

In [5]:
import numpy as np
from torch.utils.data import IterableDataset, get_worker_info, DataLoader

from dataset import StatefulShardedDataset, EvaluationShardedDataset

In [6]:
TOP_SEED = 42
torch.manual_seed(TOP_SEED)

base = '/home/shuyaoli/llm_data/converted_dataset'
domain_dirs = {
    'book':        os.path.join(base, 'book'),
    'arxiv':       os.path.join(base, 'arxiv'),
    'stackexchange':os.path.join(base, 'stackexchange'),
    'wiki':        os.path.join(base, 'wiki'),
    'c4-rp':       os.path.join(base, 'c4-rp'),
    'cc':          os.path.join(base, 'cc'),
    'github':      os.path.join(base, 'github'),
}

In [7]:
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
import torch_xla.core.xla_model as xm
import torch

In [8]:
from callback import DynamicSamplingOnEvaluationCallback

In [9]:
from torch_xla.debug import metrics
from torch_xla.distributed.parallel_loader import MpDeviceLoader
class StreamingTrainer(Trainer):
    def get_train_dataloader(self):
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        # 1) Build a standard DataLoader that uses your HF data_collator
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            num_workers=self.args.dataloader_num_workers,
            collate_fn=self.data_collator,            # ← ensures input_ids, attention_mask, labels
            pin_memory=False,                         # no pinned memory on TPU
            generator=torch.Generator().manual_seed(self.args.seed),
            persistent_workers=True,
            drop_last=True,
        )
        # 2) Wrap it so that every batch is moved onto the TPU device
        return MpDeviceLoader(train_loader, device=self.args.device)

    def get_eval_dataloader(self, eval_dataset=None):
        ds = eval_dataset or self.eval_dataset
        if ds is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_loader = DataLoader(
            ds,
            batch_size=self.args.per_device_eval_batch_size,
            num_workers=self.args.dataloader_num_workers,
            collate_fn=self.data_collator,            # ← again, to get labels for eval loss
            pin_memory=False,
            generator=torch.Generator().manual_seed(self.args.seed + 1),
            persistent_workers=True,
            shuffle=False,
            drop_last=False,
        )
        return MpDeviceLoader(eval_loader, device=self.args.device)

    def training_step(self, model, inputs, num_items_in_batch):
        # run one step
        out = super().training_step(model, inputs, num_items_in_batch)

        # on the very first step, dump some debug info:
        if self.state.global_step == 0:
            xm.master_print("### TPU CHECK ###")
            xm.master_print("  XLA devices:", xm.get_xla_supported_devices())
            xm.master_print("  Current device:", xm.xla_device())
            xm.master_print("  Batch on device:", inputs['input_ids'].device)
            xm.master_print("  Model on device:", next(model.parameters()).device)
            xm.master_print("  TPU config:", os.environ.get("XRT_TPU_CONFIG"))
            xm.master_print(metrics.metrics_report())
        return out        

In [10]:
FINAL_EVAL_BATCH_SIZE = 8
EVAL_PATH = '/home/shuyaoli/llm_data/converted_dataset/eval_merge'
eval_dataset = EvaluationShardedDataset(
    EVAL_PATH,
)

In [11]:
TRAIN_BATCH_SIZE = 8




In [12]:
initial_weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

# Instantiate your master dataset
master_train_dataset = StatefulShardedDataset(
    domain_dirs=domain_dirs,
    initial_weights=initial_weights,
    chunk_size=4  # small for demo
)

Initialized dataset with 7 domains.
  -> Found 7 shards for domain 'book'
  -> Found 4 shards for domain 'arxiv'
  -> Found 4 shards for domain 'stackexchange'
  -> Found 7 shards for domain 'wiki'
  -> Found 23 shards for domain 'c4-rp'
  -> Found 103 shards for domain 'cc'
  -> Found 7 shards for domain 'github'


In [13]:
# Using the Sheared-LLaMA model for continued pretraining.
model_name = "princeton-nlp/Sheared-LLaMA-1.3B-Pruned"

In [14]:
# You will need to be logged into your Hugging Face account and have
# access to meta-llama models for this to work.
# `huggingface-cli login`
tokenizer_name = "meta-llama/Llama-2-7b-hf"

In [15]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Set pad token to EOS token for Causal LM
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)

In [16]:
training_args = TrainingArguments(
    output_dir="./tpu_eval_sampling_model",
    max_steps=5000,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    logging_steps=100,
    
    # --- Crucial for this strategy ---
    eval_strategy="steps",
    eval_steps=2, # Run evaluation every 1000 steps
    # ---------------------------------
    
    dataloader_num_workers=30, 
    remove_unused_columns=True,
)

In [17]:
eval_sampling_callback = DynamicSamplingOnEvaluationCallback(
    dataset=master_train_dataset,
    weight_update_fn=lambda x: 1
)

In [18]:
trainer = StreamingTrainer(
    model=model,
    args=training_args,
    train_dataset=master_train_dataset,
    eval_dataset=eval_dataset, # Provide the evaluation dataset
    callbacks=[eval_sampling_callback], # Use the new callback
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

In [20]:
batch = next(iter(trainer.get_train_dataloader()))
print(batch.keys())                     # should include 'input_ids', 'attention_mask', 'labels'
print(batch['input_ids'].device)        # should be xla:0 (or your TPU device)

KeysView({'input_ids': tensor([[ 5809, 29908,   921,  ...,   829,  2271, 29958],
        [  278,   937,   934,  ...,   306,   750,  2715]], device='xla:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='xla:0'), 'labels': tensor([[ 5809, 29908,   921,  ...,   829,  2271, 29958],
        [  278,   937,   934,  ...,   306,   750,  2715]], device='xla:0')})
xla:0


In [None]:
trainer.train()

### TPU CHECK ###
  XLA devices: ['xla:0', 'xla:1', 'xla:2', 'xla:3']
  Current device: xla:0
  Batch on device: xla:0
  Model on device: xla:0
  TPU config: None
Metric: DeviceLockWait
  TotalSamples: 12
  Accumulator: 419ms828.489us
  ValueRate: 001ms204.326us / second
  Rate: 0.0345056 / second
  Percentiles: 1%=001.730us; 5%=001.730us; 10%=001.840us; 20%=003.411us; 50%=022.989us; 80%=118.280us; 90%=209ms247.635us; 95%=209ms293.365us; 99%=209ms293.365us
Metric: InputOutputAliasCount
  TotalSamples: 4
  Accumulator: 0.00
  ValueRate: 0.00 / second
  Rate: 0.0115018 / second
  Percentiles: 1%=0.00; 5%=0.00; 10%=0.00; 20%=0.00; 50%=0.00; 80%=0.00; 90%=0.00; 95%=0.00; 99%=0.00
Metric: IrValueTensorToXlaData
  TotalSamples: 220
  Accumulator: 412ms256.934us
  ValueRate: 924ms494.537us / second
  Rate: 493.354 / second
  Percentiles: 1%=038.220us; 5%=046.951us; 10%=052.730us; 20%=110.850us; 50%=001ms400.680us; 80%=002ms365.610us; 90%=004ms099.969us; 95%=006ms096.239us; 99%=013ms892.569us


In [None]:
import os; cpu_count = os.cpu_count()

In [None]:
cpu_count

In [None]:
import torch_xla.core.xla_model as xm

In [None]:
world_size = xm.xrt_world_size()

In [None]:
import torch_xla.runtime as xr; world_size = xr.world_size()

In [None]:
world_size