In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import time
import subprocess
import threading
import re
from queue import Queue
from tqdm import tqdm
import random

# Configurable parameters
dataset_size = 100000
batch_sizes = [128, 256, 512]  # Adjust for 2 GPUs (e.g., [256, 512, 1024] for H100)
hidden_dim = 512
num_epochs = 20
num_workers = 4
monitor_interval = 5
learning_rate = 0.001
lr_step_size = 5
lr_gamma = 0.5

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
print(f"Using device: {device}, Number of GPUs: {num_gpus}")

# Step 1: Synthetic dataset with learnable patterns
class SyntheticImageDataset(Dataset):
    def __init__(self, num_samples=dataset_size, img_size=28, num_classes=10):
        self.num_classes = num_classes
        self.img_size = img_size
        self.data = torch.randn(num_samples, 1, img_size, img_size)
        self.labels = torch.randint(0, num_classes, (num_samples,))
        for i in range(num_samples):
            label = self.labels[i]
            row_start = (label % 4) * (img_size // 4)
            col_start = (label // 4) * (img_size // 4)
            region_size = img_size // 5
            self.data[i, 0, row_start:row_start+region_size, col_start:col_start+region_size] += label * 0.5

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

dataset = SyntheticImageDataset()

# Step 2: CNN model
class SimpleCNN(nn.Module):
    def __init__(self, hidden_dim=hidden_dim, num_classes=10, input_size=28):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, hidden_dim // 4, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        fc1_input = hidden_dim * (input_size // 8) * (input_size // 8)
        self.fc1 = nn.Linear(fc1_input, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.pool(self.conv1(x)))
        x = self.relu(self.pool(self.conv2(x)))
        x = self.relu(self.pool(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleCNN().to(device)
if num_gpus > 1:
    model = nn.DataParallel(model)  # Distribute across GPUs

# Step 3: Optimizer, scheduler, loss, scaler
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step_size, gamma=lr_gamma)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

# Step 4: GPU monitoring thread (handles multiple GPUs)
def monitor_gpu(queue):
    util_samples = []
    mem_samples = []
    while True:
        try:
            output = subprocess.check_output(
                ["nvidia-smi", "--query-gpu=utilization.gpu,memory.used", "--format=csv,noheader,nounits"]
            ).decode("utf-8").strip()
            # Parse each GPU's metrics
            lines = output.split('\n')
            gpu_utils = []
            gpu_mems = []
            for line in lines:
                util, mem = map(float, re.findall(r'\d+', line))
                gpu_utils.append(util)
                gpu_mems.append(mem)
            # Average across GPUs for this sample
            avg_util = sum(gpu_utils) / len(gpu_utils)
            avg_mem = sum(gpu_mems) / len(gpu_mems)
            util_samples.append(avg_util)
            mem_samples.append(avg_mem)
            print(f"GPU Util (avg): {avg_util:.1f}%, Mem (avg): {avg_mem:.1f} MiB, Per GPU: {list(zip(gpu_utils, gpu_mems))}")
        except Exception as e:
            print(f"GPU monitor error: {e}")
        if not queue.empty():
            break
        time.sleep(monitor_interval)
    if util_samples and mem_samples:
        avg_util = sum(util_samples) / len(util_samples)
        avg_mem = sum(mem_samples) / len(mem_samples)
        print(f"Average GPU Utilization (all GPUs): {avg_util:.2f}%, Average Memory Used (all GPUs): {avg_mem:.2f} MiB")
    else:
        print("No GPU samples collected.")

queue = Queue()
monitor_thread = threading.Thread(target=monitor_gpu, args=(queue,))
monitor_thread.start()

# Step 5: Training loop with dynamic batch size
start_time = time.time()
for epoch in range(num_epochs):
    batch_size = batch_sizes[min(epoch // 5, len(batch_sizes) - 1)]
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs} (Batch Size: {batch_size})", leave=True)
    for data, labels in progress_bar:
        data, labels = data.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast():
            outputs = model(data)
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        progress_bar.set_postfix({"Batch Loss": f"{loss.item():.4f}"})
    
    scheduler.step()
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

end_time = time.time()
print(f"Training completed in {end_time - start_time:.2f} seconds")

# Stop monitoring
queue.put("stop")
monitor_thread.join()