In [1]:
import os
import time
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
from liburing import (
    O_RDONLY, AT_FDCWD, iovec, io_uring, io_uring_get_sqe,
    io_uring_prep_openat, io_uring_prep_read, io_uring_prep_close,
    io_uring_submit, io_uring_wait_cqe, io_uring_cqe_seen,
    io_uring_cqe, io_uring_queue_init, io_uring_queue_exit,
    io_uring_sqe_set_data64, trap_error
)


class BenchmarkConfig:
    """Configuration for the benchmark"""
    def __init__(self, 
                 data_dir="./benchmark_data",
                 num_files=1000,
                 file_size=4096,  # 4KB per file
                 batch_size=32,
                 num_workers=4,
                 queue_depth=128,
                 num_epochs=3,
                 seed=42):
        self.data_dir = data_dir
        self.num_files = num_files
        self.file_size = file_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.queue_depth = queue_depth
        self.num_epochs = num_epochs
        self.seed = seed
        
        # Create output directory if it doesn't exist
        os.makedirs(data_dir, exist_ok=True)


def generate_benchmark_data(config):
    """Generate synthetic data files for benchmarking"""
    print(f"Generating {config.num_files} benchmark files...")
    
    # Set random seed for reproducibility
    np.random.seed(config.seed)
    
    for i in range(config.num_files):
        # Generate random data (simulating text data)
        data = np.random.randint(0, 256, size=config.file_size, dtype=np.uint8)
        
        # Convert to bytes and save to file
        with open(os.path.join(config.data_dir, f"file_{i}.txt"), "wb") as f:
            f.write(data.tobytes())
    
    print("Benchmark data generation complete!")


class VanillaDataset(Dataset):
    """Standard PyTorch dataset using regular file I/O"""
    def __init__(self, data_dir, file_list, tokenizer, max_length=128):
        self.data_dir = data_dir
        self.file_list = file_list
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        # Load file using regular I/O
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        with open(file_path, 'rb') as f:
            content = f.read()
        
        # Convert bytes to text (assuming UTF-8, but handling errors)
        text = content.decode('utf-8', errors='replace')
        
        # Tokenize text
        encoded = self.tokenizer(
            text, 
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Return the encoded tokens
        return {
            'input_ids': encoded['input_ids'].squeeze(),
            'attention_mask': encoded['attention_mask'].squeeze(),
        }


class IoUringHelper:
    """Helper class for io_uring operations"""
    def __init__(self, queue_depth=128):
        self.ring = io_uring()
        self.cqe = io_uring_cqe()
        self.queue_depth = queue_depth
        io_uring_queue_init(queue_depth, self.ring, 0)
    
    def __del__(self):
        """Clean up resources"""
        io_uring_queue_exit(self.ring)
        
    def open_file(self, path):
        """Open a file using io_uring"""
        _path = path if isinstance(path, bytes) else str(path).encode()
        sqe = io_uring_get_sqe(self.ring)
        io_uring_prep_openat(sqe, AT_FDCWD, _path, O_RDONLY, 0)
        io_uring_sqe_set_data64(sqe, 1)
        return self._submit_and_wait()
    
    def read_file(self, fd, length):
        """Read data from file using io_uring"""
        iov = iovec(bytearray(length))
        sqe = io_uring_get_sqe(self.ring)
        io_uring_prep_read(sqe, fd, iov.iov_base, iov.iov_len, 0)
        io_uring_sqe_set_data64(sqe, 2)
        self._submit_and_wait()
        return iov.iov_base
    
    def close_file(self, fd):
        """Close a file using io_uring"""
        sqe = io_uring_get_sqe(self.ring)
        io_uring_prep_close(sqe, fd)
        io_uring_sqe_set_data64(sqe, 3)
        self._submit_and_wait()
    
    def _submit_and_wait(self):
        """Submit operation and wait for completion"""
        io_uring_submit(self.ring)
        io_uring_wait_cqe(self.ring, self.cqe)
        result = trap_error(self.cqe.res)
        io_uring_cqe_seen(self.ring, self.cqe)
        return result


class IoUringDataset(Dataset):
    """Dataset implementation using io_uring for async I/O"""
    def __init__(self, data_dir, file_list, tokenizer, max_length=128, queue_depth=128):
        self.data_dir = data_dir
        self.file_list = file_list
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.io_helper = IoUringHelper(queue_depth)
    
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        # Get file path
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        
        # Read file using io_uring
        try:
            # Open file
            fd = self.io_helper.open_file(file_path)
            
            # Get file size (we could optimize this in a real implementation)
            file_size = os.path.getsize(file_path)
            
            # Read file content
            content = self.io_helper.read_file(fd, file_size)
            
            # Close file
            self.io_helper.close_file(fd)
            
            # Convert bytes to text
            text = content.decode('utf-8', errors='replace')
            
            # Tokenize text
            encoded = self.tokenizer(
                text, 
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            return {
                'input_ids': encoded['input_ids'].squeeze(),
                'attention_mask': encoded['attention_mask'].squeeze(),
            }
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
            # Return empty tensors in case of error
            return {
                'input_ids': torch.zeros(self.max_length, dtype=torch.long),
                'attention_mask': torch.zeros(self.max_length, dtype=torch.long),
            }


def run_benchmark(config, dataloader, model, device, desc=""):
    """Run benchmark for a specific dataloader implementation"""
    start_time = time.time()
    total_batches = 0
    
    # Track timing for each phase
    io_times = []
    compute_times = []
    
    print(f"\nRunning benchmark for {desc}...")
    
    for epoch in range(config.num_epochs):
        epoch_start = time.time()
        
        for batch_idx, batch in enumerate(dataloader):
            # Track I/O time (time to get the batch)
            io_end = time.time()
            io_time = io_end - (compute_end if batch_idx > 0 else epoch_start)
            io_times.append(io_time)
            
            # Move to device and run model (compute)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Forward pass
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            
            # Track compute time
            compute_end = time.time()
            compute_time = compute_end - io_end
            compute_times.append(compute_time)
            
            total_batches += 1
            
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}/{config.num_epochs}, Batch {batch_idx}/{len(dataloader)}")
    
    total_time = time.time() - start_time
    avg_batch_time = total_time / total_batches
    avg_io_time = np.mean(io_times)
    avg_compute_time = np.mean(compute_times)
    
    results = {
        'total_time': total_time,
        'avg_batch_time': avg_batch_time,
        'avg_io_time': avg_io_time,
        'avg_compute_time': avg_compute_time,
        'io_times': io_times,
        'compute_times': compute_times
    }
    
    print(f"Benchmark results for {desc}:")
    print(f"  Total time: {total_time:.4f}s")
    print(f"  Average batch time: {avg_batch_time:.4f}s")
    print(f"  Average I/O time: {avg_io_time:.4f}s")
    print(f"  Average compute time: {avg_compute_time:.4f}s")
    print(f"  I/O percentage: {(avg_io_time / (avg_io_time + avg_compute_time)) * 100:.2f}%")
    
    return results


def plot_results(vanilla_results, iouring_results):
    """Plot comparison of benchmark results"""
    plt.figure(figsize=(15, 10))
    
    # Plot 1: Total times comparison
    plt.subplot(2, 2, 1)
    times = [
        vanilla_results['total_time'],
        iouring_results['total_time']
    ]
    plt.bar(['Vanilla', 'io_uring'], times)
    plt.title('Total Execution Time (s)')
    plt.ylabel('Time (seconds)')
    
    # Plot 2: Average batch times
    plt.subplot(2, 2, 2)
    avg_times = [
        vanilla_results['avg_batch_time'],
        iouring_results['avg_batch_time']
    ]
    plt.bar(['Vanilla', 'io_uring'], avg_times)
    plt.title('Average Batch Processing Time (s)')
    plt.ylabel('Time (seconds)')
    
    # Plot 3: I/O vs Compute breakdown
    plt.subplot(2, 2, 3)
    labels = ['I/O', 'Compute']
    vanilla_breakdown = [vanilla_results['avg_io_time'], vanilla_results['avg_compute_time']]
    iouring_breakdown = [iouring_results['avg_io_time'], iouring_results['avg_compute_time']]
    
    x = np.arange(len(labels))
    width = 0.35
    
    plt.bar(x - width/2, vanilla_breakdown, width, label='Vanilla')
    plt.bar(x + width/2, iouring_breakdown, width, label='io_uring')
    plt.xlabel('Phase')
    plt.ylabel('Time (seconds)')
    plt.title('I/O vs Compute Time')
    plt.xticks(x, labels)
    plt.legend()
    
    # Plot 4: Speedup ratio
    plt.subplot(2, 2, 4)
    io_speedup = vanilla_results['avg_io_time'] / iouring_results['avg_io_time']
    total_speedup = vanilla_results['avg_batch_time'] / iouring_results['avg_batch_time']
    
    plt.bar(['I/O Speedup', 'Total Speedup'], [io_speedup, total_speedup])
    plt.axhline(y=1.0, color='r', linestyle='-', alpha=0.3)
    plt.title('Speedup Ratio (Vanilla / io_uring)')
    plt.ylabel('Ratio')
    
    plt.tight_layout()
    plt.savefig('io_uring_benchmark_results.png')
    plt.show()


def main():
    # Initialize benchmark configuration
    config = BenchmarkConfig()
    
    # Generate benchmark data if needed
    if not os.path.exists(config.data_dir) or len(os.listdir(config.data_dir)) < config.num_files:
        generate_benchmark_data(config)
    
    # Get list of files
    file_list = [f for f in os.listdir(config.data_dir) if f.endswith('.txt')]
    print(f"Found {len(file_list)} files for benchmarking")
    
    # Initialize tokenizer and model
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    print(f"Using device: {device}")
    
    # Create datasets
    vanilla_dataset = VanillaDataset(config.data_dir, file_list, tokenizer)
    iouring_dataset = IoUringDataset(config.data_dir, file_list, tokenizer, queue_depth=config.queue_depth)
    
    # Create dataloaders
    vanilla_dataloader = DataLoader(
        vanilla_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=(device.type == 'cuda')
    )
    
    # Note: For IoUringDataset, we need to be careful with num_workers
    # Each worker will create its own IoUringHelper instance
    iouring_dataloader = DataLoader(
        iouring_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=(device.type == 'cuda')
    )
    
    # Run benchmarks
    vanilla_results = run_benchmark(config, vanilla_dataloader, model, device, "Vanilla DataLoader")
    iouring_results = run_benchmark(config, iouring_dataloader, model, device, "io_uring DataLoader")
    
    # Plot results
    plot_results(vanilla_results, iouring_results)
    
    # Print improvement summary
    io_improvement = (vanilla_results['avg_io_time'] - iouring_results['avg_io_time']) / vanilla_results['avg_io_time'] * 100
    total_improvement = (vanilla_results['avg_batch_time'] - iouring_results['avg_batch_time']) / vanilla_results['avg_batch_time'] * 100
    
    print("\nPerformance Improvement Summary:")
    print(f"  I/O Time Improvement: {io_improvement:.2f}%")
    print(f"  Total Time Improvement: {total_improvement:.2f}%")
    
    # Clean up (optional)
    # import shutil
    # shutil.rmtree(config.data_dir)


if __name__ == "__main__":
    main()

Generating 1000 benchmark files...
Benchmark data generation complete!
Found 1000 files for benchmarking
Using device: cpu

Running benchmark for Vanilla DataLoader...
Epoch 1/3, Batch 0/32
Epoch 1/3, Batch 10/32
Epoch 1/3, Batch 20/32
Epoch 1/3, Batch 30/32
Epoch 2/3, Batch 0/32
Epoch 2/3, Batch 10/32
Epoch 2/3, Batch 20/32
Epoch 2/3, Batch 30/32
Epoch 3/3, Batch 0/32
Epoch 3/3, Batch 10/32
Epoch 3/3, Batch 20/32
Epoch 3/3, Batch 30/32
Benchmark results for Vanilla DataLoader:
  Total time: 1198.9607s
  Average batch time: 12.4892s
  Average I/O time: 0.0614s
  Average compute time: 12.4262s
  I/O percentage: 0.49%

Running benchmark for io_uring DataLoader...
Error processing file ./benchmark_data/file_313.txt: expected bytes, __common_define__ foundError processing file ./benchmark_data/file_585.txt: expected bytes, __common_define__ found

Error processing file ./benchmark_data/file_406.txt: expected bytes, __common_define__ foundError processing file ./benchmark_data/file_454.txt:

KeyboardInterrupt: 