In [1]:
import numpy as np
import struct
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torchvision
from torch.autograd import Variable
from torchvision import datasets, transforms
from util.util import mnist_noise
from copy import deepcopy
from scipy import spatial
import torch.cuda as cutorch

from trajectoryPlugin.plugin import API

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

cuda:0


In [2]:
"""
CNN
"""
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out
    
def accuracy(predict_y, test_y):
    score = 0
    for pred, acc in zip(predict_y, test_y):
        if pred == acc:
            score +=1
    return score / test_y.shape[0]

In [3]:
"""
MNIST DATA
"""
batch_size = 100

mnistdata = datasets.MNIST('../data', train=True, download=True,
             transform=transforms.Compose([
                 transforms.ToTensor(),
                 transforms.Normalize((0.1307,), (0.3081,))
             ]))

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

datalen = len(mnistdata)
valid_index = np.random.choice(range(len(mnistdata)), size=5000, replace=False).tolist()
train_index = np.delete(range(len(mnistdata)), valid_index).tolist()
trainset = torch.utils.data.dataset.Subset(mnistdata, train_index)
validset = torch.utils.data.dataset.Subset(mnistdata, valid_index)

"""
Add Noise label to training data
"""
noise_idx = []
noise_idx = np.random.choice(range(len(trainset)), size=int(len(trainset)* 0.1), replace=False)
label = range(10)
for idx in noise_idx:
    true_label = trainset.dataset.targets[train_index[idx]]
    noise_label = [lab for lab in label if lab != true_label]
    trainset.dataset.targets[train_index[idx]] = int(np.random.choice(noise_label))
    

# suppose there are training set and validation set, trajectory API initializaiton
### currently, our API will take care of data part in training, see below

# we use torch dataset for initializaiton
api = API(num_cluster=3, device=device, iprint=2)
api.dataLoader(trainset, validset, batch_size=batch_size)

In [4]:
"""
Here is an example of standard NN training + trajectory reweighting.
"""


# model and its paramters
cnn = CNN()
cnn.to(device)
L2 = 0.0005
learning_rate = 0.001
num_iter = 10
#optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate, weight_decay=L2)
optimizer = torch.optim.SGD(cnn.parameters(), lr=learning_rate, momentum=0.9)
# standard training starts
epoch = 1
while epoch <= num_iter:
    print("="*20 + "epoch = {}".format(epoch) + "="*20)
    cnn.train()
    for step, (data, target, idx) in enumerate(api.train_loader): # api train_loader
        data, target = data.to(device), target.to(device)
        weight = api.weight_tensor[idx].to(device)
        output = cnn(data)
        loss = api.loss_func(output, target, weight) # api train_loader
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # record trajectory
    api.createTrajectory(cnn)
    
    # cluster trajectory + reweight data
    if epoch > 3:
        api.clusterTrajectory() # run gmm cluster
        api.reweightData(cnn, 1e6, noise_idx) # update train_loader
        
    cnn.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = cnn(data)
            loss += api.loss_func(output, target, None, 'sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print('Test Loss = {}, Test Accuracy = {}'.format(loss,accuracy))
    
    if torch.cuda.is_available():
        print("Memory ",str(cutorch.memory_allocated(0)) , ' ', str(cutorch.max_memory_allocated(0)) , ' ' , str(cutorch.get_device_properties(0).total_memory))
    
    epoch += 1

Test Loss = 0.3496535747528076, Test Accuracy = 94.61
Memory  39697920   149143040   6370295808
Test Loss = 0.27012772369384763, Test Accuracy = 96.42
Memory  39697920   149144064   6370295808
Test Loss = 0.2503749544143677, Test Accuracy = 97.24
Memory  39697920   149144064   6370295808
2019-04-05 01:35:12,159 - INFO - | - {0: 0, 'size': 10746, 'sim': '-0.9305', 'num_special': 5494, 'spe_ratio': '0.5113'}
2019-04-05 01:35:12,164 - INFO - | - {1: 1, 'size': 18202, 'sim': '0.9882', 'num_special': 5, 'spe_ratio': '0.0003'}
2019-04-05 01:35:12,172 - INFO - | - {2: 2, 'size': 26052, 'sim': '0.9413', 'num_special': 1, 'spe_ratio': '0.0000'}
Test Loss = 0.2223454174041748, Test Accuracy = 97.59
Memory  39697920   149463552   6370295808
2019-04-05 01:35:38,405 - INFO - | - {0: 0, 'size': 26159, 'sim': '0.9541', 'num_special': 1, 'spe_ratio': '0.0000'}
2019-04-05 01:35:38,410 - INFO - | - {1: 1, 'size': 17997, 'sim': '0.9894', 'num_special': 4, 'spe_ratio': '0.0002'}
2019-04-05 01:35:38,415 - 