In [1]:
!export PJRT_DEVICE=TPU

In [2]:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

dev = xm.xla_device()
t1 = torch.randn(3,3,device=dev)
t2 = torch.randn(3,3,device=dev)
print(t1 + t2)




tensor([[ 1.8535, -0.0823, -0.2653],
        [-0.0051,  1.3175, -0.5839],
        [-1.2089,  1.2293,  1.3118]], device='xla:0')


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import argparse
import os

import evaluate
import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    Trainer,
    TrainerCallback,
    TrainingArguments,
    TrainerState,
    TrainerControl,
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
)

from accelerate import Accelerator, DataLoaderConfiguration, DistributedType

In [12]:
import numpy as np
from torch.utils.data import IterableDataset, get_worker_info, DataLoader

from dataset import StatefulShardedDataset

In [8]:
TOP_SEED = 42
torch.manual_seed(TOP_SEED)

base = '/home/shuyaoli/llm_data/converted_dataset'
domain_dirs = {
    'book':        os.path.join(base, 'book'),
    'arxiv':       os.path.join(base, 'arxiv'),
    'stackexchange':os.path.join(base, 'stackexchange'),
    'wiki':        os.path.join(base, 'wiki'),
    'c4-rp':       os.path.join(base, 'c4-rp'),
    'cc':          os.path.join(base, 'cc'),
    'github':      os.path.join(base, 'github'),
}

In [9]:
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
import torch_xla.core.xla_model as xm
import torch

In [13]:
class DynamicSamplingCallback(TrainerCallback): # Not used
    """
    A Hugging Face TrainerCallback that dynamically adjusts dataset sampling
    weights based on the training loss every N steps.
    """
    def __init__(
        self,
        dataset: StatefulShardedDataset,
        update_every_n_steps: int = 100,
        # Your logic to map loss to weights goes here
        weight_update_fn: callable = lambda loss: [1.0 / max(loss, 1e-6)] 
    ):
        self.dataset = dataset
        self.update_every_n_steps = update_every_n_steps
        self.weight_update_fn = weight_update_fn
        self.running_losses = []

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """Event called at the end of a training step."""
        # Get the latest loss from the trainer's state
        # The log_history contains dicts like {'loss': 3.4, 'learning_rate':...}
        if state.log_history:
            latest_loss = state.log_history[-1].get("loss")
            if latest_loss is not None:
                self.running_losses.append(torch.tensor(latest_loss, device=xm.xla_device()))

        # Check if it's time to perform an update
        if state.global_step > 0 and state.global_step % self.update_every_n_steps == 0:
            if not self.running_losses:
                return # Nothing to do if we haven't collected any losses

            # This must be called on all processes to avoid deadlocks.
            # It gathers the loss tensors from all TPU cores and averages them.
            avg_loss_tensor = xm.mesh_reduce(
                'loss_reduce_tag',
                torch.mean(torch.stack(self.running_losses)),
                lambda x: torch.mean(torch.stack(x))
            )
            # Clear the buffer for the next window
            self.running_losses = []

            # The rest of the logic should only run on the master process
            if xm.is_master_ordinal():
                avg_loss_val = avg_loss_tensor.item()
                
                # 1. Calculate new weights using the provided function
                new_weights = self.weight_update_fn(avg_loss_val)
                
                # 2. Log the information
                print(f"\n--- Step {state.global_step} ---")
                print(f"Avg loss over last {self.update_every_n_steps} steps: {avg_loss_val:.4f}")
                print(f"Updating sampling weights to: {[f'{w:.4f}' for w in new_weights]}")
                
                # 3. Update the dataset's weights directly
                self.dataset.update_weights(new_weights)

In [15]:
class DynamicSamplingOnEvaluationCallback(TrainerCallback):
    """
    A Hugging Face TrainerCallback that dynamically adjusts dataset sampling
    weights based on the evaluation loss.
    """
    def __init__(self, dataset: StatefulShardedDataset, weight_update_fn: callable):
        self.dataset = dataset
        self.weight_update_fn = weight_update_fn

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics: dict[str, float], **kwargs):
        """Event called after an evaluation phase."""
        if xm.is_master_ordinal():
            eval_loss = metrics.get("eval_loss")
            if eval_loss is None:
                print("Warning: 'eval_loss' not found in metrics. Skipping weight update.")
                return

            new_weights = self.weight_update_fn(eval_loss)
            print(f"\n--- Evaluation at Step {state.global_step} ---")
            print(f"Evaluation Loss: {eval_loss:.4f}")
            print(f"Updating sampling weights to: {[f'{w:.4f}' for w in new_weights]}")
            self.dataset.update_weights(new_weights)


In [16]:
class StreamingTrainer(Trainer):
    """
    A custom Trainer that overrides the training dataloader to use
    our streaming dataset and the PyTorch/XLA MpDeviceLoader.
    """
    pass
    # def get_train_dataloader(self) -> DataLoader:
    #     if self.train_dataset is None:
    #         raise ValueError("Trainer: training requires a train_dataset.")
        
    #     assert isinstance(self.train_dataset, StatefulStreamingDataset), \
    #         "train_dataset must be an instance of StatefulStreamingDataset"
        
    #     # Use the PyTorch/XLA ParallelLoader for TPUs
    #     return pl.MpDeviceLoader(
    #         self.train_dataset,
    #         device=self.args.device,
    #         batch_size=self.args.per_device_train_batch_size,
    #         num_workers=self.args.dataloader_num_workers,
    #         worker_init_fn=worker_init_fn,
    #     )

In [17]:
FINAL_EVAL_BATCH_SIZE = 8
EVAL_PATH = '/home/shuyaoli/llm_data/LLM-Shearing/for_prune/eval_merge'
eval_dataset = StreamingDataset(
    local=EVAL_PATH,
    batch_size=FINAL_EVAL_BATCH_SIZE
)

NameError: name 'StreamingDataset' is not defined

In [61]:
TRAIN_BATCH_SIZE = 8




In [19]:
initial_weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

# Instantiate your master dataset
master_train_dataset = StatefulShardedDataset(
    domain_dirs=domain_dirs,
    initial_weights=initial_weights,
    chunk_size=4  # small for demo
)

Initialized dataset with 7 domains.
  -> Found 7 shards for domain 'book'
  -> Found 4 shards for domain 'arxiv'
  -> Found 4 shards for domain 'stackexchange'
  -> Found 7 shards for domain 'wiki'
  -> Found 23 shards for domain 'c4-rp'
  -> Found 103 shards for domain 'cc'
  -> Found 7 shards for domain 'github'


In [63]:
# Using the Sheared-LLaMA model for continued pretraining.
model_name = "princeton-nlp/Sheared-LLaMA-1.3B-Pruned"

In [64]:
# You will need to be logged into your Hugging Face account and have
# access to meta-llama models for this to work.
# `huggingface-cli login`
tokenizer_name = "meta-llama/Llama-2-7b-hf"

In [65]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Set pad token to EOS token for Causal LM
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)

In [66]:
training_args = TrainingArguments(
    output_dir="./tpu_eval_sampling_model",
    max_steps=5000,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_steps=100,
    
    # --- Crucial for this strategy ---
    eval_strategy="steps",
    eval_steps=128, # Run evaluation every 1000 steps
    # ---------------------------------
    
    dataloader_num_workers=1, # At least 1 worker for worker_init_fn to be called
    remove_unused_columns=True,
)

In [67]:
eval_sampling_callback = DynamicSamplingOnEvaluationCallback(
    dataset=master_train_dataset,
    weight_update_fn=lambda x: 1
)

In [68]:
trainer = StreamingTrainer(
    model=model,
    args=training_args,
    train_dataset=master_train_dataset,
    eval_dataset=eval_dataset, # Provide the evaluation dataset
    callbacks=[eval_sampling_callback], # Use the new callback
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

In [69]:
trainer.train()

Exception in thread Thread-22 (_loader_worker):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/shuyaoli/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/shuyaoli/.local/lib/python3.10/site-packages/torch_xla/distributed/parallel_loader.py", line 165, in _loader_worker
    _, data = next(data_iter)
  File "/home/shuyaoli/.local/lib/python3.10/site-packages/accelerate/data_loader.py", line 564, in __iter__
    dataloader_iter = self.base_dataloader.__iter__()
  File "/home/shuyaoli/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 493, in __iter__
    return self._get_iterator()
  File "/home/shuyaoli/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 424, in _get_ite

RuntimeError: torch_xla/csrc/xla_graph_executor.cpp:689 : Check failed: tensor_data 
*** Begin stack trace ***
	tsl::CurrentStackTrace[abi:cxx11]()
	torch_xla::XLAGraphExecutor::CollectSyncTensors(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&)
	torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
	torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
	torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, bool)
	
	
	
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	
	
	
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	
	PyObject_Call
	
	_PyObject_MakeTpCall
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	
	PyEval_EvalCode
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	
	PyObject_Call
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	PyEval_EvalCode
	
	
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	_PyFunction_Vectorcall
	
	Py_RunMain
	Py_BytesMain
	
	__libc_start_main
*** End stack trace ***


In [1]:
import os; cpu_count = os.cpu_count()

In [2]:
cpu_count

240

In [3]:
import torch_xla.core.xla_model as xm



In [4]:
world_size = xm.xrt_world_size()

AttributeError: module 'torch_xla.core.xla_model' has no attribute 'xrt_world_size'

In [6]:
import torch_xla.runtime as xr; world_size = xr.world_size()

In [8]:
world_size

1