In [1]:
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
import numpy as np
import random

np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 512
num_epochs = 15

train_dataset = FashionMNIST('./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)

In [48]:
class SwiGLU(nn.Module):
    def __init__(self):
        super(SwiGLU, self).__init__()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        a, b = x.chunk(2, dim=-1)
        return a * self.sigmoid(b)
    


In [58]:
class MLP(nn.Module):
    def __init__(self, input_dims, hidden_dims, output_dims):
        super(MLP, self).__init__()
        self.layer1 = nn.Linear(input_dims, hidden_dims*2)
        self.bn1 = nn.BatchNorm1d(hidden_dims*2)
        self.layer2 = nn.Linear(hidden_dims, hidden_dims*2)
        self.bn2 = nn.BatchNorm1d(hidden_dims*2)
        self.layer3 = nn.Linear(hidden_dims, hidden_dims*2)
        self.bn3 = nn.BatchNorm1d(hidden_dims*2)
        self.output = nn.Linear(hidden_dims, output_dims)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                ## using he initialization
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                nn.init.constant_(m.bias, 0.0)


    def forward(self, x):
        x = nn.Flatten()(x)
        x = self.layer1(x)
        x = self.bn1(x)
        x = SwiGLU()(x)
        identity = x # skip connection
        x = self.layer2(x)
        x = self.bn2(x)
        x = SwiGLU()(x)
        x = self.layer3(x)
        x = self.bn3(x)
        x = SwiGLU()(x)
        x = x + identity
        out = self.output(x)

        return out

In [59]:
model = MLP(input_dims=784, hidden_dims=128, output_dims=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())


In [60]:
for epoch in range(num_epochs):    
    t_loss = 0
    t_acc = 0
    cnt = 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        
        t_loss += loss.item()
        t_acc += (torch.argmax(outputs, 1) == y).sum().item()
        cnt += len(y)

    t_loss /= len(train_loader)
    t_acc /= cnt
    print(f"Epoch {epoch+1}/{num_epochs}, Train_Loss: {t_loss:.4f}, Train_Acc: {t_acc:.4f}")

Epoch 1/15, Train_Loss: 0.5069, Train_Acc: 0.8222
Epoch 2/15, Train_Loss: 0.3406, Train_Acc: 0.8766
Epoch 3/15, Train_Loss: 0.2944, Train_Acc: 0.8920
Epoch 4/15, Train_Loss: 0.2615, Train_Acc: 0.9051
Epoch 5/15, Train_Loss: 0.2384, Train_Acc: 0.9118
Epoch 6/15, Train_Loss: 0.2174, Train_Acc: 0.9199
Epoch 7/15, Train_Loss: 0.2002, Train_Acc: 0.9267
Epoch 8/15, Train_Loss: 0.1837, Train_Acc: 0.9320
Epoch 9/15, Train_Loss: 0.1707, Train_Acc: 0.9371
Epoch 10/15, Train_Loss: 0.1551, Train_Acc: 0.9433
Epoch 11/15, Train_Loss: 0.1465, Train_Acc: 0.9456
Epoch 12/15, Train_Loss: 0.1326, Train_Acc: 0.9523
Epoch 13/15, Train_Loss: 0.1235, Train_Acc: 0.9543
Epoch 14/15, Train_Loss: 0.1113, Train_Acc: 0.9595
Epoch 15/15, Train_Loss: 0.1053, Train_Acc: 0.9621
