In [10]:
import os
import io
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from pyspark.sql import SparkSession
from pyspark.ml.torch.distributor import TorchDistributor

def train_fn():
    import torch
    import torch.distributed as dist
    from torchvision import transforms, datasets, models
    from torch.utils.data import DataLoader
    from torch.utils.data.distributed import DistributedSampler
    import io, os

    print("=== DISTRIBUTED RESNET18 TRAINING (AMP + SHARDING) ===")

    # Distributed metadata
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    print(f"[Worker {rank}] World size: {world_size}")

    # Dataset path (NFS)
    dataset_path = "/mnt/spark_data/DATASET-RUIDO"
    print(f"[Worker {rank}] Dataset path: {dataset_path}")

    # Image transforms
    train_tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    # Dataset on NFS
    dataset = datasets.ImageFolder(root=dataset_path, transform=train_tf)

    # Distributed sharding
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    dataloader = DataLoader(
        dataset,
        batch_size=32,
        sampler=sampler,
        num_workers=4,
        pin_memory=True
    )

    print(f"[Worker {rank}] Total images loaded: {len(dataset)}")

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Worker {rank}] Training on: {device}")

    # Load ResNet18 pretrained
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, 2)
    model = model.to(device)
    model.train()

    # Loss, optimizer, AMP
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scaler = torch.cuda.amp.GradScaler()

    # Training loop
    EPOCHS = 3
    print(f"[Worker {rank}] Starting training for {EPOCHS} epochs")

    for epoch in range(EPOCHS):
        sampler.set_epoch(epoch)
        total_loss = 0.0

        for imgs, labels in dataloader:
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

        print(f"[Worker {rank}] Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")
        torch.cuda.synchronize()

    print(f"[Worker {rank}] Training finished!")

    # Return model only from worker 0
    if rank == 0:
        buffer = io.BytesIO()
        torch.save(model.state_dict(), buffer)
        buffer.seek(0)
        return buffer.getvalue()

    return None

# SPARK CONFIG

spark = (
    SparkSession.builder
    .appName("BrainTumor-ResNet18-Distributed-IPYNB")
    .master("spark://100.108.67.1:7077")
    .config("spark.executor.instances", "2")
    .config("spark.executor.resource.gpu.amount", "1")
    .config("spark.executor.resource.gpu.discoveryScript", "/usr/local/bin/get-gpus.sh")
    .config("spark.task.resource.gpu.amount", "1")
    .config("spark.executorEnv.NCCL_SOCKET_IFNAME", "tailscale0")
    .getOrCreate()
)

spark

In [11]:
print("Launching distributed training with AMP + SHARDING + 2 GPUs...")

model_bytes = TorchDistributor(
    num_processes=2,
    local_mode=False,
    use_gpu=True
).run(train_fn)

if model_bytes is not None:
    out_path = "/home/piero/brain_resnet18.pt"
    with open(out_path, "wb") as f:
        f.write(model_bytes)
    print(f"Modelo guardado correctamente en: {out_path}")
else:
    print("Worker secundario: no devuelve modelo.")

Launching distributed training with AMP + SHARDING + 2 GPUs...


INFO:TorchDistributor:Started distributed training with 2 executor processes
=== DISTRIBUTED RESNET18 TRAINING (AMP + SHARDING) ===              (0 + 2) / 2]
[Worker 0] World size: 2
[Worker 0] Dataset path: /mnt/spark_data/DATASET-RUIDO
=== DISTRIBUTED RESNET18 TRAINING (AMP + SHARDING) ===
[Worker 1] World size: 2
[Worker 1] Dataset path: /mnt/spark_data/DATASET-RUIDO
[Worker 1] Total images loaded: 5000                                (0 + 2) / 2]
[Worker 1] Training on: cuda
[Worker 1] Starting training for 3 epochs
[Worker 1] Epoch 1/3 - Loss: 13.2996
[Worker 0] Total images loaded: 5000                                (0 + 2) / 2]
[Worker 0] Training on: cuda
[Worker 0] Starting training for 3 epochs
[Worker 1] Epoch 2/3 - Loss: 2.9955
[Worker 1] Epoch 3/3 - Loss: 1.2369
[Worker 1] Training finished!
[Worker 0] Epoch 1/3 - Loss: 14.8137                                (0 + 2) / 2]
[Worker 0] Epoch 2/3 - Loss: 5.9047                                 (0 + 2) / 2]
[Worker 0] Epoch 3/3 -

Modelo guardado correctamente en: /home/piero/brain_resnet18.pt


In [12]:
spark.stop()