In [22]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader



In [23]:
class CustomDataset(Dataset):
    def __init__(self,data,labels):
        self.data = data
        self.labels = labels
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        return self.data[idx],self.labels[idx]

In [24]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((.1307,),(.3801,))])

In [25]:
train_dataset = datasets.MNIST('data',train=True,download=True, transform=transform)
test_dataset = datasets.MNIST('data',train=False,transform=transform)

In [26]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,num_workers=4)
test_loader = DataLoader(test_dataset,batch_size=32,shuffle=True,num_workers=4)

In [27]:
for images, labels in train_loader:
    print("Batch shape:" , images.shape)
    print("Labels shape", labels.shape)
    break

Batch shape: torch.Size([32, 1, 28, 28])
Labels shape torch.Size([32])


In [37]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1,32,3)
        self.pool = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(32*13*13,10)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.pool(self.relu(self.conv1(x)))
        x = x.view(-1,32*13*13)
        x = self.fc1(x)
        return x



model = ConvNet()

In [38]:
sample = next(iter(train_loader))[0][:1]

output = model(sample)
print("input shape:", sample.shape)
print("Output shape", output.shape)
print("Predicted digit:", output.argmax().item())

input shape: torch.Size([1, 1, 28, 28])
Output shape torch.Size([1, 10])
Predicted digit: 5


In [39]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

In [41]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_losses = []
    train_accs = []
    running_loss = 0.0
    running_acc = 0.0
    
    # Your existing training loop here
    for batch_idx, (data,labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        acc = accuracy(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_acc += acc
    
    # Save epoch metrics
    avg_loss = running_loss/len(train_loader)
    avg_acc = running_acc/len(train_loader)
    train_losses.append(avg_loss)
    train_accs.append(avg_acc)
    
    print(f'Epoch {epoch+1}: Loss = {avg_loss:.4f}, Accuracy = {avg_acc:.4f}')

Epoch 1: Loss = 0.1617, Accuracy = 0.9525
Epoch 2: Loss = 0.0683, Accuracy = 0.9791
Epoch 3: Loss = 0.0517, Accuracy = 0.9838
Epoch 4: Loss = 0.0403, Accuracy = 0.9876
Epoch 5: Loss = 0.0321, Accuracy = 0.9896
Epoch 6: Loss = 0.0272, Accuracy = 0.9914
Epoch 7: Loss = 0.0201, Accuracy = 0.9934
Epoch 8: Loss = 0.0169, Accuracy = 0.9946
Epoch 9: Loss = 0.0137, Accuracy = 0.9956
Epoch 10: Loss = 0.0107, Accuracy = 0.9967


In [42]:
torch.save(model.state_dict(), 'mnist_model.pth')
model.load_state_dict(torch.load('mnist_model.pth'))
model.eval()

# Test on new data
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on test set: {100 * correct / total:.2f}%') 


  model.load_state_dict(torch.load('mnist_model.pth'))


Accuracy on test set: 98.31%
