In [2]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms

In [29]:
# hyper-parameters
input_features = 784
num_classes = 10
num_epoch = 5
batch_size = 100
learning_rate = 0.002
l1_weight = 0.001
l2_weight = 0.001

# load dataset
train_data = torchvision.datasets.MNIST(root='/data', train=True, download=True, transform=transforms.ToTensor())
test_data = torchvision.datasets.MNIST(root='/data', train=False, transform=transforms.ToTensor())

# initiate data loader
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size)

In [42]:
class LogisticRegression(nn.Module):
    def __init__(self, input_features, num_classes, num_epoch, learning_rate, l1_weight=0, l2_weight=0):
        super(LogisticRegression, self).__init__()
        self.model = nn.Linear(input_features, num_classes)
        self.input_features = input_features
        self.num_epoch = num_epoch
        self.learning_rate = learning_rate
        self.l1_weight = l1_weight
        self.l2_weight = l2_weight
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
    
    def train(self, train_loader, model_name='logistic_regression.model', output_log_freq=0):
        """
        Train the model with given train_loader. Save the model if model name specified.
        """
        total = len(train_loader)
        for e in range(self.num_epoch):
            for i, (instances, labels) in enumerate(train_loader):
                instances = instances.reshape(-1, self.input_features)
                # Forward
                output = self.model(instances)
                # Calculate loss
                params = torch.cat([x.view(-1) for x in self.model.parameters()])
                l1_loss = 0 if self.l1_weight == 0 else torch.norm(params, 1)
                l2_loss = 0 if self.l2_weight == 0 else torch.norm(params, 2)
                loss = self.criterion(output, labels) + self.l1_weight * l1_loss + self.l2_weight * l2_loss
                # Update weights
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                if output_log_freq and (i + 1) % output_log_freq == 0:
                    print('Epoch %d/%d, trained %d/%d instances, Logloss: %.5f' % 
                          (e, self.num_epoch, i + 1, total, loss.item()))
        if model_name:
            torch.save(self.model.state_dict(), model_name)
            
    def predict(self, instances):
        """
        Predict the label with given training instance batch.
        """
        with torch.no_grad():
            instances = instances.reshape(-1, self.input_features)
            output = self.model(instances)  # tensor with dim [batch_size, 10] 
            return torch.max(output.data, 1)[1]  # idx of the max element for each instance indicates the class

In [43]:
# Train
lr = LogisticRegression(input_features, num_classes, num_epoch, learning_rate, l1_weight, l2_weight)
lr.train(train_loader, output_log_freq=100)
# Evaluate
correct, total = 0, 0
for images, labels in test_loader:
    total += labels.size(0)
    predicted = lr.predict(images)
    correct += (predicted == labels).sum()

print('Accuracy: %d/%d' % (correct, total))

Epoch 0/5, trained 100/600 instances, Logloss: 2.27502
Epoch 0/5, trained 200/600 instances, Logloss: 2.09551
Epoch 0/5, trained 300/600 instances, Logloss: 1.94368
Epoch 0/5, trained 400/600 instances, Logloss: 1.82699
Epoch 0/5, trained 500/600 instances, Logloss: 1.72551
Epoch 0/5, trained 600/600 instances, Logloss: 1.59623
Epoch 1/5, trained 100/600 instances, Logloss: 1.50132
Epoch 1/5, trained 200/600 instances, Logloss: 1.50830
Epoch 1/5, trained 300/600 instances, Logloss: 1.40051
Epoch 1/5, trained 400/600 instances, Logloss: 1.28609
Epoch 1/5, trained 500/600 instances, Logloss: 1.29996
Epoch 1/5, trained 600/600 instances, Logloss: 1.23473
Epoch 2/5, trained 100/600 instances, Logloss: 1.32620
Epoch 2/5, trained 200/600 instances, Logloss: 1.19099
Epoch 2/5, trained 300/600 instances, Logloss: 1.22588
Epoch 2/5, trained 400/600 instances, Logloss: 1.10004
Epoch 2/5, trained 500/600 instances, Logloss: 1.22927
Epoch 2/5, trained 600/600 instances, Logloss: 1.14389
Epoch 3/5,