In [244]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

import os

In [245]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# hyper-parameters
input_features = 784
num_classes = 10
num_epoch = 1
batch_size = 100
learning_rate = 0.002
l1_weight = 0.001
l2_weight = 0.001

# load dataset
train_data = torchvision.datasets.MNIST(root='/tmp/data', train=True, download=True, transform=transforms.ToTensor())
test_data = torchvision.datasets.MNIST(root='/tmp/data', train=False, transform=transforms.ToTensor())

# initiate data loader
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size)

cuda:0


In [246]:
# Define a logistic regression
class LogisticRegression(nn.Module):
    def __init__(self, input_features, num_classes, num_epoch, learning_rate, l1_weight=0, l2_weight=0):
        super(LogisticRegression, self).__init__()
        self.model = nn.Linear(input_features, num_classes)
        self.input_features = input_features
        self.num_epoch = num_epoch
        self.learning_rate = learning_rate
        self.l1_weight = l1_weight
        self.l2_weight = l2_weight
            
    def forward(self, instances):
        """
        Predict the label with given training instance batch.
        """
        instances = instances.reshape(-1, self.input_features).to(device)
        output = self.model(instances).to(device)  # tensor with dim [batch_size, 10] 
        return output
#         return torch.max(output.data, 1)[1]  # idx of the max element for each instance indicates the class
        

In [247]:
def train(model, criterion, optimizer, train_loader, model_name='logistic_regression.model', num_epoch=5, output_log_freq=0):
    """
    Train the model with given train_loader. Save the model if model name specified.
    """
    total = len(train_loader)
    l1_weight = model.l1_weight if hasattr(model, "l1_weight") else 0
    l2_weight = model.l2_weight if hasattr(model, "l2_weight") else 0
    if os.path.exists(model_name):
        print("load model")
        model.load_state_dict(torch.load(model_name, map_location=device))
    
    for e in range(num_epoch):
        for i, (instances, labels) in enumerate(train_loader):
            instances = instances.reshape(-1, model.input_features).to(device)
            labels = labels.to(device)
            # Forward
            output = model(instances)
            # Calculate loss
            params = torch.cat([x.view(-1) for x in model.parameters()])
            loss = criterion(output, labels)
            if l1_weight > 0 and l2_weight > 0:
                l1_loss = 0 if model.l1_weight == 0 else torch.norm(params, 1)
                l2_loss = 0 if model.l2_weight == 0 else torch.norm(params, 2)
                loss += l1_weight * l1_loss + l2_weight * l2_loss
            # Update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if output_log_freq and (i + 1) % output_log_freq == 0:
                print('Epoch %d/%d, trained %d/%d instances, Logloss: %.5f' % 
                        (e, model.num_epoch, i + 1, total, loss.item()))
    if model_name:
        print("write model")
        torch.save(model.state_dict(), model_name)

In [248]:
# Train LR
lr = LogisticRegression(input_features, num_classes, num_epoch, learning_rate, l1_weight, l2_weight).to(device)

# train the model
criterion = nn.CrossEntropyLoss()
lr_optimizer = torch.optim.SGD(lr.parameters(), lr=lr.learning_rate)

        
train(lr, criterion, lr_optimizer, train_loader, num_epoch=1, output_log_freq=100)


# Evaluate
correct, total = 0, 0
for images, labels in test_loader:
    total += labels.size(0)
    predicted = lr.forward(images)
    correct += (torch.max(predicted, 1)[1] == labels.to(device)).sum()

print('Accuracy: %d/%d' % (correct, total))

load model
Epoch 0/1, trained 100/600 instances, Logloss: 0.82282
Epoch 0/1, trained 200/600 instances, Logloss: 0.77052
Epoch 0/1, trained 300/600 instances, Logloss: 0.71091
Epoch 0/1, trained 400/600 instances, Logloss: 0.74702
Epoch 0/1, trained 500/600 instances, Logloss: 0.67888
Epoch 0/1, trained 600/600 instances, Logloss: 0.81556
write model
Accuracy: 8805/10000


In [257]:
# Simple CNN model
class Net(nn.Module):
    def __init__(self, num_classes):
        super(Net, self).__init__()        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=3)
        self.fc1 = nn.Linear(in_features=400, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.fc3 = nn.Linear(in_features=60, out_features=num_classes)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), kernel_size=(2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=2)
        x = x.view(-1, self.num_flat_features(x))  # change the shape to 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # (1, 10)
        return x
    
    def num_flat_features(self, x):
        sizes = x.size()[1:]
        num_features = 1
        for s in sizes:
            num_features *= s
        return num_features
    
# Inspect the net structure
net = Net(num_classes)
print(net)
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())
print([param.shape for param in list(net.parameters())])

Net(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=60, bias=True)
  (fc3): Linear(in_features=60, out_features=10, bias=True)
)
conv1.weight 	 torch.Size([6, 1, 3, 3])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 3, 3])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([60, 120])
fc2.bias 	 torch.Size([60])
fc3.weight 	 torch.Size([10, 60])
fc3.bias 	 torch.Size([10])
[torch.Size([6, 1, 3, 3]), torch.Size([6]), torch.Size([16, 6, 3, 3]), torch.Size([16]), torch.Size([120, 400]), torch.Size([120]), torch.Size([60, 120]), torch.Size([60]), torch.Size([10, 60]), torch.Size([10])]


In [250]:
# Train MNIST
net_criterion = nn.CrossEntropyLoss()
net_optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate / 10)


def train_net(model, criterion, optimizer, train_loader, model_name='net.model', num_epoch=5, output_log_freq=0):
    """
    Train the model with given train_loader. Save the model if model name specified.
    """
    total = len(train_loader)
    if os.path.exists(model_name):
        print("load model")
        model.load_state_dict(torch.load(model_name, map_location=device))
    model = model.to(device)
    
    for e in range(num_epoch):
        for i, (instances, labels) in enumerate(train_loader):
            instances = instances.to(device)
            labels = labels.to(device)
            # Forward
            output = model(instances)
            # Calculate loss
            loss = criterion(output, labels)
            # Update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if output_log_freq and (i + 1) % output_log_freq == 0:
                print('Epoch %d/%d, trained %d/%d instances, Logloss: %.5f' % 
                        (e, num_epoch, i + 1, total, loss.item()))
    if model_name is not None:
        torch.save(model.state_dict(), model_name)
                
train_net(net, net_criterion, net_optimizer, train_loader, model_name="net.model", num_epoch=5, output_log_freq=100)

load model
Epoch 0/5, trained 100/600 instances, Logloss: 0.09263
Epoch 0/5, trained 200/600 instances, Logloss: 0.05441
Epoch 0/5, trained 300/600 instances, Logloss: 0.06278
Epoch 0/5, trained 400/600 instances, Logloss: 0.04827
Epoch 0/5, trained 500/600 instances, Logloss: 0.02453
Epoch 0/5, trained 600/600 instances, Logloss: 0.05710
Epoch 1/5, trained 100/600 instances, Logloss: 0.12094
Epoch 1/5, trained 200/600 instances, Logloss: 0.07699
Epoch 1/5, trained 300/600 instances, Logloss: 0.10024
Epoch 1/5, trained 400/600 instances, Logloss: 0.13632
Epoch 1/5, trained 500/600 instances, Logloss: 0.06573
Epoch 1/5, trained 600/600 instances, Logloss: 0.06864
Epoch 2/5, trained 100/600 instances, Logloss: 0.03251
Epoch 2/5, trained 200/600 instances, Logloss: 0.03139
Epoch 2/5, trained 300/600 instances, Logloss: 0.07524
Epoch 2/5, trained 400/600 instances, Logloss: 0.02139
Epoch 2/5, trained 500/600 instances, Logloss: 0.03631
Epoch 2/5, trained 600/600 instances, Logloss: 0.05876

In [251]:
# Evaluate
correct, total = 0, 0
for images, labels in test_loader:
    total += labels.size(0)
    predicted = net(images.to(device))
    correct += (torch.max(predicted, 1)[1] == labels.to(device)).sum()

print('Accuracy: %d/%d' % (correct, total))

Accuracy: 9768/10000


In [276]:
output = torch.randn(5, 10)
print(output)
prob = F.softmax(output, dim=1)
print(prob)

print(torch.tensor([1]).item())

tensor([[ 1.1797,  0.6996,  0.1432, -1.7373,  1.5378,  0.5873,  2.2067, -0.9361,
         -1.9889, -0.5158],
        [-1.3159,  0.2482, -0.4371,  0.4883,  1.0038, -1.2690,  0.9729,  1.5214,
         -0.8506, -0.8907],
        [ 1.9925, -0.9771, -0.2337, -0.3260,  0.1440, -0.0028,  0.5881, -0.7125,
         -0.9722,  0.6373],
        [ 0.4294,  0.1880, -0.2478,  0.1346, -0.2038, -2.7212,  0.7447, -0.2624,
          0.7781,  0.9440],
        [-1.3857,  1.0920, -0.2592, -0.9410, -0.6755, -0.0675,  1.1213,  0.6549,
          0.0898, -0.7927]])
tensor([[0.1399, 0.0865, 0.0496, 0.0076, 0.2001, 0.0773, 0.3906, 0.0169, 0.0059,
         0.0257],
        [0.0180, 0.0860, 0.0434, 0.1094, 0.1832, 0.0189, 0.1776, 0.3074, 0.0287,
         0.0275],
        [0.4602, 0.0236, 0.0497, 0.0453, 0.0725, 0.0626, 0.1130, 0.0308, 0.0237,
         0.1187],
        [0.1166, 0.0916, 0.0593, 0.0869, 0.0619, 0.0050, 0.1599, 0.0584, 0.1653,
         0.1951],
        [0.0202, 0.2408, 0.0623, 0.0315, 0.0411, 0.0755, 0