### Mount gdrive
### This cell is for Google colaboratory users.

In [0]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive
%cd /gdrive/My\ Drive/DNNTopology

### Parameter setting
#### * Adjustments are required according to the trained model *

In [0]:
# model parameter
firstLayerSize = 300
secondLayerSize = 100
outputNeuron = 10
class_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# layer number for MNIST model 
layerNumber = 0
layerNumber1 = 1
layerNumber2 = 2

# layer numbers for CIFAR-10 model
#layerNumber = 13
#layerNumber1 = 15
#layerNumber2 = 16

# model file name
modelDir = "./models/"
model_name = modelDir + "firstLayerSize"+str(firstLayerSize)+"secondLayerSize"+str(secondLayerSize)+"outputNeuron"+str(outputNeuron)+ "class" + str(class_list) + ".h5"

#for sample execution
#model_name = modelDir + "firstLayerSize300secondLayerSize100outputNeuron10class[0, 1, 2, 3, 4, 5, 6, 7, 8, 9].h5"

#model_name = modelDir + "firstLayerSize512secondLayerSize512outputNeuron10class[0, 1, 2, 3, 4, 5, 6, 7, 8, 9].h5"
#firstLayerSize = 512
#secondLayerSize = 512
#layerNumber = 13
#layerNumber1 = 15
#layerNumber2 = 16

# output directory
simplexDir = "simplexes"

# the list of filtration. Parallel process execution is possible by distributing this list.
filList = range(1,65)

In [0]:
from keras import layers;
from keras import models;
import numpy as np
import copy
import itertools 
import pickle

### Load model

In [0]:
model = models.load_model( model_name, compile = False)

In [0]:
model.summary()

### Normalize weight matrix

In [0]:
weight, bias = model.layers[layerNumber].get_weights()
weight1, bias1 = model.layers[layerNumber1].get_weights()
weight2, bias2 = model.layers[layerNumber2].get_weights()

In [0]:
weight.shape

In [0]:
weight1.shape

In [0]:
weight2.shape

In [0]:
size = outputNeuron + firstLayerSize + secondLayerSize
relevance = np.identity(size)

In [0]:
for j in range(0,outputNeuron):
  normalizeFactor = 0
  weight2Plus = weight2 * (weight2 > 0)
  for i in range(0,secondLayerSize):
    normalizeFactor += weight2Plus[i][j]
  for i in range(0,secondLayerSize):
    relevance[i+outputNeuron][j] = weight2Plus[i][j] / normalizeFactor

In [0]:
for j in range(0,secondLayerSize):
  normalizeFactor = 0
  weight1Plus = weight1 * (weight1 > 0)
  for i in range(0,firstLayerSize):
    normalizeFactor += weight1Plus[i][j]
  for i in range(0,firstLayerSize):
    relevance[i+outputNeuron + secondLayerSize][j+outputNeuron] = weight1Plus[i][j] / normalizeFactor

### Constract simplex

In [0]:
matrix = relevance

In [0]:
def comb( sequence ):
    result = []
    for L in range(1, len(sequence)+1):
        for subset in itertools.combinations( sequence, L):
            result.append(list(subset))
    return result

In [0]:
def getSimplex(matrix, pointSequence, threshold):
    matrixSize = len(matrix)
    relevance = 1.0
    result = []
    #startPointからのRelevanceを計算する
    startPoint = pointSequence[0]
    for pointNumber in pointSequence:
        relevance = relevance * matrix[startPoint][pointNumber]
        startPoint = pointNumber
    #relevanceがthreshold以上だったらここまでの経路を追加する
    if relevance >= threshold:
        for e in comb(pointSequence):
            result.append(e)
        #最後の要素からの連結要素について再帰的にチェックする
        lastPoint = pointSequence[-1]
        for i in range(matrixSize):
            if matrix[lastPoint][i] > 0 and i != lastPoint:
                tempPointSequence = copy.deepcopy(pointSequence)
                tempPointSequence.append(i)
                #再帰呼び出し
                temp = getSimplex( matrix, tempPointSequence, threshold)
                #結果をresultに追加
                for e in temp:
                    for ee in comb(e):
                        result.append(ee)
    return list( map(list, set(map(tuple,result))))

In [0]:
def registerSimplexOutput( filList ):
    matrixSize = len(matrix)
    r = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2,
         1.0e-1, 0.9e-1, 0.8e-1, 0.7e-1, 0.6e-1, 0.5e-1, 0.4e-1, 0.3e-1, 0.2e-1, 
         1.0e-2, 0.9e-2, 0.8e-2, 0.7e-2, 0.6e-2, 0.5e-2, 0.4e-2, 0.3e-2, 0.2e-2,
         1.0e-3, 0.9e-3, 0.8e-3, 0.7e-3, 0.6e-3, 0.5e-3, 0.4e-3, 0.3e-3, 0.2e-3,
         1.0e-4, 0.9e-4, 0.8e-4, 0.7e-4, 0.6e-4, 0.5e-4, 0.4e-4, 0.3e-4, 0.2e-4,
         1.0e-5, 0.9e-5, 0.8e-5, 0.7e-5, 0.6e-5, 0.5e-5, 0.4e-5, 0.3e-5, 0.2e-5,
         1.0e-6, 0.9e-6, 0.8e-6, 0.7e-6, 0.6e-6, 0.5e-6, 0.4e-6, 0.3e-6, 0.2e-6,
         1.0e-7]
    print("Filtration: ", end = "")
    for fil in filList:
        number = r[fil - 1]
        filename =  simplexDir + "/Simplex" + str(fil)
        print( str(fil) + ", ", end="")

        saveSimplex = []
        for startPoint in range(0, matrixSize):
            simplex = getSimplex(matrix, [startPoint], number)
            saveSimplex.extend(simplex)
        saveFile = open( filename, 'wb')
        pickle.dump(saveSimplex, saveFile)
        saveFile.close

In [0]:
registerSimplexOutput( filList )