In this example:

* We define a simple neural network SimpleNN with one hidden layer.
* We load the MNIST dataset and define a DataLoader for batching and shuffling.
* We check if CUDA (GPU support) is available and move the model to the available device.
* If multiple GPUs are available, we wrap the model with nn.DataParallel to enable DataParallelism.
* We define a loss function (CrossEntropyLoss) and an optimizer (SGD).
* We loop over the dataset multiple times, performing forward and backward passes through the network and updating the model parameters.
* By utilizing DataParallelism with PyTorch, the workload will be automatically distributed across multiple GPUs if available, leading to faster training times.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define transformations for the dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Check if CUDA (GPU support) is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = SimpleNN().to(device)

# If multiple GPUs are available, wrap the model with DataParallel
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs for training.")
    model = nn.DataParallel(model)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Training loop
for epoch in range(5):  # Loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if i % 100 == 99:  # Print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 27386602.79it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 26687749.25it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 20604780.04it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 7119031.68it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

[1,   100] loss: 0.914
[1,   200] loss: 0.420
[1,   300] loss: 0.348
[1,   400] loss: 0.346
[1,   500] loss: 0.321
[1,   600] loss: 0.294
[1,   700] loss: 0.253
[1,   800] loss: 0.226
[1,   900] loss: 0.237
[2,   100] loss: 0.197
[2,   200] loss: 0.202
[2,   300] loss: 0.191
[2,   400] loss: 0.180
[2,   500] loss: 0.160
[2,   600] loss: 0.168
[2,   700] loss: 0.167
[2,   800] loss: 0.146
[2,   900] loss: 0.147
[3,   100] loss: 0.129
[3,   200] loss: 0.128
[3,   300] loss: 0.130
[3,   400] loss: 0.133
[3,   500] loss: 0.123
[3,   600] loss: 0.114
[3,   700] loss: 0.118
[3,   800] loss: 0.128
[3,   900] loss: 0.121
[4,   100] loss: 0.102
[4,   200] loss: 0.111
[4,   300] loss: 0.095
[4,   400] loss: 0.101
[4,   500] loss: 0.097
[4,   600] loss: 0.107
[4,   700] loss: 0.097
[4,   800] loss: 0.086
[4,   900] loss: 0.096
[5,   100] loss: 0.079
[5,   200] loss: 0.086
[5,   300] loss: 0.080
[5,   400] loss: 0.077
[5,  