In [1]:
import torch
from torch import optim
from torch.autograd import Variable
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms

In [20]:
class KSparseDropout(nn.Module):
    def __init__(self, p=0.5, k=None, infer_ratio=1.0):
        super().__init__()
        self.p = p
        self.k = k
        
    def forward(self, x):
        if not self.k:
            self.k = int(self.p * x.shape[1])
        
        # Enforce k-sparsity
        topk, indices = torch.topk(x, self.k)
        res = Variable(torch.zeros(x.shape[0], x.shape[1]))
        res = res.scatter(1, indices, topk)
        
        return res
    
    
class KSparseDropout2d(nn.Module):
    def __init__(self, p=0.5, k=None, infer_ratio=1.0):
        super().__init__()
        self.p = p
        self.k = k
        
    def forward(self, x):
        if not self.k:
            self.k = int(self.p * x.shape[1])
                
        if self.training:
            
            activation = x.sum(dim=2).sum(dim=2)
            
            topk, indices = torch.topk(activation, self.k, dim=1)
            
            for i, _ in enumerate(indices):
                for j, _ in enumerate(x[i, :, :, :]):
                    if j not in indices[i]:
                        x[i, j, :, :] = 0

        return x

In [3]:
ks_dropout = KSparseDropout()

x = torch.rand(2, 10)

print(x)

x = ks_dropout(x)

print(x)

tensor([[0.1093, 0.2895, 0.7088, 0.8486, 0.5823, 0.8954, 0.7018, 0.7308, 0.7518,
         0.9636],
        [0.4222, 0.4250, 0.1825, 0.4292, 0.9604, 0.5741, 0.1266, 0.2242, 0.2949,
         0.6249]])
tensor([[0.0000, 0.0000, 0.0000, 0.8486, 0.0000, 0.8954, 0.0000, 0.7308, 0.7518,
         0.9636],
        [0.0000, 0.4250, 0.0000, 0.4292, 0.9604, 0.5741, 0.0000, 0.0000, 0.0000,
         0.6249]])


In [21]:
class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU(inplace=True)
        self.dropout1 = nn.Dropout2d(0.5)
        self.maxpool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU(inplace=True)
        self.dropout2 = nn.Dropout2d(0.5)
        self.maxpool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(16, 120, 3)
        self.relu3 = nn.ReLU(inplace=True)
        self.dropout3 = nn.Dropout2d(0.5)
        self.maxpool3 = nn.MaxPool2d(2)
        self.fc = nn.Linear(120, 10)
        self.log_softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x):
#         print(x.shape)
        
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        x = self.maxpool2(x)
        
#         print(x.shape)
#         print(self.conv3)
        
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.dropout3(x)
        x = self.maxpool3(x)
        
        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
        x = self.fc(x)
        x = self.log_softmax(x)
        
        return x
    
class LeNet5K(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.kdrop1 = KSparseDropout2d()
        self.maxpool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.kdrop2 = KSparseDropout2d()
        self.maxpool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(16, 120, 3)
        self.kdrop3 = KSparseDropout2d()
        self.maxpool3 = nn.MaxPool2d(2)
        self.fc = nn.Linear(120, 10)
        self.log_softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x):
#         print(x.shape)
        
        x = self.conv1(x)
        x = self.kdrop1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.kdrop2(x)
        x = self.maxpool2(x)
        
#         print(x.shape)
#         print(self.conv3)
        
        x = self.conv3(x)
        x = self.kdrop3(x)
        x = self.maxpool3(x)
        
        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
        x = self.fc(x)
        x = self.log_softmax(x)
        
        return x

In [24]:
batch_size = 32

transform = transforms.Compose([transforms.ToTensor(),
#                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                    transforms.Normalize((0.5, ), (0.5, ))
                   ])

train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=0)

In [None]:
# model = LeNet5()
model = LeNet5K()

optimizer = optim.Adam(model.parameters())

num_epochs = 10
num_rows = len(train_loader.dataset)

loss_func = nn.NLLLoss()

model.train()
model.cuda()

for epoch in range(1, num_epochs + 1):
    for batch_idx, (data, label) in enumerate(train_loader, 1):
        data, label = data.cuda(), label.cuda()
        data, label = Variable(data), Variable(label)
        
        optimizer.zero_grad()
        
        output = model(data).view(-1, 10)
        loss_sum = loss_func(output, label)
        
        loss_sum.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Validation loss: {:.6f}'.format(
                epoch, batch_idx * len(data), num_rows,
                100. * batch_idx * len(data) / num_rows, loss_sum.data.item()))
        
model.eval()

correct = 0
num_rows = 0
test_loss = 0
    
with torch.no_grad():
    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.cuda(), label.cuda()

        num_rows += data.size(0)
        
        output = model(data)
    
        test_loss += loss_func(output, label).mean()
        pred = output.data.max(1, keepdim=True)[1]
        
        correct += pred.eq(label.data.view_as(pred)).sum()
        
    test_loss /= num_rows
    
    print("Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
                    test_loss, correct, num_rows, (100. * correct.item()) / num_rows))
        
        
    







