In [None]:
import argparse
import json
import logging
import os
import sys

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision
from torchvision import datasets, transforms

In [None]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))

# Define models
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Define data augmentation
def _get_transforms():
        transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
        return transform

In [None]:
# Define data loader for training dataset
def _get_train_data_loader(batch_size):
    train_set = torchvision.datasets.CIFAR10(train=True, 
                                             download=False, 
                                             transform=_get_transforms()) 
    
    return torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True)

# Define data loader for test dataset
def _get_test_data_loader(test_batch_size):
    test_set = torchvision.datasets.CIFAR10(train=False, 
                                            download=False, 
                                            transform=_get_transforms())
    
    return torch.utils.data.DataLoader(
        test_set,
        batch_size=test_batch_size,
        shuffle=True)

In [None]:
def save_model(model, model_dir):
    logger.info("Saving the model.")
    path = os.path.join(model_dir, "model.pth")
    torch.save(model.module.state_dict(), path)

In [None]:
def test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(
        "Test set: Average loss: {:.4f}, Accuracy: {:.2f}\n".format(
            test_loss, correct / len(test_loader.dataset)
        )
    )

In [None]:
seed = 0
batch_size = 64
test_batch_size = batch_size
lr = 0.001
momentum = 0.9
epochs = 10


# Set the seed for generating random numbers
torch.manual_seed(args.seed)

train_loader = _get_train_data_loader(args.batch_size)
test_loader  = _get_test_data_loader(args.test_batch_size)

device = torch.device("hpu")
model = Net().to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

for epoch in range(1, args.epochs + 1):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader, 1):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            logger.info(
                "Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.sampler),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )
    test(model, test_loader, device)

save_model(model)

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # PyTorch environments
    parser.add_argument("--model-type",type=str,default='resnet18',
                        help="custom model or resnet18")
    parser.add_argument("--batch-size",type=int,default=64,
                        help="input batch size for training (default: 64)")
    parser.add_argument("--test-batch-size",type=int,default=1000,
                        help="input batch size for testing (default: 1000)")
    parser.add_argument("--epochs",type=int,default=10,
                        help="number of epochs to train (default: 10)")
    parser.add_argument("--lr", type=float, default=0.01,
                        help="learning rate (default: 0.01)")
    parser.add_argument("--momentum", type=float, default=0.5,
                        help="SGD momentum (default: 0.5)")
    parser.add_argument("--seed", type=int, default=1,
                        help="random seed (default: 1)")
    parser.add_argument("--log-interval",type=int,default=100,
                        help="how many batches to wait before logging training status")
    parser.add_argument("--backend",type=str,default='gloo',
                        help="backend for dist. training, this script only supports gloo")
