In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchmetrics
from torchmetrics.classification import Accuracy
from tqdm import tqdm

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
train_dataset= torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset= torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [4]:
class LSTMModel(nn.Module):
    def __init__(self, input_size=28, hidden_size=128, num_layers=2, num_classes=10):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        out, _ = self.lstm(x, (h0, c0))

        out = self.fc(out[:, -1, :])
        return out
    

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs=2
for epoch in tqdm(range(epochs), desc="Epoch"):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.squeeze(1).to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss=criterion(outputs,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if (i+1) % 100 == 0:
        #     print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

Epoch: 100%|██████████| 2/2 [00:13<00:00,  6.78s/it]


In [12]:
model.eval()
accuracy = Accuracy(task='multiclass', num_classes=10).to(device)

with torch.no_grad():
    for images, labels in test_loader:
        images = images.squeeze(1).to(device)
        lables = labels.to(device)

        outputs = model(images)
        accuracy.update(outputs,lables)

final_accuracy= accuracy.compute()

print(f"Accuracy of the model on the test images: {final_accuracy  * 100: 0.2f}")

Accuracy of the model on the test images:  98.08
