In [2]:
import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Lambda

In [1]:
device = 'cuda'

In [3]:
transform_func = Compose([
        Resize((100, 100)),
        ToTensor(),
        Lambda(lambda x: x.repeat(3, 1, 1))
])

In [4]:
train_dataset = MNIST(root='data', train=True, transform=transform_func, download=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 12.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 344kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.22MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.08MB/s]


In [5]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [6]:
model = nn.Sequential(
    nn.Conv2d(3, 32, 3),
    nn.MaxPool2d(3),
    nn.Flatten(),
    nn.LazyLinear(10)
).to(device)

In [7]:
model

Sequential(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (2): Flatten(start_dim=1, end_dim=-1)
  (3): LazyLinear(in_features=0, out_features=10, bias=True)
)

In [8]:
optimizer = Adam(model.parameters(), lr=0.0001)

In [9]:
criterion = nn.CrossEntropyLoss()

In [10]:
def train(model, train_loader, optimizer, criterion, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            step_loss = loss.item()
            total_loss += step_loss
            print(f'Epoch {epoch + 1}/{num_epochs}, Step Loss: {step_loss}, Total Loss: {total_loss:.4f}')

In [11]:
train(model, train_loader, optimizer, criterion, 5)

Epoch 1/5, Step Loss: 2.3444550037384033, Total Loss: 2.3445
Epoch 1/5, Step Loss: 2.2937450408935547, Total Loss: 4.6382
Epoch 1/5, Step Loss: 2.2514564990997314, Total Loss: 6.8897
Epoch 1/5, Step Loss: 2.0929720401763916, Total Loss: 8.9826
Epoch 1/5, Step Loss: 2.057987928390503, Total Loss: 11.0406
Epoch 1/5, Step Loss: 2.015839099884033, Total Loss: 13.0565
Epoch 1/5, Step Loss: 1.9165098667144775, Total Loss: 14.9730
Epoch 1/5, Step Loss: 1.901453971862793, Total Loss: 16.8744
Epoch 1/5, Step Loss: 1.8035236597061157, Total Loss: 18.6779
Epoch 1/5, Step Loss: 1.7552063465118408, Total Loss: 20.4331
Epoch 1/5, Step Loss: 1.7303587198257446, Total Loss: 22.1635
Epoch 1/5, Step Loss: 1.5902764797210693, Total Loss: 23.7538
Epoch 1/5, Step Loss: 1.6815308332443237, Total Loss: 25.4353
Epoch 1/5, Step Loss: 1.4099233150482178, Total Loss: 26.8452
Epoch 1/5, Step Loss: 1.4625771045684814, Total Loss: 28.3078
Epoch 1/5, Step Loss: 1.428359031677246, Total Loss: 29.7362
Epoch 1/5, Step 