# Distributed ImageNet Training with Hivemind

This notebook implements distributed training for ImageNet using ResNet50 and Hivemind DHT.

In [None]:
import os
import multiprocessing as mp
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import hivemind
from torchvision.models import resnet50, ResNet50_Weights
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from typing import Optional
import itertools
import warnings
import socket
import glob
import time
import requests
from google.cloud import storage

# Local imports (assuming these files exist in the same directory)
from metrics import ResourceMonitor, TrainingLogger
from datasets import get_webdataset_loader

warnings.filterwarnings("ignore", message=".*Please use the new API settings to control TF32 behavior.*")

# Set start method for multiprocessing
try:
    mp.set_start_method('fork', force=True)
except RuntimeError:
    pass

## Configuration

We use a `Config` class to replace command-line arguments.

In [None]:
class Config:
    def __init__(self):
        self.device = None  # Auto-detect
        self.initial_peer = None # Multiaddr of initial peer
        self.val_every = 1
        
        # Data loading
        self.data_dir = "gs://caso-estudio-2/imagenet-wds"
        self.num_workers = 0
        self.batch_size = 32
        self.target_batch_size = 50000
        self.val_batches = 100
        self.no_initial_val = False
        self.epochs = 2000
        self.host_port = 31337
        
        # Hyperparameters
        self.lr = 0.001
        self.scheduler_milestones = [1000, 1600, 1800]
        self.scheduler_gamma = 0.1
        
        # Automated Peer Discovery
        self.announce_gcs_path = None
        self.fetch_gcs_path = None

args = Config()

# Example: Override defaults here if needed
# args.batch_size = 32

## Helper Functions

In [None]:
def build_model(num_classes: int = 1000) -> nn.Module:
    model = resnet50(weights=None)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

def select_device(cli_device: Optional[str]) -> torch.device:
    if cli_device:
        return torch.device(cli_device)

    # Prefer MPS (Apple Silicon)
    mps_available = (
        getattr(torch.backends, "mps", None) is not None
        and torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
    )
    if mps_available:
        return torch.device("mps")

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        return torch.device("cuda")

    return torch.device("cpu")

def get_next_log_paths(base_dir="stats"):
    hostname = socket.gethostname()
    host_dir = os.path.join(base_dir, hostname)
    os.makedirs(host_dir, exist_ok=True)

    existing_runs = glob.glob(os.path.join(host_dir, "run_*_system_metrics.csv"))
    max_run = 0
    for path in existing_runs:
        try:
            filename = os.path.basename(path)
            parts = filename.split('_')
            run_num = int(parts[1])
            if run_num > max_run:
                max_run = run_num
        except (IndexError, ValueError):
            continue

    next_run = max_run + 1
    
    sys_metric_path = os.path.join(host_dir, f"run_{next_run}_system_metrics.csv")
    train_metric_path = os.path.join(host_dir, f"run_{next_run}_training_metrics.csv")
    
    print(f"üìÅ Logging metrics to: {host_dir} (Run #{next_run})")
    return sys_metric_path, train_metric_path

In [None]:
def evaluate_accuracy(model: nn.Module, loader, device: torch.device, max_batches: Optional[int] = None) -> tuple:
    model_was_training = model.training
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        if max_batches is not None:
             loader_iter = itertools.islice(loader, max_batches)
             total_batches = max_batches
        else:
             loader_iter = loader
             try:
                 total_batches = len(loader)
             except TypeError:
                 if hasattr(loader, 'dataset') and hasattr(loader.dataset, '__len__') and hasattr(loader, 'batch_size'):
                     total_batches = len(loader.dataset) // loader.batch_size
                 else:
                     total_batches = None

        for xb, yb in tqdm(loader_iter, total=total_batches, desc="Validating", leave=False):
            nb = (device.type == "cuda")
            xb = xb.to(device, non_blocking=nb)
            yb = yb.to(device, non_blocking=nb)
            if device.type == "mps":
                xb = xb.contiguous()

            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            
            running_loss += loss.item()
            num_batches += 1
            
            pred = logits.argmax(dim=1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    
    avg_loss = running_loss / max(1, num_batches)
    accuracy = 100.0 * correct / max(1, total)
    
    if model_was_training:
        model.train()
    
    return avg_loss, accuracy

In [None]:
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler, out_dir: str, epoch_idx: int, acc: float, filename: str = "best_checkpoint.pt"):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, filename)
    torch.save({
        "epoch": epoch_idx,
        "val_accuracy": acc,
        "model_state": model.state_dict(),
        "opt_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict() if scheduler else None,
    }, path)
    return path

def load_checkpoint(path: str, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler, device: torch.device):
    if not os.path.exists(path):
        return None, -1.0

    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["opt_state"])
    
    if scheduler and "scheduler_state" in checkpoint and checkpoint["scheduler_state"] is not None:
        scheduler.load_state_dict(checkpoint["scheduler_state"])

    epoch = checkpoint.get("epoch", 0)
    acc = checkpoint.get("val_accuracy", -1.0)

    print(f"‚úì Checkpoint loaded from: {path}")
    print(f"  Epoch: {epoch}, Accuracy: {acc:.2f}%")

    return epoch, acc

## Initialization

In [None]:
RUN_ID = "imagenet_resnet50"
BATCH = args.batch_size
TARGET_GLOBAL_BSZ = args.target_batch_size
EPOCHS = args.epochs
LR = args.lr
MATCHMAKING_TIME = 60.0
AVERAGING_TIMEOUT = 120.0
CHECKPOINT_DIR = "./checkpoints"

# Initialize Loggers
sys_log_path, train_log_path = get_next_log_paths()
resource_monitor = ResourceMonitor(log_file=sys_log_path)
resource_monitor.start()
training_logger = TrainingLogger(log_file=train_log_path)

# Device
device = select_device(args.device)
print(f"\nDevice: {device}")
if device.type == "mps":
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

In [None]:
# DataLoaders (WebDataset)
print(f"Loading ImageNet from {args.data_dir} (WebDataset)")

# Load ImageNet classes
try:
    from imagenet_classes import IMAGENET_SYNSETS
    classes = IMAGENET_SYNSETS
    num_classes = len(classes)
    print(f"Loaded {num_classes} classes from imagenet_classes.py")
except ImportError:
    print("Could not import imagenet_classes.py. Defaulting to 1000 classes (ImageNet-1k standard).")
    num_classes = 1000
    classes = None

WORKERS = args.num_workers

# Train Loader
train_loader = get_webdataset_loader(
    bucket_name=args.data_dir,
    prefix="",
    batch_size=BATCH,
    num_workers=WORKERS,
    device=device,
    is_train=True,
    total_shards=641,
    train_prefix="train/train",
    classes=classes
)

# Val Loader
val_loader = get_webdataset_loader(
    bucket_name=args.data_dir,
    prefix="",
    batch_size=BATCH,
    num_workers=WORKERS,
    device=device,
    is_train=False,
    val_shards=50,
    val_prefix="val/train",
    classes=classes
)

In [None]:
# Models

# MASTER on CPU (for Hivemind)
model = build_model(num_classes=num_classes)
model = model.to("cpu")

# SHADOW on Device (for compute)
model_on_device = build_model(num_classes=num_classes)
model_on_device = model_on_device.to(device)

if device.type == "cuda":
    model_on_device = model_on_device.to(memory_format=torch.channels_last)

# Base Optimizer
base_optimizer = torch.optim.Adam(model.parameters(), lr=LR)

## Peer Discovery & DHT

In [None]:
# Automated Discovery: Fetch
if args.fetch_gcs_path and not args.initial_peer:
    print(f"üîç Looking for initial peer address in {args.fetch_gcs_path}...")
    try:
        if not args.fetch_gcs_path.startswith("gs://"):
            raise ValueError("GCS path must start with gs://")
        
        parts = args.fetch_gcs_path[5:].split('/', 1)
        bucket_name = parts[0]
        blob_name = parts[1]
        
        public_url = f"https://storage.googleapis.com/{bucket_name}/{blob_name}?t={int(time.time())}"
        print(f"   Trying public URL: {public_url}")
        
        try:
            resp = requests.get(public_url, timeout=10)
            if resp.status_code == 200:
                content = resp.text.strip()
                print(f"‚úÖ Found initial peer: {content}")
                args.initial_peer = content
            else:
                raise Exception(f"Status code {resp.status_code}")
        except Exception as e:
            print(f"‚ö†Ô∏è  Could not fetch from public URL ({e}). Trying GCS client...")
            storage_client = storage.Client.create_anonymous_client()
            bucket = storage_client.bucket(bucket_name)
            blob = bucket.blob(blob_name)
            content = blob.download_as_text().strip()
            if content:
                print(f"‚úÖ Found initial peer (via Client): {content}")
                args.initial_peer = content
            else:
                print("‚ö†Ô∏è  GCS file found but empty.")
    except Exception as e:
        print(f"‚ö†Ô∏è  Failed to fetch initial peer from GCS: {e}")
        print("   Will attempt to start without initial peer (or as standalone).")

# DHT Setup
dht_kwargs = dict(
    host_maddrs=[f"/ip4/0.0.0.0/tcp/{args.host_port}"],
    start=True,
    await_ready=False
)
if args.initial_peer:
    dht_kwargs["initial_peers"] = [args.initial_peer]

print(f"=== Hivemind DHT ===")
dht = hivemind.DHT(**dht_kwargs)

# Automated Discovery: Announce
if args.announce_gcs_path:
    print(f"üì¢ Announcing this peer to {args.announce_gcs_path}...")
    try:
        try:
            public_ip = requests.get('https://checkip.amazonaws.com', timeout=5).text.strip()
        except:
            public_ip = "127.0.0.1"
            
        peer_id = dht.peer_id
        port = args.host_port
        full_address = f"/ip4/{public_ip}/tcp/{port}/p2p/{peer_id}"
        
        print(f"   Public Address: {full_address}")
        
        if not args.announce_gcs_path.startswith("gs://"):
            raise ValueError("GCS path must start with gs://")
            
        parts = args.announce_gcs_path[5:].split('/', 1)
        bucket_name = parts[0]
        blob_name = parts[1]
        
        storage_client = storage.Client() 
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(blob_name)
        
        blob.upload_from_string(full_address)
        print(f"‚úÖ Address written to GCS successfully.")
        
    except Exception as e:
        print(f"‚ùå Failed to announce address to GCS: {e}")

# Wait for DHT
print("‚è≥ Waiting for DHT to be ready...")
try:
    dht.wait_until_ready(timeout=60.0)
    print("DHT is ready!")
except TimeoutError:
    print("DHT timed out waiting for readiness. Continuing anyway...")

maddrs = [str(m) for m in dht.get_visible_maddrs()]
print("\n=== Hivemind DHT ===")
for m in maddrs:
    print("VISIBLE_MADDR:", m)

## Training Setup

In [None]:
best_accuracy = -1.0
start_epoch = 0

# Hivemind Optimizer
opt = hivemind.Optimizer(
    dht=dht,
    run_id=RUN_ID,
    batch_size_per_step=BATCH,
    target_batch_size=TARGET_GLOBAL_BSZ,
    optimizer=base_optimizer,
    use_local_updates=True,
    matchmaking_time=MATCHMAKING_TIME,
    averaging_timeout=AVERAGING_TIMEOUT,
    verbose=True,
)

# Scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(base_optimizer, milestones=args.scheduler_milestones, gamma=args.scheduler_gamma)

# Load Checkpoint
latest_path = os.path.join(CHECKPOINT_DIR, "latest_checkpoint.pt")
best_path = os.path.join(CHECKPOINT_DIR, "best_checkpoint.pt")

if os.path.exists(latest_path):
    print(f"Resuming from LATEST: {latest_path}")
    start_epoch, acc = load_checkpoint(latest_path, model, opt, scheduler, device)
    
    if os.path.exists(best_path):
        checkpoint = torch.load(best_path, map_location='cpu')
        best_accuracy = checkpoint.get("val_accuracy", acc)
    else:
        best_accuracy = acc
elif os.path.exists(best_path):
    print(f"Resuming from BEST: {best_path}")
    start_epoch, best_accuracy = load_checkpoint(best_path, model, opt, scheduler, device)
else:
    print("No checkpoints found.")

Task exception was never retrieved
future: <Task finished name='Task-62' coro=<DecentralizedAverager._declare_for_download_periodically() done, defined at /home/daniel/distrubuted-ImageNet/.venv/lib/python3.13/site-packages/hivemind/averaging/averager.py:600> exception=RuntimeError('Broken pipe')>
Traceback (most recent call last):
  File "/home/daniel/distrubuted-ImageNet/.venv/lib/python3.13/site-packages/hivemind/averaging/averager.py", line 609, in _declare_for_download_periodically
    self.dht.store(
    ~~~~~~~~~~~~~~^
        download_key,
        ^^^^^^^^^^^^^
    ...<3 lines>...
        return_future=True,
        ^^^^^^^^^^^^^^^^^^^
    ),
    ^
  File "/home/daniel/distrubuted-ImageNet/.venv/lib/python3.13/site-packages/hivemind/dht/dht.py", line 212, in store
    future = MPFuture()
  File "/home/daniel/distrubuted-ImageNet/.venv/lib/python3.13/site-packages/hivemind/utils/mpfuture.py", line 93, in __init__
    self._shared_state_code = SharedBytes.next()
                 

## Training Loop

In [None]:
target_epochs = EPOCHS
last_seen_epoch = getattr(opt, "local_epoch", 0)
checkpoint_path = None

train_correct = 0
train_total = 0

print(f"\nTraining until {target_epochs} global epochs (target_batch_size={TARGET_GLOBAL_BSZ}).")
if best_accuracy > 0:
    print(f"Continuing from best accuracy: {best_accuracy:.2f}%")

try:
    with tqdm(total=None) as pbar:
        while True:
            for xb, yb in train_loader:
                nb = (device.type == "cuda")

                xb = xb.to(device, non_blocking=nb)
                yb = yb.to(device, non_blocking=nb)

                # Safety check for labels
                if (yb < 0).any() or (yb >= num_classes).any():
                    invalid_vals = yb[(yb < 0) | (yb >= num_classes)]
                    raise RuntimeError(f"Found invalid labels in batch: {invalid_vals.cpu().numpy()}. Expected range [0, {num_classes}).")

                if device.type == "cuda":
                    xb = xb.to(memory_format=torch.channels_last)
                elif device.type == "mps":
                    xb = xb.contiguous()

                # 1. Sync weights CPU -> Device
                with torch.no_grad():
                    for p_cpu, p_dev in zip(model.parameters(), model_on_device.parameters()):
                        p_dev.copy_(p_cpu)
                    for b_cpu, b_dev in zip(model.buffers(), model_on_device.buffers()):
                        b_dev.copy_(b_cpu)

                # 2. Forward/Backward on Device
                model_on_device.train()
                model_on_device.zero_grad()
                
                logits = model_on_device(xb)
                loss = F.cross_entropy(logits, yb)
                
                # Calculate training accuracy
                with torch.no_grad():
                    pred = logits.argmax(dim=1)
                    train_correct += (pred == yb).sum().item()
                    train_total += yb.size(0)
                
                loss.backward()

                # 3. Sync gradients Device -> CPU
                with torch.no_grad():
                    for p_cpu, p_dev in zip(model.parameters(), model_on_device.parameters()):
                        if p_dev.grad is not None:
                            if p_cpu.grad is None:
                                p_cpu.grad = torch.zeros_like(p_cpu)
                            p_cpu.grad.copy_(p_dev.grad)

                # 3.5 Sync buffers Device -> CPU
                with torch.no_grad():
                    for b_cpu, b_dev in zip(model.buffers(), model_on_device.buffers()):
                        b_cpu.copy_(b_dev)

                # 4. Step on CPU (Hivemind)
                opt.step()
                opt.zero_grad()
                
                current_train_acc = 100.0 * train_correct / max(1, train_total)

                pbar.set_description(
                    f"loss={loss.item():.4f}  train_acc={current_train_acc:.2f}%  epoch_g={getattr(opt,'local_epoch',0)}  best_val={best_accuracy:.2f}%"
                )
                pbar.update()
                
                # Log training step
                current_lr = scheduler.get_last_lr()[0]
                training_logger.log_step(
                    epoch=getattr(opt, "local_epoch", 0),
                    batch=pbar.n,
                    loss=loss.item(),
                    learning_rate=current_lr,
                    accuracy=current_train_acc
                )

                current_epoch = getattr(opt, "local_epoch", last_seen_epoch)
                if current_epoch != last_seen_epoch:
                    # Epoch finished
                    final_train_acc = 100.0 * train_correct / max(1, train_total)
                    tqdm.write(f"[Epoch {last_seen_epoch}] Training accuracy: {final_train_acc:.2f}%")
                    
                    train_correct = 0
                    train_total = 0
                    
                    force_initial = (current_epoch == 1 and not args.no_initial_val)
                    do_eval = force_initial or (current_epoch % args.val_every == 0) or (current_epoch >= target_epochs)
                    if do_eval:
                        tqdm.write(f"Starting validation for epoch {current_epoch}...")
                        
                        # Sync weights to device before eval
                        with torch.no_grad():
                            for p_cpu, p_dev in zip(model.parameters(), model_on_device.parameters()):
                                p_dev.copy_(p_cpu)
                            for b_cpu, b_dev in zip(model.buffers(), model_on_device.buffers()):
                                b_dev.copy_(b_cpu)

                        val_loss, val_acc = evaluate_accuracy(model_on_device, val_loader, device, max_batches=args.val_batches)
                        tqdm.write(f"[Epoch {current_epoch}] Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
                        
                        training_logger.log_step(
                            epoch=current_epoch,
                            batch=pbar.n,
                            loss=val_loss,
                            learning_rate=scheduler.get_last_lr()[0],
                            accuracy=val_acc
                        )

                        save_checkpoint(model, opt, scheduler, CHECKPOINT_DIR, current_epoch, val_acc, filename="latest_checkpoint.pt")
                        
                        if val_acc > best_accuracy:
                            ckpt_path = save_checkpoint(model, opt, scheduler, CHECKPOINT_DIR, current_epoch, val_acc, filename="best_checkpoint.pt")
                            best_accuracy = val_acc
                            checkpoint_path = ckpt_path
                            tqdm.write(f"‚Üë New best accuracy ({best_accuracy:.2f}%). Saved: {ckpt_path}")
                        else:
                            tqdm.write(f"‚Üî No improvement (best={best_accuracy:.2f}%).")

                    last_seen_epoch = current_epoch
                    scheduler.step()

                    if current_epoch >= target_epochs:
                        tqdm.write(f"‚úì Reached {current_epoch} global epochs. Finishing...")
                        raise StopIteration
except StopIteration:
    pass

if checkpoint_path or best_accuracy > 0:
    print(f"\nTraining finished. Best accuracy: {best_accuracy:.2f}%")
    if checkpoint_path:
        print(f"Best checkpoint: {checkpoint_path}")
else:
    print("\nTraining finished. No checkpoints saved.")

resource_monitor.stop()

Dec 01 14:45:25.789 [[1m[34mINFO[0m] imagenet_resnet50 accumulated 1952 samples for epoch #0 from 1 peers. ETA 32451.42 sec (refresh in 10.00 sec)


KeyboardInterrupt: 

terminate called after throwing an instance of 'std::system_error'
  what():  Broken pipe
