In [1]:
import numpy as np
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [2]:
class LeNet(nn.Module):

    def __init__(self, n_classes, in_channel):
        super().__init__()
        self.n_classes = n_classes
        self.in_channel = in_channel

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channel, out_channels=6, kernel_size=5),
            nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(),
            nn.AvgPool2d(2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
            nn.ReLU()
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.ReLU(),
            nn.Linear(in_features=84, out_features=self.n_classes),
            nn.Softmax()
        )

    def forward(self, X):
        X = self.feature_extractor(X)
        X = torch.flatten(X, 1)
        probs = self.classifier(X)
        return probs

In [3]:
transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor()])

In [4]:
train_dataset = datasets.MNIST(root='mnist_data', 
                               train=True, 
                               transform=transform,
                               download=True)

valid_dataset = datasets.MNIST(root='mnist_data', 
                               train=False, 
                               transform=transform)

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=32, 
                          shuffle=True)

valid_loader = DataLoader(dataset=valid_dataset, 
                          batch_size=32, 
                          shuffle=False)

In [5]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7fbafb42bcd0>

In [6]:
model = LeNet(n_classes=10, in_channel=1)

In [10]:
for param in model.parameters():
    print(param.numel())

150
6
2400
16
48000
120
10080
84
840
10


In [9]:
for X_train, y_train in train_loader:
    

SyntaxError: expected ':' (2522566847.py, line 1)