In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import networkx as nx
import copy
import warnings
warnings.filterwarnings('ignore')

In [2]:
from loader import *
from syngem_utils import *

In [3]:
torch.manual_seed(42)
np.random.seed(42)

In [360]:
gem_model_2_21 = load_gemini_model("FC", 20, 21)
gem_model_2_42 = load_gemini_model("FC", 20, 42)
gem_model_2_63 = load_gemini_model("FC", 20, 63)

In [361]:
syn_model_2_21 = load_synflow_model("FC", 20, 21)
syn_model_2_42 = load_synflow_model("FC", 20, 42)
syn_model_2_63 = load_synflow_model("FC", 20, 63)

In [368]:
syn_fil_2_21 = get_filters(syn_model_2_21)
syn_fil_2_42 = get_filters(syn_model_2_42)
syn_fil_2_63 = get_filters(syn_model_2_63)

In [369]:
rnd_model_2_21 = load_random_model("FC", 20, 21)
rnd_fil_2_21   = get_filters(rnd_model_2_21)

In [370]:
gem_fil_2_21 = get_filters(gem_model_2_21)
gem_fil_2_42 = get_filters(gem_model_2_42)
gem_fil_2_63 = get_filters(gem_model_2_63)

In [17]:
# check sparsity could be optimized using np.argwhere 

In [None]:
# the weights average around ~ in unit 1 in layer 1
# the distance is around ~
# there are ~ many weights on avergae left in the units 

In [83]:
# based on best unit matching do these analysis 

In [353]:
np.concatenate((np.array([6,7,8]), np.array([8,9])))

array([6, 7, 8, 8, 9])

In [354]:
def get_weight_positions(model):
    
    positions_model = []
    for layer in model:
        positions_layer = []
        for unit in layer:
            #print(unit)
            weight_idxs = np.argwhere(unit.flatten())
            #print(weight_idxs.flatten())
            positions_layer = np.concatenate((positions_layer, weight_idxs.flatten()))
        
        positions_model.append(positions_layer)   
    
    return positions_model

In [355]:
c = get_weight_positions(rnd_fil_2_21)

array([248., 255., 275., 280., 292., 293., 295., 297., 299., 309., 310.,
       321., 323., 332., 336., 343., 344., 348., 350., 351., 363., 371.,
       379., 383., 385., 391., 392., 398., 401., 409., 419., 420., 421.,
       422., 429., 430., 431., 432., 433., 450., 455., 459., 462., 464.,
       475., 477., 486., 487., 488., 495., 498., 502., 503., 506., 509.,
       530., 534., 536., 538., 539., 544., 545., 551., 556., 560., 569.,
       576., 581., 599., 600., 603., 607., 610., 612., 618., 622., 625.,
       639., 642., 655., 656., 663., 665., 671., 678., 681., 688., 693.,
       695., 699., 702., 714., 716., 731., 733., 735., 738., 747., 749.,
       750., 752., 755., 759., 766., 769., 776., 779., 782.,   3.,   8.,
        12.,  24.,  25.,  26.,  28.,  32.,  37.,  43.,  45.,  47.,  49.,
        72.,  74.,  78.,  84.,  92.,  98., 101., 103., 112., 128., 129.,
       131., 140., 145., 154., 156., 158., 162., 164., 168., 171., 176.,
       179., 188., 190., 192., 193., 201., 207.])

In [154]:
pos, mid, c = get_weight_positions(gem_fil_2_21[0][23], np.mean)
pos, mid, c

(None, None, None)

In [157]:
pos, mid, c = get_weight_positions(rnd_fil_2_21[0][24], np.mean)
pos, mid, c

(308.45, 337.0, 11)

In [402]:
def get_weight_distance(unit):
    
    weight_idxs = np.argwhere(unit)
    
    weight_dist = 0
    for idx in range(len(weight_idxs)):
        if len(weight_idxs) == 1:
            return np.nan
        try:
            weight_dist += abs(weight_idxs[idx] - weight_idxs[idx + 1])
        except:
            return int(weight_dist / (len(weight_idxs) - 1))
    return np.nan

In [399]:
for i in range(20):
    test = get_weight_distance(gem_fil_2_21[0][i])
    print(test)

2
2
80
2
1
2
43
2
9
39
4
2
6
2
4
4
74
106
64
39


In [397]:
gem_fil_2_21[0][17]

array([-0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        , -0.        , -0.        , -0.        , -0.        ,
        0.        , -0.        , -0.        , -0.        ,  0.        ,
       -0.        , -0.        , -0.        ,  0.        ,  0.        ,
       -0.        ,  0.        ,  0.        , -0.        ,  0.        ,
       -0.        ,  0.        ,  0.        ,  0.        , -0.        ,
        0.        ,  0.        , -0.        , -0.        ,  0.        ,
       -0.        ,  0.        , -0.        , -0.        ,  0.        ,
       -0.        , -0.        , -0.        , -0.        ,  0.        ,
        0.        , -0.        ,  0.        ,  0.        , -0.        ,
       -0.        , -0.        , -0.        ,  0.        ,  0.        ,
       -0.        ,  0.        , -0.        ,  0.        , -0.        ,
       -0.        ,  0.        , -0.        ,  0.        , -0.        ,
       -0.04536363,  0.        , -0.        ,  0.        ,  0.  

In [199]:
get_weight_distance(rnd_fil_2_21[0][0])

39

In [181]:
for i in range(20):
    test = get_weight_distance(rnd_fil_2_21[0][i])
    print(test)

39
36
42
55
44
50
42
32
54
44
48
40
64
51
37
42
39
32
65
41


In [202]:
def get_model_weight_distances(model):
    distances_model = []
    for layer in model:
        distances_layer = []
        for unit in layer:
            dist = get_weight_distance(unit)
            distances_layer.append(dist)
        distances_model.append(distances_layer)
    
    return distances_model

In [403]:
test_gem = get_model_weight_distances(gem_fil_2_21)
test_gem_2 = get_model_weight_distances(gem_fil_2_21)


In [388]:
kk = test_gem + test_gem_2

In [390]:
len(kk)

12

In [386]:
for i in range(6):
    print(np.mean(np.nan_to_num(test_gem[i])))

22.67
5.14
4.92
5.69
4.41
2.7


In [279]:
syn_gem = get_model_weight_distances(syn_fil_2_21)
for i in range(6):
    print(np.mean(np.nan_to_num(syn_gem[i])))

3.66
2.0
2.0
2.0
2.0
2.0


In [280]:
test = get_model_weight_distances(rnd_fil_2_21)
np.mean(np.nan_to_num(test[5]))

5.0

In [281]:
for i in range(6):
    print(np.mean(np.nan_to_num(test[i])))

5.47
5.19
5.05
5.12
5.19
5.0


In [None]:
# this suggests that smart pruning leads to 4x closer cluster of weights 
# all this similar to the idea of NNSTD where position of weights is compared, but not high distances
# not neccesarily punished if they are common to both networks that are getting compared
# --> makes sense, the 

In [404]:
np.unique(np.array([1,2,3,4,5,5,5]))

array([1, 2, 3, 4, 5])

In [426]:
def count_clusters(model):
     
    clusters_model = []
    for layer in model:
        clusters_layer = []
        for unit in layer:
            clusters_unit = []
            weight_idxs = np.argwhere(unit.flatten())
            count = 0
            for i in range(len(weight_idxs)):
                try:
                    if int(weight_idxs[i] + 1) == int(weight_idxs[i + 1]):
                        clusters_unit.append(count)
                    else:
                        count += 1
                        clusters_unit.append(count)
                        
                except:
                    pass
                        
            clusters = len(np.unique(np.array(clusters_unit)))
            clusters_layer.append(clusters)
        
        clusters_model.append(clusters_layer)   
    
    return clusters_model

In [427]:
lol = count_clusters(syn_fil_2_21)

In [447]:
lol[1]

[19,
 21,
 24,
 20,
 21,
 21,
 22,
 20,
 19,
 23,
 22,
 19,
 25,
 18,
 22,
 17,
 21,
 20,
 24,
 20,
 19,
 28,
 26,
 19,
 19,
 21,
 23,
 23,
 20,
 24,
 23,
 25,
 20,
 23,
 21,
 17,
 23,
 20,
 20,
 23,
 24,
 18,
 23,
 23,
 16,
 24,
 19,
 21,
 23,
 25,
 20,
 21,
 23,
 21,
 22,
 19,
 22,
 21,
 22,
 23,
 22,
 19,
 23,
 21,
 21,
 18,
 24,
 26,
 21,
 21,
 20,
 23,
 18,
 22,
 22,
 21,
 18,
 20,
 20,
 21,
 21,
 26,
 20,
 18,
 24,
 20,
 24,
 22,
 23,
 25,
 24,
 25,
 22,
 21,
 24,
 25,
 17,
 22,
 23,
 23]

In [430]:
lul = count_clusters(rnd_fil_2_21)

In [451]:
lul[0]

[133,
 141,
 154,
 123,
 131,
 135,
 133,
 122,
 122,
 117,
 123,
 140,
 133,
 131,
 122,
 125,
 125,
 122,
 132,
 124,
 132,
 115,
 134,
 131,
 121,
 124,
 118,
 116,
 134,
 122,
 120,
 127,
 117,
 110,
 118,
 135,
 125,
 119,
 132,
 121,
 143,
 123,
 120,
 121,
 141,
 116,
 126,
 121,
 136,
 112,
 128,
 118,
 123,
 135,
 134,
 112,
 112,
 134,
 113,
 142,
 128,
 109,
 121,
 117,
 141,
 111,
 118,
 135,
 111,
 126,
 132,
 129,
 137,
 115,
 132,
 127,
 120,
 128,
 136,
 144,
 117,
 113,
 123,
 128,
 128,
 118,
 117,
 130,
 125,
 131,
 139,
 117,
 132,
 127,
 122,
 112,
 127,
 120,
 135,
 111]

In [449]:
lel = count_clusters(gem_fil_2_21)

In [453]:
lel[2]

[3,
 23,
 0,
 20,
 3,
 11,
 2,
 22,
 24,
 20,
 17,
 0,
 1,
 0,
 21,
 15,
 25,
 24,
 2,
 0,
 24,
 26,
 20,
 9,
 19,
 0,
 2,
 23,
 0,
 17,
 22,
 23,
 16,
 6,
 27,
 23,
 8,
 3,
 20,
 0,
 0,
 0,
 12,
 1,
 12,
 22,
 28,
 12,
 0,
 0,
 23,
 20,
 22,
 23,
 0,
 18,
 0,
 2,
 1,
 19,
 20,
 21,
 23,
 18,
 26,
 10,
 24,
 2,
 1,
 0,
 3,
 14,
 16,
 22,
 0,
 9,
 0,
 1,
 21,
 0,
 2,
 26,
 18,
 21,
 0,
 27,
 19,
 22,
 0,
 0,
 19,
 7,
 24,
 27,
 23,
 22,
 0,
 22,
 11,
 0]

In [422]:
np.argwhere(syn_fil_2_21[0][2].flatten())

array([[  6],
       [ 20],
       [ 21],
       [ 22],
       [ 39],
       [ 43],
       [ 49],
       [ 53],
       [ 56],
       [ 59],
       [ 72],
       [ 80],
       [ 84],
       [ 94],
       [105],
       [110],
       [111],
       [127],
       [131],
       [132],
       [140],
       [162],
       [180],
       [188],
       [189],
       [204],
       [222],
       [227],
       [231],
       [232],
       [236],
       [246],
       [249],
       [255],
       [259],
       [261],
       [264],
       [267],
       [279],
       [289],
       [302],
       [320],
       [331],
       [336],
       [343],
       [357],
       [370],
       [376],
       [387],
       [397],
       [410],
       [435],
       [441],
       [445],
       [447],
       [449],
       [460],
       [477],
       [479],
       [480],
       [481],
       [483],
       [488],
       [490],
       [491],
       [501],
       [502],
       [519],
       [527],
       [532],
       [537],
      