In [1]:
import tensorpack as tp
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import matplotlib.pyplot as plt

from helpers.rsr_run import Model
from helpers.rsr_run import create_dataflow
from helpers.rsr_run import net_fn_map
from helpers.rsr2015 import *
from tensorpack import *
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.tfutils.varmanip import *
import os
from datetime import datetime

%load_ext autoreload
%autoreload 1

def avg(w): return sum(w.flatten())/len(w.flatten())

def prune(x, prune_rate=0.10):
    a = np.copy(x)
    nonzero_vals = a[np.nonzero(a)]
    sorted_vals = np.sort(np.abs(nonzero_vals).flatten())
    split_idx = int(prune_rate * len(sorted_vals))
    thres = sorted_vals[split_idx]
    a[np.abs(a) < thres] = 0
    return a

def is_extra_var(var):
    return 'Adam' in var or 'global_step' == var or 'EMA' in var or var.startswith('beta') and var.endswith('_power') or 'bits_for_maxval_var' in var

def get_new_var_dict(var_dict, prune_rate=0.1, verbose=False):
    new_var_dict = {}
    ema_stats = set(); bn_vars = set(); kernel_vars = set(); bias_vars = set()
    print
    
    for var in var_dict:
        if is_extra_var(var):
            new_var_dict[var] = var_dict[var]
        elif 'bn' in var:
            bn_vars.add(var)
        elif var.endswith('/W') or var.endswith('/depthwise_weights'):
            kernel_vars.add(var)
        elif var.endswith('/b') or var.endswith('/biases'):
            bias_vars.add(var)
        else:
            print("Couldn't classify", var)

    for i,var in enumerate(kernel_vars):
        if verbose: print(var)
        new_var = prune(var_dict[var], prune_rate)
        new_var_dict[var] = new_var

    for i,var in enumerate(bn_vars):
        if verbose: print(var)
        new_var = prune(var_dict[var], prune_rate)
        new_var_dict[var] = new_var

    for i,var in enumerate(bias_vars):
        new_var_dict[var] = prune(var_dict[var], prune_rate)
    
    return new_var_dict, kernel_vars, bn_vars, bias_vars

def get_sparsity(var_dict, interesting_vars = None):
    zero_count = []
    totals = []
    if not interesting_vars: interesting_vars = var_dict.keys()
    for name in interesting_vars:
        if is_extra_var(name): continue
        weight = var_dict[name]
        zero_count.append(np.count_nonzero(weight==0))
        totals.append(np.prod(weight.shape))
    return sum(zero_count)/sum(totals), zero_count, totals

In [2]:
model_name='fcn2'; bn = True; reg = False
datadir = '/data/sls/scratch/skoppula/kaldi-rsr/numpy/'
spkmap = '/data/sls/scratch/skoppula/backup-exps/rsr-experiments/create_rsr_data_cache/generator_full_dataset/spk_mappings.pickle'
cachedir = '/data/sls/scratch/skoppula/backup-exps/rsr-experiments/create_rsr_data_cache/trn_cache/context_50frms/'

w_bit = 32; a_bit = 32; end = True
ckpt_path = '/data/sls/u/meng/skanda/home/thesis/manfxpt/no_bn_models/' + model_name + '/checkpoint'
bn_ckpt_path = '/data/sls/u/meng/skanda/home/thesis/manfxpt/models/sentfiltNone_' + model_name + '_bnTrue_regTrue_noLRSchedule/checkpoint'
if bn:
    ckpt_path = bn_ckpt_path
    
outdir=os.path.join('pruned_models', '_'.join([str(x) for x in [model_name, w_bit, a_bit, end]]))
print("Outputting to outdir", outdir)
logger.set_logger_dir(outdir, action='k')
context=50
n_spks = get_n_spks(spkmap)

('Outputting to outdir', 'pruned_models/fcn2_32_32_True')
[32m[0329 11:33:28 @logger.py:74][0m Argv: /data/sls/u/meng/skanda/home/envs/tf2cpu/lib/python2.7/site-packages/ipykernel_launcher.py -f /run/user/23571/jupyter/kernel-c259bc5e-774d-43a7-ab67-f24d7288948e.json


In [None]:
prune_rates = [0, 0.25, 0.75]
errors = []
sparsities = []
print(ckpt_path)

for i, prune_rate in enumerate(prune_rates):
    
    val_dataflow, n_batches_val = create_dataflow('val', None, datadir, spkmap, None, context)
    val_generator = val_dataflow.get_data()
    
    print("On prune rate", prune_rate)
    var_dict = load_chkpt_vars(ckpt_path)
    new_var_dict, _, _, _ = get_new_var_dict(var_dict, prune_rate)
    sparsities.append(get_sparsity(new_var_dict))

    model = Model(n_spks, net_fn_map[model_name], bn=bn, reg=False, n_context=context)

    config = PredictConfig(
            model=model,
            session_init=DictRestore(new_var_dict),
            input_names=['input', 'label'],
            output_names=['utt-wrong']# ['utt-wrong']
    )
    predictor = OfflinePredictor(config)

    rc = tp.utils.stats.RatioCounter()
    for i in range(n_batches_val):
        x,y = next(val_generator)
        outputs, = predictor([x,y])
        rc.feed(outputs,1)
        if i % 100 == 0:
            print("On",i,"of",n_batches_val, "error:", rc.ratio)
        if i == 700: break
    print("error", rc.ratio)
    errors.append(rc.ratio[0])
print(errors)
print(sparsities)

/data/sls/u/meng/skanda/home/thesis/manfxpt/models/sentfiltNone_fcn2_bnTrue_regTrue_noLRSchedule/checkpoint
('whole utterance size', 75290)
('val', False, 18822)
('On prune rate', 0)

('Adding activation tensors to summary:', [<tf.Tensor 'linear0/bn/output:0' shape=(?, 504) dtype=float32>, <tf.Tensor 'linear0/output:0' shape=(?, 504) dtype=float32>, <tf.Tensor 'linear1/bn/output:0' shape=(?, 504) dtype=float32>, <tf.Tensor 'linear1/output:0' shape=(?, 504) dtype=float32>, <tf.Tensor 'linear2/bn/output:0' shape=(?, 504) dtype=float32>, <tf.Tensor 'linear2/output:0' shape=(?, 504) dtype=float32>, <tf.Tensor 'linear3/bn/output:0' shape=(?, 504) dtype=float32>, <tf.Tensor 'linear3/output:0' shape=(?, 504) dtype=float32>, <tf.Tensor 'last_linear/output:0' shape=(?, 255) dtype=float32>, <tf.Tensor 'output:0' shape=(?, 255) dtype=float32>])
[32m[0329 11:34:37 @rsr_run.py:129][0m Parameter count: {'mults': 1398600, 'weights': 1398855}
[32m[0329 11:34:37 @sessinit.py:206][0m Variables to re