In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import v2
import matplotlib.pyplot as plt

In [18]:
train_dataset = MNIST(root='./data', train=True, download=True)
print(f"Mean: {train_dataset.data.float().mean() / 255 :.4f}, Std: {train_dataset.data.float().std() / 255 :.4f}")

Mean: 0.1307, Std: 0.3081


In [7]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"

transform = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float32),
    v2.Normalize((0.1307,), (0.3081,)),
])

train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)

simple_net = nn.Sequential(
    nn.Linear(784, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)
simple_net.to(device)

optimizer = torch.optim.SGD(simple_net.parameters(), lr=6e-4, momentum=0.1)
predicted = []

num_epochs = 10
for epoch in range(num_epochs):
    for x, y in train_dataloader:
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)
        loss = F.cross_entropy(simple_net(x.view(-1, 784)), y)
        loss.backward()
        optimizer.step()


Mean: 0.13066047430038452, Std: 0.30810779333114624
40.347740173339844
24.37765121459961
16.33115577697754
15.093634605407715
13.008306503295898
8.981240272521973
8.369606018066406
7.229691505432129
6.735559463500977
5.913252830505371
5.736696243286133
5.391661167144775
5.295699119567871
3.7199862003326416
3.403721570968628
4.079446792602539
3.756119966506958
3.0613150596618652
2.8826122283935547
2.710071563720703
3.1743717193603516
3.089848756790161
1.7940576076507568
2.4353833198547363
1.654994010925293
3.312417984008789
1.715819239616394
1.7545311450958252
2.0808768272399902
2.2289843559265137
1.715102195739746
2.0075881481170654
1.9252630472183228
1.714141607284546
1.5232714414596558
1.7340434789657593
1.2624908685684204
1.7349687814712524
1.0181705951690674
1.4983348846435547
1.4755183458328247
1.3784185647964478
1.0657784938812256
2.1511497497558594
1.7339783906936646
0.7533106803894043
1.8691492080688477
1.6833820343017578
1.6166330575942993
1.1942081451416016
1.5218901634216309

In [14]:
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)
test_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)


with torch.no_grad():
    for x, y in test_dataloader:
        x, y = x.to(device), y.to(device)
        # print(x_train.view(784))
        # print(F.softmax(simple_net(x_train.view(-1, 784)), dim=-1), y_train)
        # print(simple_net(x_train.view(-1, 784)).shape, y_train.shape)
        logits = simple_net(x.view(-1, 784))
        loss = F.cross_entropy(logits, y)
        print(loss.item(), logits.argmax().item(), y.item())
        # if loss.item() < 1:
        #     break

1.1920928244535389e-07 0 0
0.00038795097498223186 7 7
1.1920928244535389e-07 6 6
0.004955747164785862 8 8
0.12830333411693573 7 7
0.00010227633902104571 4 4
6.794906312279636e-06 1 1
2.3364747903542593e-05 6 6
0.0002836778585333377 5 5
-0.0 4 4
1.8596476365928538e-05 7 7
0.10055291652679443 3 3
0.022245166823267937 0 0
-0.0 0 0
1.6331539882230572e-05 8 8
8.701899787411094e-05 3 3
0.00012635385792236775 4 4
0.020606480538845062 7 7
0.004875912796705961 7 7
2.777537883957848e-05 8 8
8.809178689261898e-05 7 7
0.0001941730733960867 9 9
0.33869045972824097 7 7
0.013368846848607063 3 3
0.0001858300092862919 9 9
8.344646857949556e-07 1 1
4.494089080253616e-05 2 2
8.439661905867979e-05 1 1
0.0012756790965795517 7 7
0.00023016665363684297 5 5
0.0010982679668813944 5 5
-0.0 4 4
0.01916557177901268 5 5
0.04898957535624504 7 7
4.768370445162873e-07 1 1
1.1920928244535389e-07 9 9
3.814624506048858e-05 0 0
5.960462772236497e-07 7 7
-0.0 6 6
0.003937231842428446 7 7
-0.0 0 0
0.0006902219611220062 1 1

KeyboardInterrupt: 