In [1]:
"""
=============================================================
MNIST Training on Habana Gaudi using Lazy Mode (HPU)
=============================================================
This script demonstrates a simple CNN trained on MNIST using
Habana Gaudi hardware with Lazy Mode enabled for performance.
=============================================================
"""

# =============================================================
# 1. Importing Libraries
# =============================================================
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.utils.data import DataLoader
import time

# =============================================================
# 2. Environment Setup and Device Configuration
# =============================================================

# Set the computation device to CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# =============================================================
# 3. Define the CNN Model
# =============================================================
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)   # ← 14x14 because of one pooling
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(self.relu(self.conv2(x)))   # (batch, 64, 14, 14)
        x = x.view(x.size(0), -1)                 # Flatten correctly
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)


# =============================================================
# 4. Data Loading and Preprocessing
# =============================================================
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8)
test_loader  = DataLoader(test_dataset,  batch_size=128, shuffle=False, num_workers=4)


# =============================================================
# 5. Training Setup
# =============================================================
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 3


# =============================================================
# 6. Training Loop
# =============================================================
print("\nStarting Training...\n")
start_time = time.time()
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for inputs, labels in pbar:
        inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        pbar.set_postfix({'loss': running_loss / (len(pbar) + 1)})

    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {running_loss / len(train_loader):.4f}")

end_time = time.time()
print(f"\nTotal training time: {end_time - start_time:.2f} seconds")
print("\nTraining complete.")

# =============================================================
# 7. Evaluation on Test Data
# =============================================================
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"\nTest Accuracy: {100 * correct / total:.2f}%")
print("=============================================================")

Using device: cuda


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 554kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.09MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 20.8MB/s]



Starting Training...



Epoch 1/3: 100%|██████████| 469/469 [00:08<00:00, 57.12it/s, loss=0.169] 


Epoch [1/3] - Avg Loss: 0.1690


Epoch 2/3: 100%|██████████| 469/469 [00:02<00:00, 162.59it/s, loss=0.0527]


Epoch [2/3] - Avg Loss: 0.0528


Epoch 3/3: 100%|██████████| 469/469 [00:03<00:00, 154.20it/s, loss=0.0341] 

Epoch [3/3] - Avg Loss: 0.0342

Total training time: 14.14 seconds

Training complete.






Test Accuracy: 99.03%
