In [1]:
import torch
from torchvision import datasets, transforms
import numpy as np
import os

# Download MNIST dataset
transform = transforms.ToTensor()
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Split into two shards: 0-4 and 5-9
data = mnist.data.numpy()
targets = mnist.targets.numpy()

shard1_idx = np.isin(targets, [0,1,2,3,4])
shard2_idx = np.isin(targets, [5,6,7,8,9])

shard1_data = data[shard1_idx]
shard1_targets = targets[shard1_idx]
shard2_data = data[shard2_idx]
shard2_targets = targets[shard2_idx]

os.makedirs('mnist_shards', exist_ok=True)
np.savez_compressed('mnist_shards/mnist_0_4.npz', data=shard1_data, targets=shard1_targets)
np.savez_compressed('mnist_shards/mnist_5_9.npz', data=shard2_data, targets=shard2_targets)

print(f"Shard 1 (digits 0-4): {shard1_data.shape[0]} samples saved to mnist_shards/mnist_0_4.npz")
print(f"Shard 2 (digits 5-9): {shard2_data.shape[0]} samples saved to mnist_shards/mnist_5_9.npz")

Shard 1 (digits 0-4): 30596 samples saved to mnist_shards/mnist_0_4.npz
Shard 2 (digits 5-9): 29404 samples saved to mnist_shards/mnist_5_9.npz


In [2]:
import numpy as np
import os
from torchvision import datasets, transforms

# Download MNIST dataset if not already present
transform = transforms.ToTensor()
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data = mnist.data.numpy()
targets = mnist.targets.numpy()

os.makedirs('mnist_shards', exist_ok=True)
for digit in range(10):
    idx = (targets == digit)
    digit_data = data[idx]
    digit_targets = targets[idx]
    np.savez_compressed(f'mnist_shards/mnist_{digit}.npz', data=digit_data, targets=digit_targets)
    print(f"Digit {digit}: {digit_data.shape[0]} samples saved to mnist_shards/mnist_{digit}.npz")

Digit 0: 5923 samples saved to mnist_shards/mnist_0.npz
Digit 1: 6742 samples saved to mnist_shards/mnist_1.npz
Digit 2: 5958 samples saved to mnist_shards/mnist_2.npz
Digit 3: 6131 samples saved to mnist_shards/mnist_3.npz
Digit 4: 5842 samples saved to mnist_shards/mnist_4.npz
Digit 5: 5421 samples saved to mnist_shards/mnist_5.npz
Digit 6: 5918 samples saved to mnist_shards/mnist_6.npz
Digit 7: 6265 samples saved to mnist_shards/mnist_7.npz
Digit 8: 5851 samples saved to mnist_shards/mnist_8.npz
Digit 9: 5949 samples saved to mnist_shards/mnist_9.npz


In [3]:
import torch.nn as nn
import torchvision.models as models

class MNISTResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(weights=None)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, 10)
    def forward(self, x):
        return self.model(x)


In [4]:
%%writefile train_ddp_mnist.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torchvision.models as models

class MNISTResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(weights=None)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, 10)
    def forward(self, x):
        return self.model(x)

def load_shard(rank):
    shard_file = f"mnist_shards/mnist_0_4.npz" if rank == 0 else f"mnist_shards/mnist_5_9.npz"
    data = np.load(shard_file)
    X = torch.tensor(data['data'], dtype=torch.float32).unsqueeze(1) / 255.0
    y = torch.tensor(data['targets'], dtype=torch.long)
    return TensorDataset(X, y)

def train(rank, world_size):
    print ('training on ', rank)
    dist.init_process_group('gloo', rank=rank, world_size=world_size)
    dataset = load_shard(rank)
    loader = DataLoader(dataset, batch_size=128, shuffle=True)
    model = MNISTResNet18()
    model = nn.parallel.DistributedDataParallel(model)
    device = torch.device('cpu')
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(3):
        for X, y in loader:
            optimizer.zero_grad()
            out = model(X)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
        if rank == 0:
            print(f"Epoch {epoch+1} complete. Loss: {loss.item():.4f}")
    dist.destroy_process_group()

if __name__ == "__main__":
    world_size = 2
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

Overwriting train_ddp_mnist.py


In [12]:
%%writefile train_ddp_mnist.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torchvision.models as models
import os

os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '12355'

class MNISTResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(weights=None)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.fc = nn.Linear(self.model.fc.in_features, 10)
    def forward(self, x):
        return self.model(x)

def load_digit_shard(digit):
    # Each process loads only its digit's data from its own shard file
    data = np.load(f'mnist_shards/mnist_{digit}.npz')
    X = torch.tensor(data['data'], dtype=torch.float32).unsqueeze(1) / 255.0
    y = torch.tensor(data['targets'], dtype=torch.long)
    return TensorDataset(X, y)

def train(rank, world_size):
    dist.init_process_group('gloo', rank=rank, world_size=world_size)
    print(f'Train on {rank} cpu')
    dataset = load_digit_shard(rank)
    loader = DataLoader(dataset, batch_size=128, shuffle=True)
    model = MNISTResNet18()
    model = nn.parallel.DistributedDataParallel(model)
    device = torch.device('cpu')
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(3):
        loadercnt = 0
        for X, y in loader:
            optimizer.zero_grad()
            out = model(X)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            loadercnt += 1
            if loadercnt > 10:
                break
        if rank == 0:
            print(f"Digit {rank} Epoch {epoch+1} complete. Loss: {loss.item():.4f}")
    dist.destroy_process_group()

if __name__ == "__main__":
    world_size = 10
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

Overwriting train_ddp_mnist.py


In [13]:
!MASTER_ADDR=127.0.0.1 MASTER_PORT=12355 python3 train_ddp_mnist.py

Train on 0 cpu
Train on 1 cpu
Train on 2 cpu
Train on 4 cpu
Train on 5 cpu
Train on 6 cpu
Train on 8 cpu
Train on 9 cpu
Train on 7 cpu
Train on 3 cpu
Digit 0 Epoch 1 complete. Loss: 0.8005
Digit 0 Epoch 2 complete. Loss: 0.1087
Digit 0 Epoch 3 complete. Loss: 0.0248
