In [29]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchsummary import summary

In [2]:
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [3]:
# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001

In [4]:
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

In [20]:
# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7 * 7 * 32, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out        

In [30]:
model = ConvNet(num_classes).to(device)
# input = torch.randn(1, 1, 28, 28).to(device)
# flops, params = profile(model, inputs=(input,))
summary(model, (1, 28, 28))
# Loss and optimizer
criention = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 28, 28]             416
       BatchNorm2d-2           [-1, 16, 28, 28]              32
              ReLU-3           [-1, 16, 28, 28]               0
         MaxPool2d-4           [-1, 16, 14, 14]               0
            Conv2d-5           [-1, 32, 14, 14]          12,832
       BatchNorm2d-6           [-1, 32, 14, 14]              64
              ReLU-7           [-1, 32, 14, 14]               0
         MaxPool2d-8             [-1, 32, 7, 7]               0
            Linear-9                   [-1, 10]          15,690
Total params: 29,034
Trainable params: 29,034
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.47
Params size (MB): 0.11
Estimated Total Size (MB): 0.58
---------------------------------------------

In [31]:
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criention(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            

Epoch [1/5], Step [100/600], Loss: 0.1360
Epoch [1/5], Step [200/600], Loss: 0.1000
Epoch [1/5], Step [300/600], Loss: 0.0588
Epoch [1/5], Step [400/600], Loss: 0.1170
Epoch [1/5], Step [500/600], Loss: 0.1064
Epoch [1/5], Step [600/600], Loss: 0.1082
Epoch [2/5], Step [100/600], Loss: 0.0128
Epoch [2/5], Step [200/600], Loss: 0.0213
Epoch [2/5], Step [300/600], Loss: 0.0614
Epoch [2/5], Step [400/600], Loss: 0.0246
Epoch [2/5], Step [500/600], Loss: 0.0125
Epoch [2/5], Step [600/600], Loss: 0.0186
Epoch [3/5], Step [100/600], Loss: 0.0704
Epoch [3/5], Step [200/600], Loss: 0.0461
Epoch [3/5], Step [300/600], Loss: 0.0236
Epoch [3/5], Step [400/600], Loss: 0.0026
Epoch [3/5], Step [500/600], Loss: 0.0559
Epoch [3/5], Step [600/600], Loss: 0.0479
Epoch [4/5], Step [100/600], Loss: 0.0056
Epoch [4/5], Step [200/600], Loss: 0.0615
Epoch [4/5], Step [300/600], Loss: 0.0074
Epoch [4/5], Step [400/600], Loss: 0.0170
Epoch [4/5], Step [500/600], Loss: 0.0219
Epoch [4/5], Step [600/600], Loss:

In [35]:
# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
#         print(labels)
        outputs = model(images)
#         print(outputs)
        _, predicted = torch.max(outputs.data, 1)
#         print(predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2,
        4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0,
        2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 9, 3, 1, 4,
        1, 7, 6, 9], device='cuda:0')
tensor([[-9.0860e+00, -7.3969e+00, -2.7927e-01, -1.9821e+00, -1.1090e+01,
         -7.3146e+00, -2.1404e+01,  1.2939e+01, -6.8441e+00, -2.2263e+00],
        [ 2.8193e-01, -1.9993e+00,  1.0282e+01, -8.7293e+00, -9.3289e+00,
         -1.4058e+01, -1.9737e+00, -1.0954e+01, -8.0345e+00, -1.1366e+01],
        [-6.3464e+00,  8.9002e+00, -4.7230e+00, -8.2244e+00, -4.9963e-01,
         -5.1547e+00, -5.1104e+00, -1.3334e+00, -1.8235e+00, -7.0242e+00],
        [ 1.3139e+01, -1.5165e+01, -4.4673e+00, -1.1412e+01, -7.8199e+00,
         -5.6679e+00,  1.3268e+00, -7.6111e+00, -5.8088e+00, -6.5205e+00],
        [-9.1498e+00, -9.3250e+00, -9.1591e+00

tensor([[-5.5596e+00, -6.3560e+00, -4.6928e+00, -1.0844e+01,  5.2942e+00,
         -9.7638e+00, -6.3875e+00, -3.7991e+00, -5.7906e+00, -5.9550e+00],
        [-1.4144e+01, -9.3501e+00,  9.8693e-01,  7.3044e-01, -5.9132e+00,
         -8.0970e+00, -2.3603e+01,  1.0300e+01, -8.4682e+00, -5.8884e+00],
        [-5.9704e+00,  7.9959e+00, -3.6791e+00, -5.8779e+00, -3.8855e-01,
         -6.1348e+00, -3.8217e+00, -3.5142e+00, -6.1400e-01, -6.4266e+00],
        [-3.9317e+00, -1.1561e+01,  9.6852e+00, -3.2295e+00, -1.1528e+01,
         -6.5754e+00, -7.8891e+00, -3.6502e+00, -5.8781e+00, -6.7826e+00],
        [-1.5241e+01, -9.5392e+00, -7.2552e+00, -1.2419e+01,  1.3797e+01,
         -6.9111e+00, -1.6305e+01, -3.6163e+00,  3.0141e-01, -5.0122e+00],
        [ 1.3289e+01, -2.1722e+01, -6.5917e+00, -1.3399e+01, -2.0877e+01,
         -2.3221e+00,  2.6287e+00, -1.2431e+01, -1.9240e+00, -3.8823e+00],
        [-2.1545e+00, -1.1520e-01,  1.2688e+01, -5.2901e+00, -9.8789e+00,
         -1.5278e+01, -1.1066e+0

tensor([[-7.7761e+00,  7.3948e+00, -8.7423e-01, -4.4603e+00, -1.5609e+00,
         -5.8223e+00, -8.2695e+00, -3.6721e+00, -9.7252e-01, -5.3080e+00],
        [ 1.2382e+01, -1.0127e+01, -5.7743e-01, -7.4555e+00, -7.4008e+00,
         -9.1779e+00, -5.4375e+00, -9.6106e+00, -1.0123e+01, -5.1839e+00],
        [-5.6138e+00, -2.7049e+00,  1.2226e+00, -2.1489e+00, -8.6144e+00,
         -8.3962e+00, -1.9192e+01,  7.3804e+00, -3.2448e+00, -4.5677e-01],
        [-1.1968e+01, -1.0641e+01,  2.7564e+00,  2.3309e+00, -7.9189e+00,
         -1.1242e+01, -2.5576e+01,  1.1161e+01, -8.7659e+00, -2.3024e+00],
        [ 1.2444e+01, -1.5347e+01,  1.5342e+00, -9.6758e+00, -1.1453e+01,
         -1.2249e+01, -5.1587e+00, -1.1910e+01, -5.1338e+00, -6.2507e+00],
        [-8.5519e+00, -2.6006e+00, -4.5039e-01,  6.5025e-01, -9.8192e+00,
         -8.9085e+00, -2.3281e+01,  1.2521e+01, -1.0644e+01, -1.4686e+00],
        [-1.4753e+01, -1.4077e+01, -9.9641e+00, -4.9319e+00, -2.2016e+00,
         -4.0600e+00, -2.4111e+0

tensor([[-1.1565e+01, -1.7937e+00, -3.3164e+00, -2.5791e+00, -7.7718e+00,
         -6.8653e+00, -2.3584e+01,  1.0230e+01, -4.1759e+00, -1.3769e+00],
        [-1.0142e+01, -8.6311e+00, -2.7417e+00, -6.2663e-01, -7.0465e+00,
         -5.1095e+00, -6.3408e+00, -9.0114e+00,  3.0058e+00, -7.2495e+00],
        [-1.3319e+01, -2.1800e+01, -1.7464e+01, -4.1862e+00, -1.5170e+01,
          1.7440e+01, -5.9543e+00, -1.2850e+01, -9.9657e+00, -4.4628e+00],
        [-1.8252e+01, -1.7603e+01, -7.9324e+00, -6.8099e+00, -1.9635e+00,
         -1.8314e+00, -1.4773e+01, -4.0857e+00, -3.8235e+00,  8.8540e+00],
        [-1.1023e+01, -2.7510e+00, -7.2008e-01,  6.0318e-01, -1.0480e+01,
         -9.2724e+00, -2.2973e+01,  1.2299e+01, -6.3518e+00, -1.8431e+00],
        [-1.2768e+01, -1.5856e+01, -1.0613e+01, -3.3194e+00, -2.3059e+00,
         -5.8184e+00, -1.8776e+01,  7.6543e-01, -7.8674e+00,  8.8270e+00],
        [-1.2806e+00, -1.4004e+01, -9.7952e+00, -1.3133e+01, -4.3082e+00,
         -5.3775e+00,  1.4809e+0

tensor([[-1.4441e+01, -1.9936e+01, -9.1077e+00,  1.1619e+01, -1.6575e+01,
          5.7769e+00, -1.4747e+01, -1.1121e+01, -2.4964e+00, -2.0086e+00],
        [-1.0757e+01, -1.1497e+01,  2.0932e+00,  9.5221e+00, -1.5093e+01,
         -5.0390e+00, -2.2046e+01,  3.4336e+00, -1.2687e+01, -7.8994e+00],
        [-4.3579e+00, -1.0255e+01, -7.3554e+00, -6.4266e+00, -4.1498e+00,
         -7.1931e-01,  1.0570e+01, -1.5380e+01, -2.0662e+00, -9.4492e+00],
        [-5.9369e+00,  8.2793e+00, -3.4061e+00, -6.0687e+00, -9.6588e-01,
         -4.8171e+00, -3.5614e+00, -4.3831e+00, -7.9749e-01, -7.9180e+00],
        [-1.4444e+00, -4.5646e+00,  9.6894e+00, -3.1671e+00, -1.4670e+01,
         -1.2739e+01, -7.7724e+00, -6.1379e+00, -2.7106e+00, -1.2330e+01],
        [-8.4324e+00, -1.2781e+01, -3.3323e+00,  1.1719e+01, -1.9460e+01,
         -3.4314e+00, -1.4978e+01, -1.2616e+01, -1.7793e+00, -3.1215e+00],
        [-9.8161e+00, -6.0429e+00, -4.4399e+00, -3.0321e+00, -6.0210e+00,
         -6.7583e+00, -1.9646e+0

tensor([[-1.2289e+01, -1.9681e+01, -1.1714e+01, -5.1884e+00, -1.7561e+01,
          1.2000e+01, -7.5605e+00, -1.6204e+01,  1.0593e+00, -2.2334e+00],
        [-5.8384e+00, -3.2580e+00,  1.2353e+01,  6.9055e-01, -1.6429e+01,
         -1.4646e+01, -1.1479e+01,  1.2032e+00, -2.1829e+00, -1.3605e+01],
        [-6.8940e+00, -1.4587e+01, -4.1027e+00, -7.2927e+00, -1.0479e+01,
         -5.6588e+00, -9.5868e+00, -9.6079e+00,  1.2581e+01, -3.8137e+00],
        [-1.4382e+01, -1.6076e+01, -1.4508e+01, -3.8801e+00, -1.1280e+01,
          1.4443e+01, -1.1723e+01, -1.2932e+01, -4.2123e-01, -2.4035e+00],
        [-4.1170e+00, -9.2425e+00, -2.3587e+00, -1.0909e+00, -9.2169e+00,
         -3.7001e+00, -2.0396e+01,  1.0169e+01, -8.8615e+00,  1.3610e+00],
        [-8.3144e+00, -1.0557e+01, -6.6305e+00,  8.3528e+00, -1.5071e+01,
          2.0993e+00, -1.5912e+01, -1.1944e+01, -4.4655e-01, -5.5839e+00],
        [ 1.0445e+01, -1.3539e+01,  3.9636e-01, -1.0483e+01, -1.0562e+01,
         -8.6191e+00, -5.8084e+0

tensor([[-1.1279e+01, -7.1283e+00, -5.0280e+00, -5.2747e+00, -9.2391e+00,
         -5.6851e+00, -1.0702e+01, -1.1873e+01,  1.2028e+01, -7.7264e+00],
        [-1.2331e+00, -1.6362e+01, -4.8213e+00, -1.4477e+01, -1.0523e+01,
          3.2714e-01,  1.1537e+01, -1.4149e+01, -3.4278e-01, -9.7641e+00],
        [-6.6543e+00,  8.5806e-01,  1.1956e+01, -2.6239e+00, -8.3219e+00,
         -1.4314e+01, -1.6095e+01,  1.7832e+00, -6.5514e+00, -1.1108e+01],
        [-1.3855e+01, -1.6595e+01, -5.2278e+00, -7.4110e+00, -1.8171e-02,
         -7.4157e+00, -2.0718e+01, -1.4067e+00,  2.4036e-01,  9.4670e+00],
        [-8.6526e+00, -1.0627e+01, -8.9758e+00, -2.7259e+00, -1.2518e+01,
          1.0243e+01, -8.5142e+00, -1.0913e+01, -6.1108e+00, -3.4856e+00],
        [-1.4918e+01, -4.5551e+00, -4.6977e+00, -4.4910e+00, -3.2490e+00,
         -4.3292e+00, -2.1758e+01,  1.1225e+01, -8.2984e+00, -3.8498e-01],
        [-7.7339e+00, -2.0408e+01, -1.1133e+01, -3.2319e+00, -1.5878e+01,
          1.3635e+01, -3.3023e+0

tensor([[-8.4376e+00, -1.4109e+01, -6.6986e+00,  1.6347e+00, -1.6506e+01,
          7.5025e+00, -8.4158e+00, -8.2591e+00, -4.3471e+00, -6.9291e+00],
        [-7.4499e+00, -7.6941e+00, -4.5447e+00, -4.2507e+00, -5.5317e+00,
         -5.5791e+00, -2.0002e+01,  9.9038e+00, -8.1029e+00,  2.8621e-01],
        [-6.2875e+00, -6.4486e+00, -7.6621e+00, -2.5193e-01, -8.3718e+00,
          7.0954e+00, -1.0539e+01, -7.9819e+00, -1.0314e+01, -3.6467e+00],
        [-1.1095e+01, -2.0391e+00, -1.9106e+00, -1.2292e+00, -4.9745e+00,
         -5.1200e+00, -1.8005e+01,  9.4322e+00, -7.9367e+00, -2.2081e+00],
        [-1.1758e+01, -1.1311e+01, -6.1801e+00,  1.1397e+01, -1.3676e+01,
         -1.9953e+00, -1.4011e+01, -8.5286e+00, -3.0260e+00, -2.9583e+00],
        [-9.0751e+00, -1.2289e+01,  1.3801e+01, -4.6669e+00, -4.8530e+00,
         -1.1698e+01, -1.2198e+01, -6.8081e+00, -5.8130e+00, -1.1363e+01],
        [-5.9692e-01, -6.1756e+00,  5.8131e+00, -3.6339e+00, -1.1707e+01,
         -1.0492e+01, -2.4393e+0

tensor([[-1.2648e+00,  2.9243e+00,  9.6615e+00, -4.4786e+00, -1.6994e+01,
         -1.0213e+01, -9.8518e+00, -5.8932e+00, -3.8050e+00, -1.1845e+01],
        [-5.3980e+00,  4.2641e+00, -4.9500e+00, -5.7506e+00, -1.3895e-01,
         -3.1067e+00,  2.3551e+00, -7.3315e+00, -2.2783e+00, -9.1359e+00],
        [-3.3301e+00, -1.0051e+01, -6.5261e+00, -1.5064e+01, -6.3626e-01,
         -3.2656e+00,  1.4635e+01, -1.2269e+01, -7.9347e+00, -1.2607e+01],
        [-3.6537e+00, -1.5875e+01, -1.0211e+01, -1.3258e+01, -5.4771e+00,
         -3.0804e+00,  1.6476e+01, -1.2603e+01, -3.2477e+00, -1.3532e+01],
        [-5.3771e-01, -8.8515e+00,  4.3289e-01, -1.4121e+01, -3.8197e+00,
         -3.3747e+00, -1.1277e+01,  7.1164e+00, -9.5340e+00, -1.1380e+01],
        [-4.7630e+00,  8.0847e+00, -3.2567e+00, -8.2817e+00, -3.1748e+00,
         -5.5045e+00, -3.2813e+00, -4.6809e+00, -2.3753e-01, -7.6986e+00],
        [-5.9193e+00,  6.0157e+00, -4.3094e+00, -8.7792e+00, -1.6284e+00,
         -3.4005e+00, -4.3065e+0

tensor([[-8.0530e+00, -2.9322e+00,  9.3580e+00,  2.5942e+00, -1.7752e+01,
         -1.0448e+01, -1.0682e+01, -7.6185e+00, -4.1895e+00, -1.1034e+01],
        [-1.3818e+01, -1.1427e+01, -7.9445e+00, -5.3793e+00,  1.5065e+00,
         -5.4574e+00, -1.9968e+01, -6.9303e-02, -2.0363e+00,  9.3302e+00],
        [-6.8778e+00, -1.5493e+01, -1.0995e+01, -2.3024e+00, -6.9242e+00,
         -2.6715e+00, -1.1332e+01,  6.3829e+00, -7.6633e+00, -9.6646e-01],
        [-1.5326e+01, -1.8393e+00,  1.1631e+01, -1.4952e+00, -7.8742e+00,
         -1.3837e+01, -1.2875e+01, -4.7370e+00, -4.3586e+00, -1.0597e+01],
        [-3.9438e+00,  8.2997e+00, -2.3666e+00, -5.7777e+00, -4.4958e-01,
         -7.7937e+00, -6.4166e+00, -1.7515e+00, -1.4530e+00, -6.9666e+00],
        [-4.8988e+00,  7.3838e+00, -1.1926e+00, -6.5864e+00, -1.7315e+00,
         -9.0207e+00, -6.8137e+00, -1.9162e+00, -7.8732e-01, -6.5994e+00],
        [-1.1470e+01, -8.0391e+00, -1.7645e+00,  1.4457e+01, -1.7450e+01,
         -3.2518e+00, -1.4493e+0

tensor([[-1.0877e+01, -1.4518e+01, -2.1032e+00,  1.6287e+01, -2.0196e+01,
         -2.6665e+00, -1.6480e+01, -5.4421e+00, -2.8802e+00, -4.1722e+00],
        [-2.1298e+01, -3.8303e+00, -5.5920e+00, -6.4367e+00,  3.6606e-01,
         -6.1639e+00, -1.8558e+01, -5.3663e+00,  4.1726e+00, -9.8025e+00],
        [-5.1636e+00,  7.8278e+00, -2.7403e+00, -8.4969e+00,  3.4384e-01,
         -6.8264e+00, -6.1991e+00, -1.4943e+00, -2.1333e+00, -6.3599e+00],
        [-5.6217e+00,  7.9275e+00, -5.0112e+00, -5.2992e+00, -1.2724e-02,
         -6.1078e+00, -6.3711e+00, -1.7119e+00, -1.8839e+00, -4.7220e+00],
        [ 6.0640e+00, -2.0765e+01, -7.5235e+00, -1.0175e+01, -9.0429e+00,
         -1.5729e+00,  1.2587e+00, -6.0616e+00, -6.2036e+00, -4.6891e+00],
        [-4.8679e+00, -7.8776e+00,  1.8245e-01,  8.9981e+00, -1.2932e+01,
         -4.8743e+00, -1.0821e+01, -8.7754e+00, -5.0619e+00, -5.1053e+00],
        [-6.3368e+00,  8.7737e+00, -2.5343e+00, -8.4424e+00, -2.4985e-01,
         -7.6789e+00, -6.1763e+0

tensor([[-9.6013e+00, -1.1139e+01, -1.2177e+01, -4.3782e+00,  4.9576e-01,
         -5.3755e+00, -1.3708e+01, -3.6934e+00, -5.1352e+00,  6.5971e+00],
        [-1.3970e+01, -7.6042e+00, -1.1159e+01, -9.4892e+00,  1.3211e+01,
         -8.0408e+00, -1.4575e+01, -5.1962e+00, -4.7358e+00, -1.4746e+00],
        [-9.8570e+00, -1.6250e+01, -1.3565e+01,  1.5282e+00, -7.6007e+00,
          1.2054e+01, -9.9768e+00, -1.3653e+01, -9.8707e-01, -5.7568e+00],
        [-1.3331e+01, -1.6110e+01, -5.2695e+00, -8.6508e+00, -8.5999e+00,
          6.2599e-02, -1.2941e+01, -9.7792e+00,  1.0992e+01, -5.5496e+00],
        [-1.3007e+01, -9.3378e+00, -1.4569e+01, -1.2703e+01,  1.3127e+01,
         -7.4673e+00, -1.2982e+01,  6.2518e-01, -8.5332e+00, -4.3200e+00],
        [-1.2169e+01, -2.3520e+00,  7.5275e+00, -4.2539e+00, -4.9467e+00,
         -4.1025e+00, -1.2340e+01, -4.5688e+00, -1.0875e+01, -1.2659e+01],
        [-1.3171e+01, -1.3575e+01, -1.1465e+01, -3.6369e+00,  1.1526e+00,
         -7.5938e+00, -1.8542e+0

tensor([[-5.2030e+00, -1.3770e+00, -5.5938e+00, -7.9821e+00, -9.0871e-01,
         -7.3394e+00, -1.6030e+01,  6.0765e+00, -6.0667e+00,  1.1061e+00],
        [-6.7813e+00, -8.0629e+00,  7.3585e-02, -3.2151e+00, -4.2677e+00,
         -8.7034e+00, -9.5507e+00, -1.0224e+01,  6.3822e+00, -9.7442e+00],
        [-1.6146e+01, -1.5015e+01, -7.6871e+00, -6.9204e+00,  8.7885e-01,
         -6.5318e+00, -2.3818e+01,  3.5419e-01, -2.7472e+00,  1.1877e+01],
        [ 1.1438e+01, -2.0003e+01, -5.6908e+00, -1.3334e+01, -1.6316e+01,
         -8.0210e-01,  1.8661e+00, -1.1864e+01, -4.4486e+00, -5.2228e+00],
        [-6.3686e+00,  8.7875e+00, -6.1855e+00, -6.0768e+00, -1.0454e-01,
         -5.0016e+00, -6.6212e+00, -2.4264e+00, -2.2445e+00, -6.3669e+00],
        [-9.3233e+00, -1.1558e+01,  1.2330e+01, -1.8446e+00, -1.3029e+01,
         -1.5934e+01, -1.4271e+01, -7.1910e+00, -3.8332e+00, -7.4749e+00],
        [-6.4051e+00, -1.2137e+01, -2.7460e+00,  1.1454e+01, -1.7389e+01,
         -1.9015e+00, -1.8979e+0

tensor([[-1.4808e+01, -7.2975e+00, -5.5393e+00, -1.2813e+01,  1.0863e+01,
         -6.1280e+00, -1.4352e+01, -1.2131e+00, -8.3041e+00, -6.0176e+00],
        [-4.5315e+00,  4.3111e+00, -5.6012e+00, -7.0208e+00, -2.6874e+00,
         -3.1885e+00, -2.8576e+00, -6.5565e+00, -7.8005e-01, -9.3455e+00],
        [-1.1333e+01, -8.9132e+00, -6.0798e+00, -1.9102e+01,  1.4869e+01,
         -6.5376e+00, -1.0161e+01, -1.5907e+00, -9.7380e+00, -9.3789e+00],
        [-1.4487e+01, -9.8945e+00, -3.5727e+00,  1.2510e+01, -1.3686e+01,
          3.0331e-01, -1.8610e+01, -6.6555e+00, -8.0732e+00, -9.4466e+00],
        [-1.5918e+00, -1.2734e+01, -8.7486e+00, -1.1180e+01, -4.9630e+00,
         -2.4826e+00,  1.5126e+01, -1.2381e+01, -4.0934e+00, -1.2637e+01],
        [-1.3346e+01, -2.4497e+00, -5.1827e+00,  6.9148e-01, -1.0440e+01,
         -5.3845e+00, -2.3888e+01,  1.2734e+01, -8.1857e+00, -3.5696e+00],
        [-4.8677e+00, -4.4128e+00,  1.3305e+01, -2.0979e+00, -1.4390e+01,
         -1.3630e+01, -1.0554e+0

tensor([[-12.7415, -15.7149,  -7.2531,  -1.0275,  -5.7363,   9.9332,  -7.8616,
         -13.4653,  -6.4881,  -4.4202],
        [ -3.7313,   8.4561,  -0.7868,  -7.3369,  -0.7046,  -8.2354,  -8.4512,
          -1.5018,  -2.1264,  -7.2674],
        [ 14.6138, -19.3750,   1.4400,  -9.2506, -20.0485,  -7.2260,  -9.4948,
          -3.8354,  -2.4913,  -7.1860],
        [ -3.3962, -20.1567,  -4.0679,  -7.4556,  -7.1682,  -4.9331, -13.5614,
          -2.7118,   3.0041,  -1.2479],
        [ -4.7549,   8.8433,  -3.0754,  -8.9126,  -1.2973,  -7.0528,  -4.8861,
          -1.9123,  -1.2527,  -7.8569],
        [ -1.1458,  -1.3985,  -2.9306,  -8.1472,  -0.8721,  -1.8228,   4.3150,
          -8.5631,  -4.9425,  -8.8743],
        [ -8.9157,  -8.9360,  -3.0889,  -2.5034,  -6.5180,  -4.5342, -17.5896,
          10.4897, -10.0007,  -3.3916],
        [ -2.5432,  -9.5728,   8.3890,  -4.0372,  -4.1380,  -8.4781, -11.3790,
          -5.3745,  -2.7167,  -8.6407],
        [-11.2801, -11.6435, -12.2660,  -3.8374,

tensor([[-9.9072e+00, -2.8370e+00,  1.1139e+01, -9.4544e-01, -1.1558e+01,
         -1.0618e+01, -1.4742e+01,  1.9558e+00, -7.4565e+00, -1.1506e+01],
        [-1.9056e+01, -6.6681e+00, -1.0146e+01, -9.1064e+00,  1.1889e+01,
         -3.4200e+00, -1.8027e+01,  1.2440e+00, -9.8715e+00, -2.6945e+00],
        [-1.8117e+01, -5.0238e+00, -1.0660e+01, -8.5223e+00,  1.0754e+01,
         -4.2316e+00, -1.3457e+01, -4.3789e+00, -4.9080e+00, -3.1853e+00],
        [-1.3001e+01, -1.1704e+01, -3.7891e+00,  1.5463e+01, -1.9175e+01,
         -3.5682e-01, -1.3448e+01, -4.6282e+00, -4.5717e+00, -4.0783e+00],
        [-1.5267e+01, -1.5941e+01, -4.3863e-02,  1.5617e+01, -1.8206e+01,
         -2.7700e+00, -1.5084e+01, -4.9483e+00, -4.5435e+00, -4.0047e+00],
        [-2.7786e+00, -9.5232e+00, -7.6828e+00, -1.1988e+01, -2.7719e+00,
         -3.9635e+00,  1.1041e+01, -1.1833e+01, -6.7170e+00, -9.7172e+00],
        [-1.2825e+01, -6.2007e+00, -7.4721e+00, -3.0617e+00,  1.3904e+00,
         -6.6067e+00, -1.8085e+0

tensor([[-1.0266e+01, -5.9866e+00, -5.8328e+00, -5.4712e+00, -6.7950e+00,
         -4.6583e+00, -1.1072e+01, -8.1854e+00,  8.9923e+00, -3.0663e+00],
        [-1.0712e+01, -9.8417e+00, -5.2503e+00,  1.2177e+01, -1.5459e+01,
          3.1157e-01, -1.6903e+01, -6.8434e+00, -6.5360e+00, -6.1077e+00],
        [-9.4549e+00, -8.9151e+00, -6.4876e+00, -8.2063e+00, -1.0259e+01,
         -1.4791e+00, -1.3320e+01, -1.1781e+01,  1.0254e+01, -6.7563e+00],
        [ 1.5516e-01, -1.1220e+01, -5.0257e+00, -1.5366e+01, -2.5829e+00,
         -1.7330e+00,  8.2558e+00, -1.3068e+01, -6.7868e+00, -7.7408e+00],
        [-8.4466e+00, -5.6660e+00,  2.3479e+00,  9.4211e-01, -1.4113e+01,
         -1.2354e+01, -1.9124e+01,  1.2066e+01, -5.8942e+00, -2.4328e-01],
        [ 6.8036e+00, -6.4576e+00, -1.2390e+00, -1.0723e+01,  1.3350e-01,
         -1.1386e+01, -1.6053e+00, -9.6123e+00, -4.7634e+00, -8.3668e+00],
        [ 1.4127e+01, -1.1651e+01, -5.4265e-01, -1.3688e+01, -8.4578e+00,
         -6.4902e+00,  2.5560e+0

tensor([[-1.9232e+00,  6.0570e+00, -1.1812e+00, -6.3966e+00, -2.7839e+00,
         -8.9472e+00, -3.6492e+00, -5.2188e+00, -1.8307e+00, -8.4453e+00],
        [ 2.3878e-01, -1.2537e+01, -2.4362e+00, -6.9688e+00, -8.7007e+00,
         -5.4773e+00, -6.4597e+00, -1.0071e+01,  6.9143e+00, -8.4670e+00],
        [-7.5342e+00, -5.9173e+00, -3.0055e+00, -2.7535e+00, -6.6406e+00,
         -6.3815e+00, -1.8768e+01,  5.9662e+00, -4.5079e+00, -1.5406e+00],
        [-7.5575e+00, -4.4108e+00, -1.1753e+00, -8.4360e-01, -8.4744e+00,
         -6.5636e+00, -1.8633e+01,  7.2070e+00, -5.5883e+00, -3.2479e+00],
        [-3.2992e+00, -1.2103e+01, -7.9278e+00, -9.0289e+00, -5.0960e+00,
         -1.6826e+00,  1.0078e+01, -1.3026e+01, -6.2795e+00, -1.0634e+01],
        [-5.2756e+00, -9.5111e+00, -3.9890e+00,  5.4473e+00, -9.3138e+00,
         -2.7772e+00, -1.3795e+01, -9.9254e+00, -5.1113e+00, -3.2142e+00],
        [-2.0330e+00, -1.0971e+01, -9.1085e+00, -8.3798e+00, -3.4658e+00,
         -1.6662e+00,  9.9728e+0

tensor([[-11.4972, -21.8755,  -5.3326,  -6.1951,  -8.6491,  -2.5806, -16.5220,
          -8.7899,  12.3053,  -4.0108],
        [-13.3267, -16.2065,  -6.9065,  -6.0967,  -8.3103,  -4.4653, -12.7039,
         -11.0748,  11.6295,  -5.3451],
        [ -6.8919,  -9.6525,   1.7511,   0.5539, -14.3335,  -5.4650, -21.9504,
          11.9698,  -8.3843,  -5.1491],
        [ -5.5350,  10.2432,  -2.9298,  -8.6911,  -1.2030,  -5.3102,  -5.7938,
          -2.6887,  -2.4058,  -8.5609],
        [ -3.3290, -12.3476,   2.1931,   2.3262, -15.7363,  -4.9952, -19.0572,
           8.0486,  -9.0758,  -4.2188],
        [ -6.5846,   9.3771,  -2.9566,  -6.4852,  -0.7365,  -4.4229,  -6.0855,
          -3.4739,  -3.0143,  -7.9625],
        [ -6.9773,  10.3685,  -2.3761,  -8.1981,  -0.9208,  -6.4553,  -5.1853,
          -2.5669,  -2.2809,  -9.1733],
        [ 12.6817, -13.6603,  -3.5269, -10.6356,  -9.8564,  -8.2516,   0.1169,
         -15.9258,  -6.5842,  -6.3889],
        [-13.7852, -13.1069,  -9.7841,  11.7592,

tensor([[-8.0024e+00, -5.9207e+00,  1.6561e+01, -6.3767e+00, -7.1190e+00,
         -1.3549e+01, -1.1941e+01, -1.1092e+00, -8.9405e+00, -1.4476e+01],
        [-1.2887e+01, -1.2575e+01, -6.1427e+00,  1.4984e+01, -1.7884e+01,
          5.9896e-01, -1.5721e+01, -8.7847e+00, -3.1103e+00, -3.4241e+00],
        [-1.1455e+01, -8.4992e+00, -8.5864e+00, -1.6539e+01,  1.5759e+01,
         -1.0128e+01, -1.1755e+01, -1.3314e+00, -1.1888e+01, -7.5624e+00],
        [-7.7478e+00, -1.8923e+01, -1.4004e+01, -2.5513e+00, -1.1612e+01,
          1.5423e+01, -6.1793e+00, -1.7511e+01, -4.5860e-01, -5.4379e+00],
        [-6.1482e-01, -1.0198e+01, -4.5912e+00, -9.4751e+00, -2.4817e+00,
         -7.2389e+00,  1.0650e+01, -1.1475e+01, -6.6699e+00, -1.1133e+01],
        [-8.5380e+00, -2.1237e+00,  1.8021e+00, -7.7807e-02, -9.4053e+00,
         -1.2102e+01, -2.1503e+01,  1.1563e+01, -6.9985e+00, -1.0389e+00],
        [-1.2107e+01, -1.1748e+00, -9.1732e+00, -9.9420e+00, -3.2832e+00,
         -3.3341e+00, -7.7227e+0

tensor([[-12.5918, -10.5408,  -3.7406,  -0.7793,  -9.7176,  -3.5453, -23.0686,
          15.0314, -12.3610,  -2.1034],
        [ -3.5935,   7.4395,  -2.2996,  -4.9456,  -1.7123,  -5.5442,  -4.0905,
          -3.1227,  -0.2971,  -7.4237],
        [-12.2994, -10.5524,  -4.1147,  -0.3366,  -9.2256,  -2.5773, -21.8255,
          12.4734,  -9.9366,  -1.3924],
        [ -0.6282, -14.6112,  10.6654,  -3.2117, -12.4727, -12.4564, -11.1213,
          -6.1478,  -2.0310,  -9.0436],
        [ -3.2425, -11.6760,  14.4316,  -2.6924, -10.7111, -12.5587, -11.6637,
          -4.3488,  -2.6764, -10.7585],
        [ -4.1372, -17.9730,  -0.2612,   9.8569, -16.4592,  -5.4001,  -9.1358,
          -6.4989,  -2.6898,  -6.3401],
        [ -4.7775, -15.5623,  -9.6605, -13.2871,  -4.0007,  -2.1396,  15.0080,
         -13.1099,  -5.3068, -11.9263],
        [ -8.5620, -18.6144,  -2.7504,  -1.3442, -16.8230,  -5.6694, -12.6835,
         -15.1373,  14.4313,  -1.6071],
        [ -1.9520, -11.9378,   0.5487,   5.9110,

tensor([[ 7.3309e+00, -1.6100e+01, -1.3214e+00, -1.0507e+01, -7.3490e+00,
         -1.1292e+01, -6.3199e+00, -9.6663e+00, -3.3046e+00, -2.0865e+00],
        [ 1.4869e+01, -1.3723e+01,  1.5317e-01, -1.5328e+01, -1.3509e+01,
         -7.4453e+00,  2.2554e-01, -1.0510e+01, -2.1896e+00, -6.3573e+00],
        [-8.1260e+00,  1.0262e+01, -5.7046e+00, -9.1803e+00,  6.9052e-01,
         -5.8454e+00, -5.3599e+00, -1.8801e+00, -2.2289e+00, -6.6038e+00],
        [-5.2940e+00, -2.8694e+00,  1.3236e+01,  2.9708e+00, -1.9586e+01,
         -1.5262e+01, -1.3557e+01, -5.3770e+00, -3.9310e-01, -9.5519e+00],
        [-9.5631e+00, -1.2347e+01, -5.3046e+00,  1.4526e+01, -1.6456e+01,
          1.7186e+00, -1.3171e+01, -9.3844e+00, -4.6362e+00, -5.8969e+00],
        [-1.4602e+01, -6.3299e+00, -1.3315e+01, -1.4939e+01,  1.5567e+01,
         -5.2836e+00, -1.2724e+01,  4.9154e-01, -9.9792e+00, -7.5596e+00],
        [-6.7091e+00, -2.3044e+00, -6.8786e-01, -1.5989e-03, -8.3050e+00,
         -8.2862e+00, -2.2723e+0

In [None]:
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')