In [2]:
import torch
import torch.onnx as onnx
import torchvision.models as models
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import torchvision


In [3]:
training_data = torchvision.datasets.MNIST(
    '/files/',
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                              torchvision.transforms.Normalize((0.1307,), (0.3081,))]
                                             )
)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
train_dataloader = DataLoader(training_data, 64, shuffle=True)
epochs = 3
learning_rate = 0.01
momentum = 0.5
log_interval = 10
batch_size = 64

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x1a7079efde0>

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [11]:
model = Net()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [12]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
linear_relu_stack.0.weight 	 torch.Size([512, 784])
linear_relu_stack.0.bias 	 torch.Size([512])
linear_relu_stack.2.weight 	 torch.Size([512, 512])
linear_relu_stack.2.bias 	 torch.Size([512])
linear_relu_stack.4.weight 	 torch.Size([10, 512])
linear_relu_stack.4.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5]}]


In [13]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (data, target) in enumerate(train_dataloader):
        # Compute prediction and loss
        pred = model(data)
        loss = loss_fn(pred, target)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(data)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [14]:
epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
print("Done!")

Epoch 1
-------------------------------
loss: 2.349210  [    0/60000]
loss: 1.909350  [ 6400/60000]
loss: 1.123892  [12800/60000]
loss: 0.687302  [19200/60000]
loss: 0.468531  [25600/60000]
loss: 0.470848  [32000/60000]
loss: 0.452021  [38400/60000]
loss: 0.457846  [44800/60000]
loss: 0.526327  [51200/60000]
loss: 0.425250  [57600/60000]
Epoch 2
-------------------------------
loss: 0.233128  [    0/60000]
loss: 0.236356  [ 6400/60000]
loss: 0.312516  [12800/60000]
loss: 0.370720  [19200/60000]
loss: 0.261536  [25600/60000]
loss: 0.433929  [32000/60000]
loss: 0.413369  [38400/60000]
loss: 0.379716  [44800/60000]
loss: 0.179866  [51200/60000]
loss: 0.208052  [57600/60000]
Epoch 3
-------------------------------
loss: 0.290081  [    0/60000]
loss: 0.215108  [ 6400/60000]
loss: 0.383265  [12800/60000]
loss: 0.229047  [19200/60000]
loss: 0.166082  [25600/60000]
loss: 0.265015  [32000/60000]
loss: 0.164164  [38400/60000]
loss: 0.217300  [44800/60000]
loss: 0.286389  [51200/60000]
loss: 0.27

In [15]:
torch.save(model.state_dict(), 'model_weights.pth')