In [65]:
import torch
from torch import nn

class SimpleNet(nn.Module):
    def __init__(self, in_channels=1, out_features=10):
        super(SimpleNet, self).__init__()
        self.conv_layers = nn.Sequential(

            nn.Conv2d(in_channels=in_channels, out_channels=20, kernel_size=3),
            nn.ReLU(inplace=False),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(in_channels=20, out_channels=25, kernel_size=3),
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channels=25, out_channels=30, kernel_size=3),
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channels=30, out_channels=35, kernel_size=3)
        )

        self.fcs = nn.Sequential(
            nn.Linear(in_features=1260, out_features=out_features)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.shape[0], -1)
        x = self.fcs(x)
        return x 

In [66]:
from torchvision.datasets import MNIST
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from time import sleep
from torchvision.transforms import ToTensor, Normalize, Compose

In [67]:
train_data = MNIST(root='./mnist_train.pt', train=True, download=True, transform=Compose([ToTensor(), Normalize(0, 1)]))
test_data = MNIST(root='./mnist_test.pt', train=False, download=True, transform=Compose([ToTensor(), Normalize(0, 1)]))

In [68]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
model = SimpleNet(in_channels=1, out_features=10).to(device)
optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()
BATCH_SIZE = 200
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)


train_loader = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    num_workers=1
)

In [69]:
for epoch in range(1, 71):
    with tqdm(train_loader, unit="batch") as tepoch:
        for data, target in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            predictions = output.argmax(dim=1, keepdim=True).squeeze()
            loss = loss_fn(output, target)
            correct = (predictions == target).sum().item()
            accuracy = correct / BATCH_SIZE
            
            loss.backward()
            optimizer.step()

            tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
            sleep(0.1)
        scheduler.step()    

Epoch 1: 100%|██████████| 300/300 [01:01<00:00,  4.86batch/s, accuracy=97.5, loss=0.156] 
Epoch 2: 100%|██████████| 300/300 [01:03<00:00,  4.73batch/s, accuracy=98.5, loss=0.129] 
Epoch 3: 100%|██████████| 300/300 [01:03<00:00,  4.75batch/s, accuracy=99, loss=0.118]   
Epoch 4: 100%|██████████| 300/300 [01:04<00:00,  4.64batch/s, accuracy=99, loss=0.107]   
Epoch 5: 100%|██████████| 300/300 [01:02<00:00,  4.83batch/s, accuracy=99, loss=0.101]    
Epoch 6: 100%|██████████| 300/300 [01:00<00:00,  4.96batch/s, accuracy=99.5, loss=0.0973] 
Epoch 7: 100%|██████████| 300/300 [01:04<00:00,  4.68batch/s, accuracy=99.5, loss=0.0957] 
Epoch 8: 100%|██████████| 300/300 [01:02<00:00,  4.81batch/s, accuracy=99.5, loss=0.0936] 
Epoch 9: 100%|██████████| 300/300 [01:02<00:00,  4.79batch/s, accuracy=99.5, loss=0.0915] 
Epoch 10: 100%|██████████| 300/300 [01:01<00:00,  4.87batch/s, accuracy=99.5, loss=0.0906] 
Epoch 11: 100%|██████████| 300/300 [01:02<00:00,  4.81batch/s, accuracy=99.5, loss=0.0903] 
E

In [70]:
test_loader = DataLoader(
    dataset=test_data
)

In [71]:
total = 0
correct = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predictions = torch.max(outputs.data, dim=1)
        total += labels.size(0)
        correct += (predictions == labels).sum().item()

print('Accuracy == {} %'.format(100 * correct / total))

Accuracy == 99.22 %


In [72]:
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

33600
