In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import argparse
import numpy as np
import json, time, sys
import copy

from networks import *

from trajectoryPlugin.plugin import API

In [2]:
def train_fn(model, device, optimizer, api, reweight=False):
    model.train()
    for batch_idx, (data, target, weight) in enumerate(api.train_loader):
        data, target, weight = data.to(device), target.to(device), weight.to(device)
        optimizer.zero_grad()
        output = model(data)
        if reweight:
            loss = api.loss_func(output, target, weight, 'mean')
        else:
            loss = api.loss_func(output, target, None, 'mean')
        loss.backward()
        optimizer.step()

def forward_fn(model, device, api, forward_type, test_loader=None):
    model.eval()
    loss = 0
    correct = 0
    if forward_type == 'test':
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss += api.loss_func(output, target, None, 'sum').item() # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)

    elif forward_type == 'train':
        with torch.no_grad():
            for batch_idx, (data, target, weight) in enumerate(api.train_loader):
                data, target, weight = data.to(device), target.to(device), weight.to(device)
                output = model(data)
                loss += api.loss_func(output, target, weight, 'sum').item() # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        loss /= len(api.train_loader.dataset)
        accuracy = 100. * correct / len(api.train_loader.dataset)

    elif forward_type == 'validation':
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(api.valid_loader): 
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss += api.loss_func(output, target, None, 'sum').item() # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        loss /= len(api.valid_loader.dataset)
        accuracy = 100. * correct / len(api.valid_loader.dataset)

    return loss, accuracy

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mnistdata = datasets.MNIST('../data', train=True, download=True,
             transform=transforms.Compose([
                 transforms.ToTensor(),
                 transforms.Normalize((0.1307,), (0.3081,))
             ]))

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=100, shuffle=True)

valid_index = np.random.choice(range(len(mnistdata)), size=5000, replace=False).tolist()
train_index = np.delete(range(len(mnistdata)), valid_index).tolist()
trainset = torch.utils.data.dataset.Subset(mnistdata, train_index)
validset = torch.utils.data.dataset.Subset(mnistdata, valid_index)

In [4]:
model = ConvNet()
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

api = API(num_cluster=10, device=device, update_rate=1.0, iprint=2)
api.dataLoader(trainset, validset, batch_size=100)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

In [5]:
for epoch in range(1, 4):

    scheduler.step()
    train_fn(model, device, optimizer, api, False)
    api.createTrajectory(model)

    loss, accuracy = forward_fn(model, device, api, 'train')


    loss, accuracy = forward_fn(model, device, api, 'validation')


    loss, accuracy = forward_fn(model, device, api, 'test', test_loader)


    api.generateTrainLoader()
    print(api.train_loader.dataset[0])

((tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, 

((tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, 

((tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, 

In [6]:
api.trajectoryBins()
api.clusterBins()
api.reweightData(model, [])
api.generateTrainLoader()

2019-05-10 11:09:55,791 - INFO - | - {0: 0, 'size': 32961, 'sim': 0.06556938588619232}
2019-05-10 11:09:55,793 - INFO - | - {1: 1, 'size': 633, 'sim': 0.8496793508529663}
2019-05-10 11:09:55,799 - INFO - | - {2: 2, 'size': 1184, 'sim': 0.3025215268135071}
2019-05-10 11:09:55,805 - INFO - | - {3: 3, 'size': 7635, 'sim': 0.15567390620708466}
2019-05-10 11:09:55,806 - INFO - | - {4: 4, 'size': 2308, 'sim': 0.17941972613334656}
2019-05-10 11:09:55,810 - INFO - | - {5: 5, 'size': 3645, 'sim': 0.34148311614990234}
2019-05-10 11:09:55,810 - INFO - | - {6: 6, 'size': 966, 'sim': 0.7932683229446411}
2019-05-10 11:09:55,812 - INFO - | - {7: 7, 'size': 2674, 'sim': 0.055486708879470825}
2019-05-10 11:09:55,814 - INFO - | - {8: 8, 'size': 1211, 'sim': 0.3208021819591522}
2019-05-10 11:09:55,816 - INFO - | - {9: 9, 'size': 1783, 'sim': 0.41284430027008057}


In [7]:
print(api.weight_tensor[0])
print(api.train_loader.dataset[0])

tensor(1.0306)
((tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.