In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
# from torchsummary import summary

import os
import time
import datetime
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.legend_handler import HandlerPathCollection


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1_1 = nn.Conv2d(1, 32, 5, 1, 2) # in, out, kernel, stride, padding 
        self.prelu1_1 = nn.PReLU()
        self.conv1_2 = nn.Conv2d(32, 32, 5, 1, 2)
        self.prelu1_2 = nn.PReLU()

        self.conv2_1 = nn.Conv2d(32, 64, 5, 1, 2)
        self.prelu2_1 = nn.PReLU()
        self.conv2_2 = nn.Conv2d(64, 64, 5, 1, 2)
        self.prelu2_2 = nn.PReLU()

        self.conv3_1 = nn.Conv2d(64, 128, 5, 1, 2)
        self.prelu3_1 = nn.PReLU()
        self.conv3_2 = nn.Conv2d(128, 128, 5, 1, 2)
        self.prelu3_2 = nn.PReLU()

        self.fc1 = nn.Linear(128 * 3 * 3, 2)
        self.prelu_fc1 = nn.PReLU()
        self.fc2 = nn.Linear(2, 10)

    def forward(self, x):
        # Stage 1
        x = self.prelu1_1(self.conv1_1(x)) # 28
        x = self.prelu1_2(self.conv1_2(x)) # 28
        x = F.max_pool2d(x, 2, 2, 0) # 14

        # Stage 2
        x = self.prelu2_1(self.conv2_1(x)) # 14
        x = self.prelu2_2(self.conv2_2(x)) # 14
        x = F.max_pool2d(x, 2, 2, 0) # 7

        # Stage 3
        x = self.prelu3_1(self.conv3_1(x)) # 7
        x = self.prelu3_2(self.conv3_2(x)) # 7
        x = F.max_pool2d(x, 2, 2, 0) # 3

        # Flatten
        # x = torch.flatten(x, 1)
        # x = x.view(x.size(0), -1)
        x = x.view(-1, 128 * 3 * 3)
        x = self.prelu_fc1(self.fc1(x))
        y = self.fc2(x)
        
        return x, y

In [2]:
use_cuda = 'cuda'
# device = torch.device("cuda" if use_cuda else "cpu")
device = 'cuda:3'

lenet = Net().to(device=device)
# summary(lenet, input_size=(1, 28, 28))

In [3]:
batch_size = 128
test_batch_size = 1000
epochs = 50
lr = 1e-2
step_size = 20
gamma = 0.7
log_interval = 10
save_model = True
num_classes = 10
center_loss_weight = 0.001

loss_opt = 'softmax_center' # 'softmax' or 'softmax_center'

class CenterLoss(nn.Module):
    def __init__(self, num_classes=10, feat_dim=2):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())

    def forward(self, x, labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (batch_size).
        """
        center = self.centers[labels]
        dist = (x - center).pow(2).sum(dim=-1) / 2.0
        loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)

        return loss

def plot_features(features, labels, num_classes, epoch, prefix):
    """Plot features on 2D plane.
    Args:
        features: (num_instances, num_features).
        labels: (num_instances). 
    """
    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
    for label_idx in range(num_classes):
        sc = plt.scatter(
            features[labels==label_idx, 0],
            features[labels==label_idx, 1],
            c=colors[label_idx],
            s=1,
        )
    plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], bbox_to_anchor=(1.04, 1), loc='upper left', 
               handler_map={type(sc): HandlerPathCollection(update_func=update_prop)}) # prop={'size': 6}
    # dirname = osp.join(args.save_dir, prefix)
    # if not osp.exists(dirname):
    #     os.mkdir(dirname)
    save_name = prefix + '_epoch_' + str(epoch) + '.png'
    plt.savefig(save_name, bbox_inches='tight')
    plt.close()
    

def update_prop(handle, orig):
    marker_size = 12
    handle.update_from(orig)
    handle.set_sizes([marker_size])
    
    
def crossentropyloss_train(model, criterion_cross, device, train_loader, optimizer_cross, epoch, num_classes, log_interval):
    model.train()
    all_features, all_labels = [], []

    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        feature, output = model(data)
        loss = criterion_cross(output, label)

        optimizer_cross.zero_grad()
        loss.backward()

        optimizer_cross.step()

        all_features.append(feature.data.cpu().numpy())
        all_labels.append(label.data.cpu().numpy())

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            # if args.dry_run:
            #     break

    all_features = np.concatenate(all_features, 0)
    all_labels = np.concatenate(all_labels, 0)
    plot_features(all_features, all_labels, num_classes, epoch, prefix='softmax')


def crosscenterloss_train(model, criterion_cross, criterion_center, device, train_loader, optimizer_cross, optimizer_center, epoch, num_classes, log_interval, weight=1):
    model.train()
    all_features, all_labels = [], []

    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        feature, output = model(data)
        loss_cross = criterion_cross(output, label)
        loss_center = criterion_center(feature, label)
        loss_center *= weight
        loss = loss_center + loss_cross

        optimizer_cross.zero_grad()
        optimizer_center.zero_grad()
        loss.backward()

        optimizer_cross.step()
        # by doing so, weight_cent would not impact on the learning of centers
        for param in criterion_center.parameters():
            param.grad.data *= (1. / weight)
        optimizer_center.step()

        all_features.append(feature.data.cpu().numpy())
        all_labels.append(label.data.cpu().numpy())

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tCenter Loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), loss_center.item()))
            # if args.dry_run:
            #     break

    all_features = np.concatenate(all_features, 0)
    all_labels = np.concatenate(all_labels, 0)
    plot_features(all_features, all_labels, num_classes, epoch, prefix=f'softmax_center_lambda_{weight}') # soft_center

    
def test(model, criterion, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            feature, output = model(data)
            # test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


# Training settings
train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                    'pin_memory': True,
                    'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

transform=transforms.Compose([
    transforms.Resize((28,28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

# transform = transforms.Compose([
#     transforms.Resize((28,28)),
#     transforms.ToTensor()
# ])

dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)

with torch.cuda.device(3):
    criterion_cross = nn.CrossEntropyLoss()
    optimizer_cross = optim.SGD(model.parameters(), lr=lr, weight_decay=5e-4, momentum=0.9) # weight_decay=5e-4, momentum=0.9
    
    if loss_opt == 'softmax_center':
        criterion_center = CenterLoss(num_classes=num_classes, feat_dim=2)
        optimizer_center = optim.SGD(criterion_center.parameters(), lr=lr)
        
scheduler = StepLR(optimizer_cross, step_size=step_size, gamma=gamma)

In [4]:
start_time = time.time()
with torch.cuda.device(3):
    if loss_opt == 'softmax':
        for epoch in range(1, epochs + 1):
            crossentropyloss_train(model, criterion_cross, device, train_loader, 
                optimizer_cross, epoch, num_classes, log_interval)
            test(model, criterion_cross, device, test_loader)
            scheduler.step()
    elif loss_opt == 'softmax_center':
        for epoch in range(1, epochs + 1):
            crosscenterloss_train(model, criterion_cross, criterion_center, device, train_loader, 
                optimizer_cross, optimizer_center, epoch, num_classes, log_interval, center_loss_weight)
            test(model, criterion_cross, device, test_loader)
            scheduler.step()

elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))

if save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")



Test set: Average loss: 0.0003, Accuracy: 9302/10000 (93%)


Test set: Average loss: 0.0002, Accuracy: 9700/10000 (97%)




Test set: Average loss: 0.0001, Accuracy: 9761/10000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9815/10000 (98%)




Test set: Average loss: 0.0001, Accuracy: 9846/10000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9867/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9831/10000 (98%)




Test set: Average loss: 0.0001, Accuracy: 9846/10000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9869/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9863/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9888/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9891/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9898/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9894/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9895/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9892/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9890/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9890/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9892/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9895/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9889/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9883/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9884/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9885/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9882/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9885/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9881/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9883/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9880/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9882/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9879/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9875/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9876/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9867/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9848/10000 (98%)




Test set: Average loss: 0.0001, Accuracy: 9864/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9886/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9888/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9891/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9884/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9880/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9882/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9884/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9884/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9879/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9880/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9879/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9881/10000 (99%)


Test set: Average loss: 0.0001, Accuracy: 9874/10000 (99%)




Test set: Average loss: 0.0001, Accuracy: 9881/10000 (99%)

Finished. Total elapsed time (h:m:s): 0:09:35
