In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import argparse, os, sys,time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser(description='MLP')
parser.add_argument('-device', default='cuda:0', help='device')
parser.add_argument('-b', default=200, type=int, help='batch size')
parser.add_argument('-epochs', default=10, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-j', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-data-dir', default='/data/tanghao/datasets/', type=str, help='root dir of dataset')
parser.add_argument('-opt', type=str, choices=['sgd', 'adam'], default='adam', help='use which optimizer. SGD or Adam')
parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
parser.add_argument('-lr', default=1e-2, type=float, help='learning rate')

args = parser.parse_args(args=[])
print(args)

os.environ['CUDA_LAUNCH_BLOCKING'] = '0'

Namespace(b=200, data_dir='/data/tanghao/datasets/', device='cuda:0', epochs=10, j=4, lr=0.01, momentum=0.9, opt='adam')


In [3]:
train_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=args.b,
    shuffle=True,
    drop_last=True,
    num_workers=args.j,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=args.b,
    shuffle=False,
    drop_last=False,
    num_workers=args.j,
    pin_memory=True
)

In [4]:
net=nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 200),
    nn.ReLU(),
    nn.Linear(200, 10),
    nn.Softmax(dim=1)
).to(args.device)

In [17]:
# modle=nn.Sequential(
#     nn.ReLU()
# ).to(args.device)
# for item in modle.parameters():
#     print(item)

# modle=nn.Sequential(
#     nn.Flatten(),
#     nn.Linear(784, 200, bias=False),
# ).to(args.device)
# for item in modle.parameters():
#     print(item.shape)

torch.Size([200, 784])


In [14]:
criteon=nn.CrossEntropyLoss().to(args.device)
optimizer=torch.optim.Adam(net.parameters(),lr=args.lr)
for item in net.parameters():
    print(item.shape)

torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 200])
torch.Size([10])


In [6]:
for epoch in range(args.epochs):
    net.train()
    train_loss = 0
    for i, (x, y) in enumerate(train_data_loader):
        x, y = x.to(args.device), y.to(args.device)

        y_hat = net(x)
        loss = criteon(y_hat, y)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    net.eval()
    test_loss = 0
    correct = 0
    for i, (x, y) in enumerate(test_data_loader):
        x, y = x.to(args.device), y.to(args.device)
        # print(x.shape, y.shape)
        y_hat = net(x)
        loss = criteon(y_hat, y)
        test_loss += loss.item()
        correct += (torch.argmax(y_hat, dim=1) == y).sum().item()

    print('epoch: {}, train loss: {}, test loss: {}, test acc: {}'.format(
        epoch, train_loss / len(train_data_loader),test_loss / len(test_data_loader), correct / len(test_dataset)))


epoch: 0, train loss: 1.5651118020216623, test loss: 1.521777045726776, test acc: 0.9405
epoch: 1, train loss: 1.5140315488974254, test loss: 1.5111462116241454, test acc: 0.9509
epoch: 2, train loss: 1.5046754709879557, test loss: 1.5038097500801086, test acc: 0.9566
epoch: 3, train loss: 1.498742919365565, test loss: 1.4996220993995666, test acc: 0.9614
epoch: 4, train loss: 1.4964753480752309, test loss: 1.4997141933441163, test acc: 0.9611
epoch: 5, train loss: 1.4937808867295583, test loss: 1.4975913095474243, test acc: 0.964
epoch: 6, train loss: 1.4928985182444254, test loss: 1.5058342504501343, test acc: 0.9551
epoch: 7, train loss: 1.4911694077650706, test loss: 1.4941966867446899, test acc: 0.9671
epoch: 8, train loss: 1.4915913927555084, test loss: 1.4972417855262756, test acc: 0.9635
epoch: 9, train loss: 1.4914813260237376, test loss: 1.501608681678772, test acc: 0.9594
