In [5]:
!export PJRT_DEVICE=TPU

In [1]:
import torch
import torch_xla.core.xla_model as xm

dev = xm.xla_device()
t1 = torch.randn(3,3,device=dev)
t2 = torch.randn(3,3,device=dev)
print(t1 + t2)




tensor([[-2.1709,  0.8687, -0.4139],
        [-0.4169, -2.6102, -1.1311],
        [ 1.7698,  0.3981, -1.6594]], device='xla:0')


In [6]:
import os

In [None]:
os.environ.set

In [9]:
import argparse
import os

import evaluate
import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed

from accelerate import Accelerator, DataLoaderConfiguration, DistributedType

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import numpy as np
import multiprocessing as mp
from torch.utils.data import IterableDataset, get_worker_info

def worker_init_fn(worker_id):
    """Initializes each data loader worker to have its own data streams."""
    worker_info = get_worker_info()
    dataset = worker_info.dataset
    # Each worker gets its own generator for each data source
    dataset.data_streams = [iter(source) for source in dataset.sources]

class StatefulStreamingDataset(IterableDataset):
    """
    An IterableDataset that samples from multiple data sources based on
    dynamically updatable weights, optimized for batched sampling.
    """
    def __init__(self, sources, initial_weights, chunk_size=4096):
        self.sources = sources
        # Use a multiprocessing Array for weights to ensure they are shared
        # across all data loader worker processes.
        self.weights = mp.Array('d', initial_weights)
        self.chunk_size = chunk_size
        self.data_streams = None

    def _get_weights(self):
        """Reads the current weights from the shared memory array."""
        return np.array(self.weights[:])

    def update_weights(self, new_weights: list[float]):
        """
        Updates the shared weights. This method is called from the main
        process by a custom callback.
        """
        with self.weights.get_lock():
            for i in range(len(new_weights)):
                self.weights[i] = new_weights[i]

    def __iter__(self):
        """The core streaming logic for each worker, using chunked sampling."""
        # Initialize streams if they haven't been, relevant for num_workers=0
        if self.data_streams is None:
            self.data_streams = [iter(source) for source in self.sources]

        while True:
            # --- Start of Optimized Block ---
            # 1. Get weights and probabilities ONCE per chunk.
            current_weights = self._get_weights()
            probabilities = current_weights / np.sum(current_weights)
            
            # 2. Generate a large chunk of source indices at once.
            # This is vastly more efficient than calling it in a loop.
            source_indices = np.random.choice(
                len(self.sources), 
                size=self.chunk_size, 
                p=probabilities
            )

            # 3. Iterate through the pre-sampled chunk and yield items.
            for source_idx in source_indices:
                try:
                    yield next(self.data_streams[source_idx])
                except StopIteration:
                    # When a source is exhausted, restart it for continuous training
                    print(f"Worker {get_worker_info().id if get_worker_info() else 0}: Restarting stream {source_idx}.")
                    self.data_streams[source_idx] = iter(self.sources[source_idx])
                    yield next(self.data_streams[source_idx])