In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class DefaultLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(DefaultLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.weight, -bound, bound)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import random
import numpy as np

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

class SimpleNet(nn.Module):
    def __init__(self, linear_module):
        super(SimpleNet, self).__init__()
        self.linear = linear_module(784, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear(x)

def train_and_evaluate(linear_module):
    model = SimpleNet(linear_module)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    for epoch in range(10):
        model.train()
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = correct / len(test_loader.dataset)
        print(f'Epoch {epoch + 1}: Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}')

set_seed()
print("Training with nn.Linear:")
train_and_evaluate(nn.Linear)

set_seed()
print("Training with DefaultLinear:")
train_and_evaluate(DefaultLinear)

Training with nn.Linear:
Epoch 1: Test Loss: 0.0055, Accuracy: 0.8971
Epoch 2: Test Loss: 0.0048, Accuracy: 0.9102
Epoch 3: Test Loss: 0.0047, Accuracy: 0.9144
Epoch 4: Test Loss: 0.0047, Accuracy: 0.9135
Epoch 5: Test Loss: 0.0046, Accuracy: 0.9163
Epoch 6: Test Loss: 0.0044, Accuracy: 0.9223
Epoch 7: Test Loss: 0.0047, Accuracy: 0.9123
Epoch 8: Test Loss: 0.0049, Accuracy: 0.9091
Epoch 9: Test Loss: 0.0050, Accuracy: 0.9095
Epoch 10: Test Loss: 0.0047, Accuracy: 0.9183
Training with DefaultLinear:
Epoch 1: Test Loss: 0.0055, Accuracy: 0.8971
Epoch 2: Test Loss: 0.0048, Accuracy: 0.9102
Epoch 3: Test Loss: 0.0047, Accuracy: 0.9144
Epoch 4: Test Loss: 0.0047, Accuracy: 0.9135
Epoch 5: Test Loss: 0.0046, Accuracy: 0.9163
Epoch 6: Test Loss: 0.0044, Accuracy: 0.9223
Epoch 7: Test Loss: 0.0047, Accuracy: 0.9123
Epoch 8: Test Loss: 0.0049, Accuracy: 0.9091
Epoch 9: Test Loss: 0.0050, Accuracy: 0.9095
Epoch 10: Test Loss: 0.0047, Accuracy: 0.9183
