In [1]:
import torch
from torchsummary import summary
from datasets import load_dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
device = 0 if torch.cuda.is_available() else "cpu"

In [3]:
def calculate_loss(input, target):
    # loss = F.cross_entropy(input, target)
    loss = F.nll_loss(input, target)
    return loss

In [10]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(8, 16, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(8)
        self.bn2 = nn.BatchNorm2d(16)
        self.bn3 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(288, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x = nn.AvgPool2d(2, stride=2)(x)
        x = F.relu(self.bn2(self.conv2(x)), inplace=True)
        x = nn.AvgPool2d(2, stride=2)(x)
        x = F.relu(self.bn3(self.conv3(x)), inplace=True)
        x = nn.AvgPool2d(2, stride=2)(x)
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x, dim=1)


# class Net(nn.Module):
    
#     def __init__(self):
#         super(Net, self).__init__()
#         self.fc1 = nn.Linear(784,512)
#         self.fc2 = nn.Linear(512,256)
#         self.fc3 = nn.Linear(256,128)
#         self.fc4 = nn.Linear(128,10)
        
#     def forward(self, x):
#         x = x.view(-1,784)
#         x = F.relu(nn.LayerNorm(512)(self.fc1(x)))
#         x = F.relu(nn.LayerNorm(256)(self.fc2(x)))
#         x = F.relu(nn.LayerNorm(128)(self.fc3(x)))
#         x = self.fc4(x)
#         return F.softmax(x, dim=1)

In [5]:
from torch.utils.data import DataLoader
from torchvision.transforms import transforms, Compose

In [6]:
transform = Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

dataset = load_dataset("mnist")

def transform_(examples):
    examples['pixel_values'] = [transform(image.convert('L')) for image in examples['image']]
    del examples['image']
    return examples

dataset_with_transformed = dataset.with_transform(transform_)
train_data_loader = DataLoader(dataset_with_transformed['train'], batch_size=256, shuffle=True)
test_data_loader = DataLoader(dataset_with_transformed['test'], batch_size=16, shuffle=False)

Found cached dataset mnist (/home/ygq/.cache/huggingface/datasets/mnist/mnist/1.0.0/fda16c03c4ecfb13f165ba7e29cf38129ce035011519968cdaf74894ce91c9d4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
from torchvision import datasets
trainsform_ = Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
    ])

def get_train_dataloader():
    dataset = datasets.MNIST(root='../data/', train=True, transform=trainsform_, download=True)
    train_data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    return train_data_loader

def get_test_dataloader():
    dataset = datasets.MNIST(root='../data/', train=False, transform=trainsform_, download=True)
    test_data_loader = DataLoader(dataset, batch_size=1000, shuffle=False)
    return test_data_loader


In [19]:
def train_loop(epoch, model):
    model.train()
    train_data_loader = get_train_dataloader()
    print("Our model is training...")
    for batch_idx, (data, target) in enumerate(train_data_loader):
        optimizer.zero_grad()
        # data = batch['pixel_values'].to(device)
        # target = batch['label'].to(device)
        predict = model(data.cuda())
        loss = calculate_loss(predict, target.cuda())
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f"epoch {epoch}, batch_idx {batch_idx}, loss is {loss.item()}")
            
            
@torch.no_grad()
def test_loop(epoch, model):
    model.eval()
    test_data_loader = get_test_dataloader()
    print("Our model is testing...")
    counts = 0
    correct_counts = 0
    for batch_idx, (data, target) in enumerate(test_data_loader):
        # data = batch['pixel_values'].to(device)
        # target = batch['label'].to(device)
        predict = torch.argmax(model(data.cuda()), dim=1)
        counts += data.shape[0]
        correct_counts += sum(torch.where(predict == target.cuda(), 1, 0)).item()
    print(f"{epoch} accuracy is {correct_counts / counts}")

In [20]:
if __name__ == "__main__":
    epochs = 20
    model = Net()
    model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(epochs):
        train_loop(epoch, model)
        test_loop(epoch, model)


Our model is training...
epoch 0, batch_idx 0, loss is -0.09854584187269211
epoch 0, batch_idx 100, loss is -0.15113712847232819
epoch 0, batch_idx 200, loss is -0.3827395737171173
epoch 0, batch_idx 300, loss is -0.5021599531173706
epoch 0, batch_idx 400, loss is -0.6136916875839233
epoch 0, batch_idx 500, loss is -0.60197913646698
epoch 0, batch_idx 600, loss is -0.6772567629814148
epoch 0, batch_idx 700, loss is -0.7940143346786499
epoch 0, batch_idx 800, loss is -0.817430853843689
epoch 0, batch_idx 900, loss is -0.7458800673484802
epoch 0, batch_idx 1000, loss is -0.9080413579940796
epoch 0, batch_idx 1100, loss is -0.8226488828659058
epoch 0, batch_idx 1200, loss is -0.8863486051559448
epoch 0, batch_idx 1300, loss is -0.8287241458892822
epoch 0, batch_idx 1400, loss is -0.9047046303749084
epoch 0, batch_idx 1500, loss is -0.8200705051422119
epoch 0, batch_idx 1600, loss is -0.8367782235145569
epoch 0, batch_idx 1700, loss is -0.9089874625205994
epoch 0, batch_idx 1800, loss is -