In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

# Load MNIST dataset, but only keep digits 0 and 8
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Filter out only the digits 0 and 8
train_indices = np.where((mnist_train.targets == 0) | (mnist_train.targets == 8))[0]
test_indices = np.where((mnist_test.targets == 0) | (mnist_test.targets == 8))[0]

train_data = mnist_train.data[train_indices]
train_targets = mnist_train.targets[train_indices]
test_data = mnist_test.data[test_indices]
test_targets = mnist_test.targets[test_indices]

# Normalize the data to [0, 1]
train_data = train_data.float() / 255.0
test_data = test_data.float() / 255.0

# Flatten the images to vectors of size 784
train_data = train_data.view(-1, 28*28)
test_data = test_data.view(-1, 28*28)

# Convert targets to binary: 0 for '0' and 1 for '8'
train_targets = (train_targets == 8).float().view(-1, 1)
test_targets = (test_targets == 8).float().view(-1, 1)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 9912422/9912422 [00:08<00:00, 1216849.57it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 230279.36it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:01<00:00, 945847.39it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [2]:
import torch.nn as nn
import torch.optim as optim

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 300)
        self.fc2 = nn.Linear(300, 1)
    
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

# Initialize the network
net_hinge = SimpleNN()
net_ce = SimpleNN()
net_exp = SimpleNN()

# Initialize weights with Gaussian distribution
for net in [net_hinge, net_ce, net_exp]:
    for layer in net.modules():
        if isinstance(layer, nn.Linear):
            nn.init.normal_(layer.weight, mean=0, std=1/np.sqrt(layer.in_features + layer.out_features))
            nn.init.zeros_(layer.bias)


In [3]:
# Hinge Loss function
def hinge_loss(output, target):
    return torch.mean(torch.clamp(1 - output * (2 * target - 1), min=0))

# Exponential Loss function
def exponential_loss(output, target):
    return torch.mean(torch.exp(-output * (2 * target - 1)))

# Cross-Entropy Loss is provided by PyTorch
criterion_ce = nn.BCELoss()


In [4]:
def train_network(net, criterion, optimizer, train_data, train_targets, epochs=10):
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = net(train_data)
        loss = criterion(outputs, train_targets)
        loss.backward()
        optimizer.step()

# Optimizers
optimizer_hinge = optim.Adam(net_hinge.parameters(), lr=0.001)
optimizer_ce = optim.Adam(net_ce.parameters(), lr=0.001)
optimizer_exp = optim.Adam(net_exp.parameters(), lr=0.001)




In [24]:
# Train each network
train_network(net_hinge, hinge_loss, optimizer_hinge, train_data, train_targets)
train_network(net_ce, criterion_ce, optimizer_ce, train_data, train_targets)
train_network(net_exp, exponential_loss, optimizer_exp, train_data, train_targets)

In [25]:
# Evaluate the networks
with torch.no_grad():
    outputs_hinge = net_hinge(test_data)
    outputs_ce = net_ce(test_data)
    outputs_exp = net_exp(test_data)

# Apply the decision rule: classify as 8 if output > 0.5
y_pred_hinge = (outputs_hinge > 0.5).float()
y_pred_ce = (outputs_ce > 0.5).float()
y_pred_exp = (outputs_exp > 0.5).float()

# Compute error rates
error_rate_hinge = torch.mean((y_pred_hinge != test_targets).float())
error_rate_ce = torch.mean((y_pred_ce != test_targets).float())
error_rate_exp = torch.mean((y_pred_exp != test_targets).float())

print(f'Hinge Loss NN Error Rate: {error_rate_hinge.item()}')
print(f'Cross-Entropy NN Error Rate: {error_rate_ce.item()}')
print(f'Exponential Loss NN Error Rate: {error_rate_exp.item()}')


Hinge Loss NN Error Rate: 0.014329580590128899
Cross-Entropy NN Error Rate: 0.009723643772304058
Exponential Loss NN Error Rate: 0.019959058612585068
