## Import packages

In [1]:
import os
import sys
import cv2
import json
import tqdm
import numpy as np
import pandas as pd
from time import time
from numpy.lib.stride_tricks import as_strided

import tensorflow as tf
import lucid.optvis.render as render
import lucid.modelzoo.vision_models as models
from keras.applications.inception_v3 import preprocess_input

sys.path.insert(0, '..')
from InceptionV1 import InceptionV1

Using TensorFlow backend.


In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

## Load class info

In [3]:
# Load class info
input_path = '../../../data/imagenet-labels.txt'
df = pd.read_csv(input_path, sep='\t')

# Parse class info
class_info = {}
for synset, tf_label in zip(df['synset'], df['tfrecord_label']):
    class_info[synset] = int(tf_label) - 1

## Load model

In [4]:
model_wrapper = InceptionV1()
# model_wrapper.load_model()

## Load I-matrix

In [5]:
dir_path = '../../../data/InceptionV1/summit/I-mat'

In [6]:
I = {}
for blk in model_wrapper.BLKS:        
    print(blk)
    file_path = '{}/I_{}.json'.format(dir_path, blk)
    with open(file_path, 'r') as f:
        I[blk] = json.load(f)

mixed3a
mixed3b_3x3
mixed3b_5x5
mixed3b
mixed4a_3x3
mixed4a_5x5
mixed4a
mixed4b_3x3
mixed4b_5x5
mixed4b
mixed4c_3x3
mixed4c_5x5
mixed4c
mixed4d_3x3
mixed4d_5x5
mixed4d
mixed4e_3x3
mixed4e_5x5
mixed4e
mixed5a_3x3
mixed5a_5x5
mixed5a
mixed5b_3x3
mixed5b_5x5
mixed5b


## Load node by class

In [7]:
node_dir_path = '../../../data/InceptionV1/graph/node'

synsets = list(class_info.keys())
node_by_class = {}

for synset in synsets:
    file_path = '{}/node-{}.json'.format(node_dir_path, synset)
    with open(file_path, 'r') as f:
        node_by_class[synset] = json.load(f)

## Find important connections per class

In [8]:
important_connection = {}
total = 1000
tic = time()

for synset_i, synset in enumerate(synsets):
    if synset_i % 100 == 0:
        toc = time()
        print('%d/%d=%.2lf, %.2lf sec' % (synset_i, total, synset_i / total, toc - tic))
    
    important_connection[synset] = {}
    c = class_info[synset]
    
    for blk in model_wrapper.BLKS[::-1]:
        
        if blk == 'mixed3a':
            continue
            
        important_connection[synset][blk] = {}
            
        for group_id in node_by_class[synset][blk]:
            
            important_connection[synset][blk][group_id] = {}
            neurons = node_by_class[synset][blk][group_id]['group']

            for neuron in neurons:
                prev_blk = model_wrapper.get_prev_blk(neuron)
                neuron_i = int(neuron.split('-')[1])
                conn = I[blk][c][neuron_i]
                    
                prev_groups = node_by_class[synset][prev_blk]
                for prev_group_id in prev_groups:
                    prev_neurons = prev_groups[prev_group_id]['group']
                    for prev_neuron in prev_neurons:
                        prev_neuron_i = prev_neuron.split('-')[1]
                        if prev_neuron_i in conn:
                            conn_cnt = conn[prev_neuron_i]
                            if prev_group_id not in important_connection[synset][blk][group_id]:
                                important_connection[synset][blk][group_id][prev_group_id] = 0
                            important_connection[synset][blk][group_id][prev_group_id] += conn_cnt

0/1000=0.00, 0.00 sec
100/1000=0.10, 7.26 sec
200/1000=0.20, 14.56 sec
300/1000=0.30, 21.66 sec
400/1000=0.40, 28.82 sec
500/1000=0.50, 36.24 sec
600/1000=0.60, 43.36 sec
700/1000=0.70, 50.60 sec
800/1000=0.80, 57.83 sec
900/1000=0.90, 64.84 sec


In [9]:
for synset in important_connection:
    for blk in important_connection[synset]:
        for group_id in important_connection[synset][blk]:
            curr_group = node_by_class[synset][blk][group_id]['group']
            for prev_group_key in important_connection[synset][blk][group_id]:
                prev_blk = prev_group_key.split('-')[1]
                prev_group = node_by_class[synset][prev_blk][prev_group_key]['group']
                N = len(curr_group) * len(prev_group)
                important_connection[synset][blk][group_id][prev_group_key] /= N

In [10]:
# Save important connection
output_dir_path = '../../../data/InceptionV1/graph/edge'

for synset in synsets:
    file_path = '{}/edge-{}.json'.format(output_dir_path, synset)
    with open(file_path, 'w') as f:
        json.dump(important_connection[synset], f)