In [None]:
import torch
from torchvision import datasets
from torchvision.transforms import Lambda, ToTensor
from torch.utils.data import DataLoader


one_hot_encoding = Lambda(lambda label: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(label), value=1))

data_train = datasets.MNIST(
  "data",
  train=True,
  download=True,
  transform=ToTensor(),
  target_transform=one_hot_encoding
)
data_test = datasets.MNIST(
  "data",
  train=False,
  download=True,
  transform=ToTensor(),
  target_transform=one_hot_encoding
)
loader_train = DataLoader(dataset=data_train, batch_size=32, shuffle=True)
loader_test = DataLoader(dataset=data_test, batch_size=32, shuffle=True)

In [None]:
import matplotlib.pyplot as plt


plt.imshow(data_train[2][0].squeeze(), cmap="gray_r")

In [None]:
from torch import nn


class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self._flatten = nn.Flatten()
    self._model = nn.Sequential(
      nn.Linear(28*28, 128),
      nn.ReLU(),
      nn.Dropout(0.5),
      nn.Linear(128, 32),
      nn.Sigmoid(),
      nn.Linear(32, 10),
      nn.Softmax(dim=1)
    )

  def forward(self, x):
    return self._model(self._flatten(x))
  

#np.random.seed(666)
#tf.random.set_seed(666)

In [None]:
model = Model()
print(model)

In [None]:
from torch import nn, optim


loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())  

def train():
  for batch, (X, y) in enumerate(loader_train):
    pred = model(X)
    loss = loss_function(pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if batch % 100 == 0:
      print(f"Batch {batch} -> loss: {loss.item()}")


def test():
  test_loss = 0.0
  accuracy = 0.0
  with torch.no_grad():
    for X, y in loader_test:
      pred = model(X)
      test_loss += loss_function(pred, y).item()
      accuracy += (torch.argmax(pred, dim=1) == torch.argmax(y, dim=1)).type(torch.float).sum()
  test_loss /= len(loader_test)
  accuracy /= len(loader_test.dataset)
  print(f"Test -> Loss: {test_loss}, accuracy: {accuracy}")


In [None]:
torch.manual_seed(666)

for epoch in range(1, 21):
  print(f"Epoch: {epoch}")
  train()
  test()


In [None]:
test()