In [1]:
from __future__ import print_function

import argparse
import os

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.utils.data import DistributedSampler
from torchvision import datasets, transforms

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [3]:
def train(args, model, device, train_loader, epoch, writer):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    for batch_idx, (data, target) in enumerate(train_loader):
        # Attach tensors to the device.
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )
            niter = epoch * len(train_loader) + batch_idx
            writer.add_scalar("loss", loss.item(), niter)

In [4]:
def test(model, device, test_loader, writer, epoch):
    model.eval()

    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            # Attach tensors to the device.
            data, target = data.to(device), target.to(device)

            output = model(data)
            # Get the index of the max log-probability.
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    print("\naccuracy={:.4f}\n".format(float(correct) / len(test_loader.dataset)))
    writer.add_scalar("accuracy", float(correct) / len(test_loader.dataset), epoch)

In [5]:
def main(args=None):
    parser = argparse.ArgumentParser(description="PyTorch FashionMNIST Example")
    parser.add_argument("--batch-size", type=int, default=64, metavar="N", help="Batch size for training (default: 64)")
    parser.add_argument("--test-batch-size", type=int, default=1000, metavar="N", help="Batch size for testing (default: 1000)")
    parser.add_argument("--epochs", type=int, default=50, metavar="N", help="Number of epochs (default: 50)")
    parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="Learning rate (default: 0.01)")
    parser.add_argument("--momentum", type=float, default=0.5, metavar="M", help="SGD momentum (default: 0.5)")
    parser.add_argument("--no-cuda", action="store_true", default=False, help="Disable CUDA")
    parser.add_argument("--seed", type=int, default=1, metavar="S", help="Random seed (default: 1)")
    parser.add_argument("--log-interval", type=int, default=10, metavar="N", help="Log interval")
    parser.add_argument("--save-model", action="store_true", default=False, help="Save model")
    parser.add_argument("--dir", default="logs", metavar="L", help="Log directory")
    parser.add_argument("--backend", type=str, choices=["gloo", "nccl", "mpi"], default="gloo", help="Distributed backend")

    # Handle arguments passed manually or via CLI
    if args is None:
        args, _ = parser.parse_known_args()  # Ignore Jupyter arguments
    else:
        args = parser.parse_args(args)  # Manually passed args

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    print(f"Using CUDA: {use_cuda}")
    print(f"Backend: {args.backend}")

    writer = SummaryWriter(args.dir)
    torch.manual_seed(args.seed)

    model = Net().to(device)
    print(f"World Size: {os.environ.get('WORLD_SIZE', '1')}, Rank: {os.environ.get('RANK', '0')}")

    if "WORLD_SIZE" not in os.environ:
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "1234"

    print(f"World Size: {os.environ['WORLD_SIZE']}. Rank: {os.environ['RANK']}")
    
    dist.init_process_group(backend=args.backend)
    model = nn.parallel.DistributedDataParallel(model)

    train_ds = datasets.FashionMNIST("./data", train=True, download=True, transform=transforms.ToTensor())
    test_ds = datasets.FashionMNIST("./data", train=False, download=True, transform=transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, sampler=DistributedSampler(train_ds))
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=args.test_batch_size, sampler=DistributedSampler(test_ds))

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, epoch, writer)
        test(model, device, test_loader, writer, epoch)

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")


In [6]:
args_list = [
    "--batch-size", "32",
    "--test-batch-size", "1000",
    "--epochs", "1",
    "--lr", "0.005",
    "--momentum", "0.9",
    "--backend", "gloo"
]

In [10]:
if dist.is_initialized():
    dist.destroy_process_group()
    main(args_list)

Using CUDA: True
Backend: gloo
World Size: 1, Rank: 0
World Size: 1. Rank: 0

accuracy=0.8274

