Skip to content
Permalink
Browse files

A tool to filter out all transformations related with one given node. (

…#3208)

Summary:
Implement a tool that
- filters the compile log for information about a specific node.
- It returns all related transformations to the given node.
Pull Request resolved: #3208

Test Plan:
Test with a real network: googlenet_v1_slim. Filter the node of 'InceptionV1_InceptionV1_Mixed_5b_Branch_0_Conv2d_0a_1x1_BatchNorm_batchnorm_mul21'. And result is shown in a dot graph that has all the related transformations.

`python3 ../glow/utils/log_parser.py -f googlenet_v1_slim/googlenet_v1_slim.onnx_compile.log`

`python3 ../glow/utils/compilation_filter.py --db-file compilation_log_db.sqlite --filter-target InceptionV1_InceptionV1_Mixed_5b_Branch_0_Conv2d_0a_1x1_BatchNorm_batchnorm_mul21`

`dot -Tpdf transformations.dot > trans.pdf`

<img width="490" alt="Screen Shot 2019-07-04 at 1 24 48 AM" src="https://user-images.githubusercontent.com/8838608/60651195-8decc500-9dfa-11e9-8686-ca6eb1830d62.png">

Yellow rectangle is the direct transformation that created/replaced the given node. Other blue rectangles are all related transformations.

=================================================================
=================================================================

=================================================================

Update with a QA that explains the design of this pr.

> What is the reason of SQLite?

The reason of using sqlite is to fully utilize well-optimized database to have efficient queries we might have in the future (when the log might get bigger in the future).

> What is the general flow?

Run `log_parser.py` to process the raw compilation log and create a database file and store processed data items (Node transformation right now, but there can be more in the future) into it. In details, I create a table `Log_Transformation (trans_id INTEGER, operation_type VARCHAR, node_name VARCHAR, node_kind VARCHAR, scope_name VARCHAR)`. The script stores all node transformation into this table.

For example transaction with id 100 that is 'A lower into B,C', the insert data items would be `100, REMOVE, A, A_kind, lower`, `100, ADD, B, B_kind,  lower`, `100, ADD, C, C_kind,  lower`.

Run `compilation_filter.py` to query all related transformation of a give node name. Given a node, query the database to find all transformations that are directly/indirectly related with it.

>  Do you create a new DB file each time you parse/query the log as opposed to creating it once and then running multiple queries against it? What are the pros and cons?

I create a new DB file once when parsing a new compilation log file with `log_parser.py`. I execute queries with 'compilation_filter.py' on the created database file. The pro here is that I don't need to parse the raw log file every time i execute queries. The con here is that whenever we update the log file we need to recreate a database file.

> Do you insert all the information from the log into this DB or just some parts of it?

Right now, I only store node transformation info into the database. An example is given above. But we can expand with more possible tables that store more valuable information in the future.

> How do you handle queries for finding all ancestors of a given node? Do you issue multiple queries? Or may be use use some clever tricks like those described e.g. transitive_closure SQLite extension or WITH RECURSIVE

I issue multiple queries. The steps are like below:
1. Search all trans_ids that has the provided node_name. Store that trans_ids in a list.
`SELECT trans_id
            FROM Log_Transformation
            WHERE node_name = '{nodeName}'
            GROUP BY trans_id`

2. Search all node_names that are in the trans_ids_list.
`SELECT node_name
            FROM Log_Transformation
            WHERE trans_id in {trans_ids_list}
            GROUP BY node_name`

3. Search all trans_ids that that has the node_name in the node_name_list in step 2.
 `SELECT trans_id
            FROM Log_Transformation
            WHERE node_name in {node_name_list}
            GROUP BY trans_id`

4. Repeat 2,3 until the trans_ids_list dont  change any more.
5. Return the trans_ids_list.

Well, i did try with CTE i.e. the WITH RECURSIVE, I found it really hard to write the same logic. Also I'm not sure if this recursive query is actually translated into multiple SELECT clauses in the dbms, which might be same efficiency as executing multiple queries.

> Do you remove the DB file at the end of processing?

No, it stays there. One DB file will only get removed when `log_parser.py` tries to create a new db that has the same name.

Differential Revision: D16171145

Pulled By: ZchiPitt

fbshipit-source-id: edea68b37d4ebd8546fea3fac7adad5a64cb1a0e
  • Loading branch information...
ZchiPitt authored and facebook-github-bot committed Jul 9, 2019
1 parent 4253762 commit 67b2454a6ff4200f40df7bfd1c26de9500116067
Showing with 415 additions and 20 deletions.
  1. +260 −0 utils/compilation_filter.py
  2. +155 −20 utils/log_parser.py
@@ -0,0 +1,260 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import sqlite3
import os
from typing import (
List,
Dict,
)

# A list of all filtered transformations.
TRANS_LIST: List['Transformation'] = []

# Mapping between added nodes and the transformation that adds these nodes.
NODES_ADDING_MAP: Dict[str, 'Transformation'] = {}


class Transformation:
"""A class that represents the nodes transformation, e.g. lower,fold etc.
Public attributes:
addedNodes_: List[str]. Nodes added by this transformation.
removedNodes_: List[str]. Nodes removed by this transformation.
ancestors_: List['Transformation']. The ancestor transformation of current transformation.
scopeName_: str. The scope of current transformation.
transID_: str. The internal transformation ID in the database.
isDirectTrans_ :bool. Whether this transformation directly created/replaced the given nodeName that is passed to this script file.
"""

def __init__(self, transID: str):
self.addedNodes_: List[str] = []
self.removedNodes_: List[str] = []
self.ancestors_: List['Transformation'] = []
self.scopeName_: str = ''
self.transID_: str = transID
self.isDirectTrans_: bool = False

def appendAddedNode(self, nodeName: str) -> None:
"""Append the added nodes of this transformation."""

self.addedNodes_.append(nodeName)

def appendRemovedNode(self, nodeName: str) -> None:
"""Append the removed nodes of this transformation."""

self.removedNodes_.append(nodeName)

def addAncestor(self, ancestor: 'Transformation') -> None:
"""Add ancestors of this transformation."""

self.ancestors_.append(ancestor)


class DottyPrinter:
"""A class for generating the dotty graph file"""

def __init__(self, transList: List[Transformation]):
self.transList_ = transList
self.vertices_ = []
self.edges_ = []

def get_color(self, isDirectTrans: bool) -> str:
"""Returns the color for the given node. """

if isDirectTrans:
return "Yellow2"
else:
return "AliceBlue"

def dump_label(self, tran: Transformation) -> str:
"""Returns the string for the label of the given transformation. """

labelStr = rf"""{{ {{SCOPE:\l{tran.scopeName_} }}|{{ORIGINAL NODES:\l\l"""
for rstr in tran.removedNodes_:
labelStr += rf"""{rstr}\l\l"""
labelStr += rf"}}| {{REPLACED BY:\l\l"
for astr in tran.addedNodes_:
labelStr += rf"""{astr}\l\l"""
labelStr += f"}} }}"
return labelStr

def dump_node(self, tran: Transformation) -> None:
"""Generates the dotty information for the given transformation. """

if not tran:
return

tranStr = f"""v{tran.transID_}[\n
\tlabel = \"{self.dump_label(tran)}\"\n
\tshape = \"record\"\n
\tstyle=\"filled,rounded\"\n
\tfillcolor={self.get_color(tran.isDirectTrans_)}\n
penwidth = 2];\n"""
self.vertices_.append(tranStr)

def visit_nodes(self) -> None:
"""Visits all transformation and dump the dotty information for each transformation. """

for tran in self.transList_:
self.dump_node(tran)

def visit_edges(self) -> None:
"""Visits all edges and dump the dotty information for each edge. """

for tran in self.transList_:
for anc in tran.ancestors_:
edgeStr = f"v{anc.transID_} -> v{tran.transID_}"
self.edges_.append(edgeStr)

def dump_graph(self) -> None:
"""Visits the graph and generates the dotty information. """

self.visit_nodes()
self.visit_edges()
with open(f"transformations.dot", "w") as f:
f.write("digraph DAG {\n\trankdir=TB;\n")
for v in self.vertices_:
f.write(f"{v}\n")
for e in self.edges_:
f.write(f"{e};\n")
f.write("}")


def dump_dotty_DAG():
"""A helper function to dump the dotty file."""

dotty = DottyPrinter(TRANS_LIST)
dotty.dump_graph()


def parse_args():
"""Parse the arguments of this script. """

parser = argparse.ArgumentParser(
description="Filter compilation and optimiztion.")
parser.add_argument("--db-file")
parser.add_argument("--filter-target")
options = parser.parse_args()

assert options.db_file and options.filter_target, "Please specify db file and filter target."
return options.db_file, options.filter_target


def init_db(sqliteFile: str) -> sqlite3.Connection:
"""Initialize a sqlite3 database connection."""

assert os.path.isfile(sqliteFile)

# Connect to database file.
return sqlite3.connect(sqliteFile)


def find_all_related_transformation(
cursor: sqlite3.Cursor,
transIDs: List[str]):
"""A recursive function that find all related transformations given a list of transformation IDs in the database.
Args:
cursor: sqlite3.Cursor. Cursor of current sqlite3 database connection.
transIDs: List[str]. A list of transformation IDs.
"""

transQueryStr = "(" + ', '.join(transIDs) + ')'
cursor.execute(f"""
SELECT node_name
FROM Log_Transformation
WHERE trans_id in {transQueryStr}
GROUP BY node_name
""")
rows = cursor.fetchall()
nodesList = ["'" + r[0] + "'" for r in rows]

transQueryStr = "(" + ', '.join(nodesList) + ')'
cursor.execute(f"""
SELECT trans_id
FROM Log_Transformation
WHERE node_name in {transQueryStr}
GROUP BY trans_id
""")
rows = cursor.fetchall()
newTransIDs = [str(r[0]) for r in rows]

if sorted(newTransIDs) != sorted(transIDs):
transIDs = find_all_related_transformation(cursor, newTransIDs)
return transIDs


def filter_node_transformation(nodeName: str, conn: sqlite3.Connection):
"""Filter out all node transformation that is related to the given node.
Args:
nodeName: str. The node name that is passed to this script.
conn: sqlite3.Connection. A sqlite3 database connection.
"""

cursor = conn.cursor()
cursor.execute("""
SELECT trans_id
FROM Log_Transformation
WHERE node_name = ?
GROUP BY trans_id
""", (nodeName,))
rows = cursor.fetchall()

directTransIDs = [str(r[0]) for r in rows]

transIDs = find_all_related_transformation(cursor, directTransIDs)

for tid in transIDs:
cursor.execute("""
SELECT *
FROM Log_Transformation
WHERE trans_id = ?
""", (tid, ))
rows = cursor.fetchall()
if len(rows):
tran = Transformation(tid)
if tid in directTransIDs:
tran.isDirectTrans_ = True
TRANS_LIST.append(tran)
tran.scopeName_ = rows[0][4].replace(
"glow::", "").replace(
"->", r" --\> ")
for r in rows:
opr_type, name, kind = r[1:4]
if opr_type == 'ADD':
nodeKindAndName = kind + r" \l" + name
tran.appendAddedNode(nodeKindAndName)
NODES_ADDING_MAP[nodeKindAndName] = tran
elif opr_type == 'REMOVE':
nodeKindAndName = kind + r" \l" + name
tran.appendRemovedNode(nodeKindAndName)
if nodeKindAndName in NODES_ADDING_MAP:
tran.addAncestor(NODES_ADDING_MAP[nodeKindAndName])

dump_dotty_DAG()
conn.commit()


def main():
dbFile, filterTarget = parse_args()
with init_db(dbFile) as conn:
filter_node_transformation(filterTarget, conn)


if __name__ == "__main__":
main()

0 comments on commit 67b2454

Please sign in to comment.
You can’t perform that action at this time.