In [28]:
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from time import time

In [29]:
train_data = datasets.MNIST(
    root = 'data',
    train = True,
    transform = ToTensor(),
    download = True
)

test_data = datasets.MNIST(
    root = 'data',
    train = False,
    transform = ToTensor(),
    download = True,
)

train_data.data = train_data.data.type(torch.float32)
train_data.targets = train_data.targets.type(torch.float32)

test_data.data = test_data.data.type(torch.float32)
test_data.targets = test_data.targets.type(torch.float32)

In [30]:

device = torch.device('cuda')


def fix_input(x):
    return x/255.0

# train_loader = DataLoader(dataset=train_data, batch_size=10, shuffle=False, num_workers=1)

# for idx, (data, target) in enumerate(train_loader):
#     print(data[0][0].shape)

def conv(y):
    arr = [0]*10
    arr[int(y.item())] = 1.0
    return torch.tensor(arr)


class CustomDataSet:
    def __init__(self, data, targets, transform = None):
        self.data = data
        self.targets = targets
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        target = conv(self.targets[idx])
        if self.transform:
            sample = self.transform(sample)
        return sample, target
    
customData = CustomDataSet(data=train_data.data, targets=train_data.targets, transform=fix_input)
train_loader = DataLoader(dataset=customData, batch_size=10, shuffle=True, num_workers=1)

class NeuralNet(nn.Module):

    def __init__(self, lr):
        # torch.manual_seed(1)
        super().__init__()
        self.c1 = nn.Conv2d(1, 10, 3)
        self.m1 = nn.MaxPool2d(2)
        self.c2 = nn.Conv2d(10, 20, 3)
        self.m2 = nn.MaxPool2d(2)
        self.f1 = nn.Linear(20 * 5 * 5, 50)  # Adjusted based on convolution and pooling operations
        self.f2 = nn.Linear(50, 10)

        self.weights = [self.c1.weight, self.f1.weight, self.f2.weight]
        self.biases = [self.c1.bias, self.f1.bias, self.f2.bias]

        self.lr = lr
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(self.weights + self.biases, lr = self.lr)

    def forward(self, x):
        x = torch.relu(self.c1(x))
        x = self.m1(x)
        x = torch.relu(self.c2(x))
        x = self.m2(x)
        x = x.view(-1, 20 * 5 * 5)  # Flatten the tensor for fully connected layers
        x = torch.relu(self.f1(x))
        x = self.f2(x)
        return x

    def conv(self, y):
        arr = [0]*10
        arr[int(y.item())] = 1.0
        return torch.tensor([arr], device=device)

    def predict(self, y):
        m = 0
        I = 0
        for i in range(10):
            if y[0][i] > m:
                I = i
                m = y[0][i]
        return I

    def train(self, training_loader, testing_data):
        testing_data.data, testing_data.targets = testing_data.data.to(device), testing_data.targets.to(device)
        for epoch in range(20):
            for idx, (data, targets) in enumerate(training_loader):
                data, targets = data.to(device), targets.to(device)
                self.optimizer.zero_grad()
                targets = targets.type(torch.float32)
                y_pred = self(data.view([10, 1, 28, 28]))
                y = targets
                l = self.loss(y_pred, y)
                l.backward()
                self.optimizer.step()
            correct = 0
            for i in range(len(testing_data.data)):
                y = self.forward(fix_input(testing_data.data[i]).view(1, 28, 28))
                z = self.predict(y)
                if z == testing_data.targets[i]:
                    correct += 1
            
            print(f"epoch {epoch}: {correct}/{len(testing_data.data)}")
        

n = NeuralNet(lr=0.01)
n.to(device)
n.train(train_loader, test_data)

epoch 0: 9375/10000
epoch 1: 9596/10000
epoch 2: 9627/10000
epoch 3: 9721/10000
epoch 4: 9766/10000
epoch 5: 9771/10000
epoch 6: 9774/10000
epoch 7: 9767/10000
epoch 8: 9803/10000
epoch 9: 9825/10000
epoch 10: 9833/10000
epoch 11: 9839/10000
epoch 12: 9830/10000
epoch 13: 9843/10000
epoch 14: 9835/10000
epoch 15: 9825/10000
epoch 16: 9836/10000
epoch 17: 9850/10000
epoch 18: 9834/10000
epoch 19: 9849/10000
