In [1]:
!export PJRT_DEVICE=TPU

In [2]:
import torch
import torch_xla.core.xla_model as xm

dev = xm.xla_device()
t1 = torch.randn(3,3,device=dev)
t2 = torch.randn(3,3,device=dev)
print(t1 + t2)




tensor([[ 1.8535, -0.0823, -0.2653],
        [-0.0051,  1.3175, -0.5839],
        [-1.2089,  1.2293,  1.3118]], device='xla:0')


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import os

In [13]:
import argparse
import os

import evaluate
import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    Trainer,
    TrainerCallback,
    TrainingArguments,
    TrainerState,
    TrainerControl,
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
)

from accelerate import Accelerator, DataLoaderConfiguration, DistributedType

In [6]:
import numpy as np
import multiprocessing as mp
from torch.utils.data import IterableDataset, get_worker_info

def worker_init_fn(worker_id):
    """Initializes each data loader worker to have its own data streams."""
    worker_info = get_worker_info()
    dataset = worker_info.dataset
    # Each worker gets its own generator for each data source
    dataset.data_streams = [iter(source) for source in dataset.sources]

class StatefulStreamingDataset(IterableDataset):
    """
    An IterableDataset that samples from multiple data sources based on
    dynamically updatable weights, optimized for batched sampling.
    """
    def __init__(self, sources, initial_weights, chunk_size=4096):
        self.sources = sources
        # Use a multiprocessing Array for weights to ensure they are shared
        # across all data loader worker processes.
        self.weights = mp.Array('d', initial_weights)
        self.chunk_size = chunk_size
        self.data_streams = None

    def _get_weights(self):
        """Reads the current weights from the shared memory array."""
        return np.array(self.weights[:])

    def update_weights(self, new_weights: list[float]):
        """
        Updates the shared weights. This method is called from the main
        process by a custom callback.
        """
        with self.weights.get_lock():
            for i in range(len(new_weights)):
                self.weights[i] = new_weights[i]

    def __iter__(self):
        """The core streaming logic for each worker, using chunked sampling."""
        # Initialize streams if they haven't been, relevant for num_workers=0
        if self.data_streams is None:
            self.data_streams = [iter(source) for source in self.sources]

        while True:
            # --- Start of Optimized Block ---
            # 1. Get weights and probabilities ONCE per chunk.
            current_weights = self._get_weights()
            probabilities = current_weights / np.sum(current_weights)
            
            # 2. Generate a large chunk of source indices at once.
            # This is vastly more efficient than calling it in a loop.
            source_indices = np.random.choice(
                len(self.sources), 
                size=self.chunk_size, 
                p=probabilities
            )

            # 3. Iterate through the pre-sampled chunk and yield items.
            for source_idx in source_indices:
                try:
                    yield next(self.data_streams[source_idx])
                except StopIteration:
                    # When a source is exhausted, restart it for continuous training
                    print(f"Worker {get_worker_info().id if get_worker_info() else 0}: Restarting stream {source_idx}.")
                    self.data_streams[source_idx] = iter(self.sources[source_idx])
                    yield next(self.data_streams[source_idx])

In [7]:
from streaming import Stream, StreamingDataset
from torch.utils.data import DataLoader

In [8]:
local_path = '/home/shuyaoli/llm_data/LLM-Shearing/for_prune/'

In [10]:
!python3 dataset.py

Initializing sources...
Created 7 data sources.

Starting DataLoader iteration...
Step 0:
  Batch keys: dict_keys(['set', 'tokens'])
  Tokens tensor length: [8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192, 8192]
  Domain counts in batch: {np.str_('arxiv'): np.int64(13), np.str_('book')

In [9]:
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
import torch_xla.core.xla_model as xm
import torch

In [10]:
class DynamicSamplingCallback(TrainerCallback):
    """
    A Hugging Face TrainerCallback that dynamically adjusts dataset sampling
    weights based on the training loss every N steps.
    """
    def __init__(
        self,
        dataset: StatefulStreamingDataset,
        update_every_n_steps: int = 100,
        # Your logic to map loss to weights goes here
        weight_update_fn: callable = lambda loss: [1.0 / max(loss, 1e-6)] 
    ):
        self.dataset = dataset
        self.update_every_n_steps = update_every_n_steps
        self.weight_update_fn = weight_update_fn
        self.running_losses = []

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """Event called at the end of a training step."""
        # Get the latest loss from the trainer's state
        # The log_history contains dicts like {'loss': 3.4, 'learning_rate':...}
        if state.log_history:
            latest_loss = state.log_history[-1].get("loss")
            if latest_loss is not None:
                self.running_losses.append(torch.tensor(latest_loss, device=xm.xla_device()))

        # Check if it's time to perform an update
        if state.global_step > 0 and state.global_step % self.update_every_n_steps == 0:
            if not self.running_losses:
                return # Nothing to do if we haven't collected any losses

            # This must be called on all processes to avoid deadlocks.
            # It gathers the loss tensors from all TPU cores and averages them.
            avg_loss_tensor = xm.mesh_reduce(
                'loss_reduce_tag',
                torch.mean(torch.stack(self.running_losses)),
                lambda x: torch.mean(torch.stack(x))
            )
            # Clear the buffer for the next window
            self.running_losses = []

            # The rest of the logic should only run on the master process
            if xm.is_master_ordinal():
                avg_loss_val = avg_loss_tensor.item()
                
                # 1. Calculate new weights using the provided function
                new_weights = self.weight_update_fn(avg_loss_val)
                
                # 2. Log the information
                print(f"\n--- Step {state.global_step} ---")
                print(f"Avg loss over last {self.update_every_n_steps} steps: {avg_loss_val:.4f}")
                print(f"Updating sampling weights to: {[f'{w:.4f}' for w in new_weights]}")
                
                # 3. Update the dataset's weights directly
                self.dataset.update_weights(new_weights)

In [11]:
class DynamicSamplingOnEvaluationCallback(TrainerCallback):
    """
    A Hugging Face TrainerCallback that dynamically adjusts dataset sampling
    weights based on the evaluation loss.
    """
    def __init__(self, dataset: StatefulStreamingDataset, weight_update_fn: callable):
        self.dataset = dataset
        self.weight_update_fn = weight_update_fn

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics: dict[str, float], **kwargs):
        """Event called after an evaluation phase."""
        if xm.is_master_ordinal():
            eval_loss = metrics.get("eval_loss")
            if eval_loss is None:
                print("Warning: 'eval_loss' not found in metrics. Skipping weight update.")
                return

            new_weights = self.weight_update_fn(eval_loss)
            print(f"\n--- Evaluation at Step {state.global_step} ---")
            print(f"Evaluation Loss: {eval_loss:.4f}")
            print(f"Updating sampling weights to: {[f'{w:.4f}' for w in new_weights]}")
            self.dataset.update_weights(new_weights)


In [14]:
class StreamingTrainer(Trainer):
    """
    A custom Trainer that overrides the training dataloader to use
    our streaming dataset and the PyTorch/XLA MpDeviceLoader.
    """
    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        
        assert isinstance(self.train_dataset, StatefulStreamingDataset), \
            "train_dataset must be an instance of StatefulStreamingDataset"
        
        # Use the PyTorch/XLA ParallelLoader for TPUs
        return pl.MpDeviceLoader(
            self.train_dataset,
            device=self.args.device,
            batch_size=self.args.per_device_train_batch_size,
            num_workers=self.args.dataloader_num_workers,
            worker_init_fn=worker_init_fn,
        )

In [25]:
EVAL_PATH = '/home/shuyaoli/llm_data/LLM-Shearing/for_prune/eval_merge'
eval_dataset = StreamingDataset(
    local=EVAL_PATH,
    batch_size=FINAL_EVAL_BATCH_SIZE
)

In [28]:
TRAIN_BATCH_SIZE = 128




In [31]:
mp.set_start_method("spawn", force=True)

# --- Setup the Underlying Readers ---
local_path = '/home/shuyaoli/llm_data/LLM-Shearing/for_prune'
stream_names = [
    'book', 'arxiv', 'stackexchange', 'wiki', 'c4-rp', 'cc', 'github'
]

print("Initializing sources...")
# Create a list of sources, where each source is a StreamingDataset for one domain
sources = []
for name in stream_names:
    # Each StreamingDataset object is an independent iterable data source
    domain_dataset = StreamingDataset(
        local=os.path.join(local_path, name),
        shuffle=True, # Shuffle within this stream
        batch_size=TRAIN_BATCH_SIZE # We will handle batching in the final DataLoader
    )
    sources.append(domain_dataset)

print(f"Created {len(sources)} data sources.")

# --- Setup Your Custom Wrapper Dataset ---
# Define the initial sampling weights for each stream
# Make sure the length matches the number of streams!
initial_weights = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
assert len(initial_weights) == len(sources)

# Instantiate your master dataset
master_dataset = StatefulStreamingDataset(
    sources=sources,
    initial_weights=initial_weights,
    chunk_size=1
)

Initializing sources...
Created 7 data sources.
