In [5]:
%matplotlib inline
from sklearn.neighbors import NearestNeighbors, KDTree

from mcmatch import cluster
from mcmatch.db.pg_database import PgFunDB
from mcmatch.util import extract_funname, signature_to_fname_heuristic
from mcmatch.feature import all_features, textlength_features, counter_sum_features, counter_features, relative_counter_features, relative_counter_sum_features
from mcmatch.feature import cyclo_feature
from mcmatch.feature.aggregator import FeatureAggregator
from sklearn import preprocessing
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
import pprint
import timeit

In [6]:
def run_test(mtr_n, tfm, norm, doublewrap):
    global mtm
    line = '%s;%r;%s;' % (mtr_n, tfm, norm)
    #print "***", label
    mtr_o = mtm[mtr_n]
    used_features = FeatureAggregator(mtr_o.values())

    di = cluster.DistanceInfo(fdb, used_features, training_repositories=train_set, transform=tfm, norm=norm)
    #print di.train_data.shape
    pairwise_d, testset_infos = di.test(fdb, in_repositories=['t-dietlibc'])
    #print pairwise_d.shape
    training_infos = di.get_trainingset_infos()

    em = cluster.DistanceInfo.make_equivalence_map(testset_infos, training_infos,
                                                  key=lambda z: signature_to_fname_heuristic(z[1]))

    #print pairwise_d.shape
    #print di.train_data.shape
    
    good, bad, other = 0, 0, 0
    
    ks = []
    for i in range(0, len(em)):
        closests = [] # Yes?
        #pprint.pprint(em[i])
        
        #print ">", testset_infos[i]
        #print ">", di.test_data[i]
        #for eq in em[i]:
        #    print "<", di.train_data[eq], training_infos[eq]
        res = cluster.DistanceInfo.get_partition_sizes(pairwise_d[i], None, em[i])
        for el in res:
            closests.append(el[0])    
            if el[0] < el[2]:
                good += 1
            elif el[0] > el[2]:
                bad += 1
            else:
                other += 1
        if (len(closests)):
            ks.append(min(closests))
    try:
        line += ";".join(map(str, [good, bad, other, len(ks), 1.0*sum(ks)/len(ks), np.median(ks), np.std(ks)]))
    except ZeroDivisionError, e:
        # zomg x/0
        line += ";".join(map(str, [good, bad, other, len(ks), "NaN", np.median(ks), np.std(ks)]))
    doublewrap(line)

    #di.make_aggregate_graph(pairwise_d, testset_infos, em, title=label)
    #plt.savefig("glibc_dietlibc_default_aggr.pdf")
    #plt.close()
    del di.train_data
    del di.trainingset_idx_to_ftid
    del pairwise_d
    del testset_infos
    del training_infos
    del em
    del di
    #gc.collect()

In [7]:
fdb = PgFunDB()
mtm = {
    'all' : all_features,
#    'textlength' : textlength_features,
#    'counter': counter_features,
#    'counter_sum': counter_sum_features,
#    'rel_counter': relative_counter_features,
#    'rel_counter_sum' : relative_counter_sum_features,
#    'cyclo' : cyclo_feature,
#    'textlen+rel_counter_sum' : dict(
#          list(textlength_features.items())
#        + list(relative_counter_sum_features.items()))
}

#for m in counter_features:
#    mtm['counter_' + m] = {'m': counter_features[m]}
#for m in relative_counter_features:
#    mtm['rel_counter_' + m] = {'m' : relative_counter_features[m]}

norms = ['cityblock', 'euclidean', 'cosine']

transform_modes = [0, 1, 2, 4]

train_set = ['t-glibc', 'musl-1.1.6']

outfile = open("shell-dist.out.csv", "w")
def doublewrap(line):
    #outfile.write(line + "\n")
    #outfile.flush()
    print line

doublewrap("feature;tranformation_mode;norm;better_than_avg;worse_than_avg;other;num_fns_in_both_sets;mean_k;median_k;stddev_k")
for mtr_n, norm, tfm in product(mtm, norms, transform_modes):
    print timeit.timeit(lambda: run_test(mtr_n, tfm, norm, doublewrap), number=1)

feature;tranformation_mode;norm;better_than_avg;worse_than_avg;other;num_fns_in_both_sets;mean_k;median_k;stddev_k
all;0;cityblock;2828;1308;0;419;3442.92840095;1034.0;4689.96208234
36.6887021065
all;1;cityblock;2825;1311;0;419;3041.56563246;862.0;4553.559855
34.2806880474
all;2;cityblock;2805;1331;0;419;3361.98806683;950.0;4638.26693854
35.9522540569
all;4;cityblock;2805;1331;0;419;3362.00238663;950.0;4638.25769151
52.1849730015
all;0;euclidean;2820;1316;0;419;3429.52505967;1061.0;4643.23611034
34.0597109795
all;1;euclidean;2477;1659;0;419;3209.19570406;856.0;4651.61717663
33.917402029
all;2;euclidean;2820;1316;0;419;3430.5202864;1061.0;4642.86369704
41.9563589096
all;4;euclidean;2820;1316;0;419;3430.32458234;1061.0;4642.85082569
51.5741598606
all;0;cosine;2665;1471;0;419;3071.94033413;1173.0;4065.5129111
33.7857909203
all;1;cosine;3075;1061;0;419;2718.71837709;759.0;3902.76758749
33.8562469482
all;2;cosine;2815;1321;0;419;3415.58472554;1184.0;4617.38177051
40.0717999935
all;4;cosine;