In [None]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from utils import progress_bar
from torch.autograd import Variable
import numpy as np
from torch.nn.modules.module import _addindent

from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import MeanShift, estimate_bandwidth
import scipy.cluster.hierarchy as hcluster
import scipy.cluster.hierarchy as hac
import scipy.cluster.hierarchy as fclusterdata
import time
from sklearn.preprocessing import normalize


In [None]:
def torch_summarize(model, show_weights=True, show_parameters=True):
    """Summarizes torch model by showing trainable parameters and weights."""
    tmpstr = model.__class__.__name__ + ' (\n'
    total =0
    for key, module in model._modules.items():
        # if it contains layers let call it recursively to get params and weights
        if type(module) in [
            torch.nn.modules.container.Container,
            torch.nn.modules.container.Sequential
        ]:
            modstr = torch_summarize(module)
        else:
            modstr = module.__repr__()
        modstr = _addindent(modstr, 2)

        params = sum([np.prod(p.size()) for p in module.parameters()])
        weights = tuple([tuple(p.size()) for p in module.parameters()])


        tmpstr += '  (' + key + '): ' + modstr
        if show_weights:
            tmpstr += ', weights={}'.format(weights)
        if show_parameters:
            tmpstr +=  ', parameters={}'.format(params)
            total+=params
            print(params)
            print('total is ',total)
        tmpstr += '\n'

    tmpstr = tmpstr + ')'
    return tmpstr

In [None]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


In [None]:
def test(epoch, num_total, sum_total):
    global best_acc
    # best_acc = 0.0
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(testloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.data.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
        num_total+=1
        sum_total+=100.*correct/total
    
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving weight')
        state = {
            'net': net.module if use_cuda else net,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_pruned_retrain.t7')
        best_acc = acc

    return best_acc, num_total, sum_total


In [None]:
def cluster_weights_agglo(weight, threshold, average=True,cosine=True,euclidean=False,chebyshev=False,manhattan=False):
    t0 = time.time()
    weight = weight.T
    weight = normalize(weight, norm='l2', axis=1)
    threshold =  1.0-threshold   # Conversion to distance measure
    if cosine==True:
        clusters = hcluster.fclusterdata(weight, threshold, criterion="distance", metric='cosine', depth=1, method='centroid')
        z = hac.linkage(weight, metric='cosine', method='complete')
    elif euclidean==True:
        clusters = hcluster.fclusterdata(weight, threshold, criterion="distance", metric='euclidean', depth=1, method='centroid')
        z = hac.linkage(weight, metric='euclidean', method='complete')
    elif chebyshev==True:
        clusters = hcluster.fclusterdata(weight, threshold, criterion="distance", metric='chebyshev', depth=1, method='centroid')
        z = hac.linkage(weight, metric='chebyshev', method='complete')
    elif manhattan==True:
        clusters = hcluster.fclusterdata(weight, threshold, criterion="distance", metric='cityblock', depth=1, method='centroid')
        z = hac.linkage(weight, metric='cityblock', method='complete')
    
    labels = hac.fcluster(z, threshold, criterion="distance")

    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)

    #print(n_clusters_)
    elapsed_time = time.time() - t0
    # print(elapsed_time)

    a=np.array(labels)
    sort_idx = np.argsort(a)
    a_sorted = a[sort_idx]
    unq_first = np.concatenate(([True], a_sorted[1:] != a_sorted[:-1]))
    unq_items = a_sorted[unq_first]
    unq_count = np.diff(np.nonzero(unq_first)[0])
    unq_idx= np.split(sort_idx, np.cumsum(unq_count))
    first_ele = [unq_idx[idx][-1] for idx in range(len(unq_idx))]
    return n_clusters_, first_ele