Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What should I do if I want to change the model? #8

Closed
shxiangyan opened this issue May 6, 2020 · 2 comments
Closed

What should I do if I want to change the model? #8

shxiangyan opened this issue May 6, 2020 · 2 comments

Comments

@shxiangyan
Copy link

No description provided.

@xiaohk
Copy link
Member

xiaohk commented May 6, 2020

Hello, this is related to #2.

Currently CNN Explainer only supports the Tiny-VGG architecture that we described in our manuscript. If you want to use a different CNN model, then you would need to modify the code. If you want to use another pre-trained Tiny-VGG model, you can see the following functions in CNN Explainer (how we use tensorflow.js to load the model):

model = await loadTrainedModel('PUBLIC_URL/assets/data/model.json');
cnn = await constructCNN(`PUBLIC_URL/assets/img/${selectedImage}`, model);

/**
* Wrapper to load a model.
*
* @param {string} modelFile Filename of converted (through tensorflowjs.py)
* model json file.
*/
export const loadTrainedModel = (modelFile) => {
return tf.loadLayersModel(modelFile);
}

/**
* Construct a CNN with given extracted outputs from every layer.
*
* @param {number[][]} allOutputs Array of outputs for each layer.
* allOutputs[i][j] is the output for layer i node j.
* @param {Model} model Loaded tf.js model.
* @param {Tensor} inputImageTensor Loaded input image tensor.
*/
const constructCNNFromOutputs = (allOutputs, model, inputImageTensor) => {
let cnn = [];
// Add the first layer (input layer)
let inputLayer = [];
let inputShape = model.layers[0].batchInputShape.slice(1);
let inputImageArray = inputImageTensor.transpose([2, 0, 1]).arraySync();
// First layer's three nodes' outputs are the channels of inputImageArray
for (let i = 0; i < inputShape[2]; i++) {
let node = new Node('input', i, nodeType.INPUT, 0, inputImageArray[i]);
inputLayer.push(node);
}
cnn.push(inputLayer);
let curLayerIndex = 1;
for (let l = 0; l < model.layers.length; l++) {
let layer = model.layers[l];
// Get the current output
let outputs = allOutputs[l].squeeze();
outputs = outputs.arraySync();
let curLayerNodes = [];
let curLayerType;
// Identify layer type based on the layer name
if (layer.name.includes('conv')) {
curLayerType = nodeType.CONV;
} else if (layer.name.includes('pool')) {
curLayerType = nodeType.POOL;
} else if (layer.name.includes('relu')) {
curLayerType = nodeType.RELU;
} else if (layer.name.includes('output')) {
curLayerType = nodeType.FC;
} else if (layer.name.includes('flatten')) {
curLayerType = nodeType.FLATTEN;
} else {
console.log('Find unknown type');
}
// Construct this layer based on its layer type
switch (curLayerType) {
case nodeType.CONV: {
let biases = layer.bias.val.arraySync();
// The new order is [output_depth, input_depth, height, width]
let weights = layer.kernel.val.transpose([3, 2, 0, 1]).arraySync();
// Add nodes into this layer
for (let i = 0; i < outputs.length; i++) {
let node = new Node(layer.name, i, curLayerType, biases[i],
outputs[i]);
// Connect this node to all previous nodes (create links)
// CONV layers have weights in links. Links are one-to-multiple.
for (let j = 0; j < cnn[curLayerIndex - 1].length; j++) {
let preNode = cnn[curLayerIndex - 1][j];
let curLink = new Link(preNode, node, weights[i][j]);
preNode.outputLinks.push(curLink);
node.inputLinks.push(curLink);
}
curLayerNodes.push(node);
}
break;
}
case nodeType.FC: {
let biases = layer.bias.val.arraySync();
// The new order is [output_depth, input_depth]
let weights = layer.kernel.val.transpose([1, 0]).arraySync();
// Add nodes into this layer
for (let i = 0; i < outputs.length; i++) {
let node = new Node(layer.name, i, curLayerType, biases[i],
outputs[i]);
// Connect this node to all previous nodes (create links)
// FC layers have weights in links. Links are one-to-multiple.
// Since we are visualizing the logit values, we need to track
// the raw value before softmax
let curLogit = 0;
for (let j = 0; j < cnn[curLayerIndex - 1].length; j++) {
let preNode = cnn[curLayerIndex - 1][j];
let curLink = new Link(preNode, node, weights[i][j]);
preNode.outputLinks.push(curLink);
node.inputLinks.push(curLink);
curLogit += preNode.output * weights[i][j];
}
curLogit += biases[i];
node.logit = curLogit;
curLayerNodes.push(node);
}
// Sort flatten layer based on the node TF index
cnn[curLayerIndex - 1].sort((a, b) => a.realIndex - b.realIndex);
break;
}
case nodeType.RELU:
case nodeType.POOL: {
// RELU and POOL have no bias nor weight
let bias = 0;
let weight = null;
// Add nodes into this layer
for (let i = 0; i < outputs.length; i++) {
let node = new Node(layer.name, i, curLayerType, bias, outputs[i]);
// RELU and POOL layers have no weights. Links are one-to-one
let preNode = cnn[curLayerIndex - 1][i];
let link = new Link(preNode, node, weight);
preNode.outputLinks.push(link);
node.inputLinks.push(link);
curLayerNodes.push(node);
}
break;
}
case nodeType.FLATTEN: {
// Flatten layer has no bias nor weights.
let bias = 0;
for (let i = 0; i < outputs.length; i++) {
// Flatten layer has no weights. Links are multiple-to-one.
// Use dummy weights to store the corresponding entry in the previsou
// node as (row, column)
// The flatten() in tf2.keras has order: channel -> row -> column
let preNodeWidth = cnn[curLayerIndex - 1][0].output.length,
preNodeNum = cnn[curLayerIndex - 1].length,
preNodeIndex = i % preNodeNum,
preNodeRow = Math.floor(Math.floor(i / preNodeNum) / preNodeWidth),
preNodeCol = Math.floor(i / preNodeNum) % preNodeWidth,
// Use channel, row, colume to compute the real index with order
// row -> column -> channel
curNodeRealIndex = preNodeIndex * (preNodeWidth * preNodeWidth) +
preNodeRow * preNodeWidth + preNodeCol;
let node = new Node(layer.name, i, curLayerType,
bias, outputs[i]);
// TF uses the (i) index for computation, but the real order should
// be (curNodeRealIndex). We will sort the nodes using the real order
// after we compute the logits in the output layer.
node.realIndex = curNodeRealIndex;
let link = new Link(cnn[curLayerIndex - 1][preNodeIndex],
node, [preNodeRow, preNodeCol]);
cnn[curLayerIndex - 1][preNodeIndex].outputLinks.push(link);
node.inputLinks.push(link);
curLayerNodes.push(node);
}
// Sort flatten layer based on the node TF index
curLayerNodes.sort((a, b) => a.index - b.index);
break;
}
default:
console.error('Encounter unknown layer type');
break;
}
// Add current layer to the NN
cnn.push(curLayerNodes);
curLayerIndex++;
}
return cnn;
}

/**
* Construct a CNN with given model and input.
*
* @param {string} inputImageFile filename of input image.
* @param {Model} model Loaded tf.js model.
*/
export const constructCNN = async (inputImageFile, model) => {
// Load the image file
let inputImageTensor = await getInputImageArray(inputImageFile, true);
// Need to feed the model with a batch
let inputImageTensorBatch = tf.stack([inputImageTensor]);
// To get intermediate layer outputs, we will iterate through all layers in
// the model, and sequencially apply transformations.
let preTensor = inputImageTensorBatch;
let outputs = [];
// Iterate through all layers, and build one model with that layer as output
for (let l = 0; l < model.layers.length; l++) {
let curTensor = model.layers[l].apply(preTensor);
// Record the output tensor
// Because there is only one element in the batch, we use squeeze()
// We also want to use CHW order here
let output = curTensor.squeeze();
if (output.shape.length === 3) {
output = output.transpose([2, 0, 1]);
}
outputs.push(output);
// Update preTensor for next nesting iteration
preTensor = curTensor;
}
let cnn = constructCNNFromOutputs(outputs, model, inputImageTensor);
return cnn;
}

Let us know if you have more questions. I will close the issue for now :P

@EricCousineau-TRI
Copy link

EricCousineau-TRI commented May 9, 2020

As an FYI, docs for converting to TensorFlow.js models:
https://www.tensorflow.org/js/tutorials#convert_pretrained_models_to_tensorflowjs
Relates #12.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants