Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor core of package to remat package #1

Merged
merged 22 commits into from Sep 19, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions .gitignore
Expand Up @@ -12,6 +12,27 @@ env/
.idea/
.envrc

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Gurobi
*.lp
*.mps
Expand Down
26 changes: 25 additions & 1 deletion README.md
@@ -1,2 +1,26 @@
# Optimal Checkpointing
# Optimal tensor rematerialization
`remat` is a package to compute schedules for rematerializing tensors in DFGraphs (tensor dataflow graphs).

# Installation
```bash
$ git clone https://github.com/parasj/tensor-remat.git
$ cd tensor-remat
$ pip install -e .
$ py.test
```

If you are evaluating on a GPU instance, you can install `tensorflow-gpu` as a dependency in order to enable GPU support.

# Project structure
```
.
├── experiments stores one-off plotting scripts for paper
├── remat main python package
│   ├── core
│   │   ├── graph.py DFGraph data structure
│   │   ├── schedule.py Schedule defition of concrete evaluation order
│   │   ├── solvers Package containing solver implementations
│   │   └── utils
│   └── tensorflow2 Tensorflow 2.0 integration (extraction and execution)
└── tests
```
143 changes: 53 additions & 90 deletions src/utils/graph.py → remat/core/graph.py
Expand Up @@ -26,7 +26,7 @@ def __init__(self, args: AdjList, v: Iterable[Vertex], vfwd_map: Dict[Vertex, Ve
self.args = defaultdict(list, args)
self.v = list(sorted(v))
self.size = len(self.v)
self.edge_list = Graph.adj_to_edge_list(self.args, reverse_edge=True)
self.edge_list = adj_to_edge_list(self.args, reverse_edge=True)
self.vfwd_map = vfwd_map
self.vfwd = list(sorted(vfwd_map.keys())) if vfwd_map else self.v
self.vloss = vloss
Expand Down Expand Up @@ -73,34 +73,14 @@ def cpu_gcd(self, *othervals):

return np.gcd.reduce(intvalues)

@staticmethod
def gen_linear_graph(forward_node_count, **kwargs):
"""
gen_linear_graph will generate linear-style graphs like VGG and AlexNet.
Method returns forward and backward graphs. Pass cost_ram and cost_cpu as kwargs.
:param forward_node_count: number of forward (not backward nodes)
:return: Graph object containing linear graph
"""
args = defaultdict(list)
vfwd_map = {}
loss_node_idx = forward_node_count
for i in range(forward_node_count * 2):
args[i + 1].append(i)
if i < forward_node_count:
corresponding_bwd = (forward_node_count * 2) - i
args[corresponding_bwd].append(i)
vfwd_map[i] = corresponding_bwd
v = list(vfwd_map.keys()) + list(vfwd_map.values()) + [loss_node_idx]
return Graph(args=args, v=v, vfwd_map=vfwd_map, vloss=loss_node_idx, **kwargs)

def write_graphviz(self, directory, format='pdf', quiet=True):
def write_graphviz(self, directory, format='pdf', quiet=True, name=""):
"""
Generate Graphviz-formatted edge list for visualization
:param directory: str -- where to write source and rendered graph
:param format: str -- file format for output
:param quiet: bool -- whether or not to print debug information
"""
dot = Digraph("!ExtractedGraph")
dot = Digraph("!ExtractedGraph" + str(name))
dot.attr('graph', rankdir='LR')
for u in self.vfwd:
with dot.subgraph() as s:
Expand Down Expand Up @@ -131,62 +111,6 @@ def write_graphviz(self, directory, format='pdf', quiet=True):
except TypeError:
dot.render(directory=directory, format=format)

def tensor_plot(self, sched, directory, tag=None, format='pdf', quiet=True):
import solvers.scheduler
dot = Digraph(f"!TensorPlot_{tag}", engine="dot")
if sched is None:
return
for op in sched:
if isinstance(op, solvers.scheduler.OperatorEvaluation):
if self.is_loss_node(op.id):
node_name = "Loss"
elif self.is_forward_node(op.id):
node_name = self.node_names.get(op.id)
node_name = node_name if node_name is None else f"{node_name} ({str(op.id)})"
elif self.is_backward_node(op.id):
fwd_node = self.backward_to_forward(op.id)
node_name = "Grad<{}> {} {}".format(self.node_names.get(fwd_node), fwd_node, op.id)
else:
raise ValueError("Unknown operation")
# dot.node("op{}".format(op.id), node_name, shape="diamond")
# dot.edge("op{}".format(op.id), "reg{}".format(op.out_register))
dot.node(f"reg{op.out_register}", f"Register {op.out_register} for {node_name}", shape="box")
for dep_op, dep_reg in op.arg_regs.items():
dot.edge("reg{}".format(dep_reg), "reg{}".format(op.out_register),
style="dashed", label=str(self.args[op.id].index(dep_op)))

try:
dot.render(directory=directory, format=format, quiet=quiet)
except TypeError:
dot.render(directory=directory, format=format)

@staticmethod
def edge_to_adj_list(E: EdgeList, convert_undirected=False):
"""Returns an (undirected / bidirectional) adjacency list"""
adj_list = defaultdict(set)
for (i, j) in list(E):
adj_list[i].add(j)
if convert_undirected:
adj_list[j].add(i)
return adj_list

@staticmethod
def adj_to_edge_list(E: AdjList, convert_undirected=False, reverse_edge=False):
"""Returns an edge list
:param E: AdjList -- input graph
:param convert_undirected: bool -- if true, add u -> v and v -> u to output graph
:param reverse_edge: bool -- if true, reverse edge direction prior to conversion
:return:
"""
edge_list = []
for u, deps in E.items():
for v in deps:
edge = (u, v) if not reverse_edge else (v, u)
edge_list.append(edge)
if convert_undirected:
edge_list.append(tuple(reversed(edge)))
return edge_list

@property
@lru_cache(maxsize=None)
def max_degree(self):
Expand All @@ -210,7 +134,7 @@ def dfs(graph, node, visited):
return visited

unvisited = set(V)
adj_list = Graph.edge_to_adj_list(E, convert_undirected=True)
adj_list = edge_to_adj_list(E, convert_undirected=True)
components = 0
while len(unvisited) > 0:
v = unvisited.pop()
Expand Down Expand Up @@ -243,15 +167,7 @@ def checkpoint_set_all(self) -> set:
@property
@lru_cache(maxsize=None)
def topological_order_fwd(self):
return self._topological_order(True)

@property
@lru_cache(maxsize=None)
def topological_order(self):
return self._topological_order(False)

def _topological_order(self, forward_only: bool):
E = self.edge_list_fwd if forward_only else self.edge_list
E = self.edge_list_fwd

def helper(adj_list_, v, visited_, stack_):
visited_[v] = True
Expand All @@ -260,7 +176,7 @@ def helper(adj_list_, v, visited_, stack_):
helper(adj_list_, i, visited_, stack_)
stack_.insert(0, v)

adj_list = Graph.edge_to_adj_list(E, convert_undirected=True)
adj_list = edge_to_adj_list(E, convert_undirected=True)
num_nodes = len(adj_list.keys())

visited = [False] * num_nodes
Expand Down Expand Up @@ -327,3 +243,50 @@ def dependency_order(self, node: int):
def max_degree_ram(self):
"""compute minimum memory needed for any single node (ie inputs and outputs)"""
return max([sum([self.cost_ram[u] for u in self.predecessors(v)]) + self.cost_ram[v] for v in self.vfwd])


def gen_linear_graph(forward_node_count, **kwargs):
"""
gen_linear_graph will generate linear-style graphs like VGG and AlexNet.
Method returns forward and backward graphs. Pass cost_ram and cost_cpu as kwargs.
:param forward_node_count: number of forward (not backward nodes)
:return: Graph object containing linear graph
"""
args = defaultdict(list)
vfwd_map = {}
loss_node_idx = forward_node_count
for i in range(forward_node_count * 2):
args[i + 1].append(i)
if i < forward_node_count:
corresponding_bwd = (forward_node_count * 2) - i
args[corresponding_bwd].append(i)
vfwd_map[i] = corresponding_bwd
v = list(vfwd_map.keys()) + list(vfwd_map.values()) + [loss_node_idx]
return Graph(args=args, v=v, vfwd_map=vfwd_map, vloss=loss_node_idx, **kwargs)


def edge_to_adj_list(E: EdgeList, convert_undirected=False):
"""Returns an (undirected / bidirectional) adjacency list"""
adj_list = defaultdict(set)
for (i, j) in list(E):
adj_list[i].add(j)
if convert_undirected:
adj_list[j].add(i)
return adj_list


def adj_to_edge_list(E: AdjList, convert_undirected=False, reverse_edge=False):
"""Returns an edge list
:param E: AdjList -- input graph
:param convert_undirected: bool -- if true, add u -> v and v -> u to output graph
:param reverse_edge: bool -- if true, reverse edge direction prior to conversion
:return:
"""
edge_list = []
for u, deps in E.items():
for v in deps:
edge = (u, v) if not reverse_edge else (v, u)
edge_list.append(edge)
if convert_undirected:
edge_list.append(tuple(reversed(edge)))
return edge_list
24 changes: 24 additions & 0 deletions remat/core/schedule.py
@@ -0,0 +1,24 @@
from typing import NamedTuple, Dict, List, Union


class OperatorEvaluation(NamedTuple):
id: int
arg_regs: Dict[int, int]
out_register: int
operator_cost: int
is_backwards: bool = False
update_aux_vars: bool = True # will be true if this is the last time this node is evaluated


class AllocateRegister(NamedTuple):
register_id: int
for_operation_id: int
register_size: int


class DeallocateRegister(NamedTuple):
op_id: int
register_id: int


Schedule = List[Union[OperatorEvaluation, AllocateRegister, DeallocateRegister]]
34 changes: 34 additions & 0 deletions remat/core/utils/graph_utils.py
@@ -0,0 +1,34 @@
from graphviz import Digraph

from remat.core.graph import Graph
from remat.core.schedule import Schedule, OperatorEvaluation


def tensor_plot(g: Graph, sched: Schedule, directory, tag=None, format='pdf', quiet=True):
parasj marked this conversation as resolved.
Show resolved Hide resolved
dot = Digraph(f"!TensorPlot_{tag}", engine="dot")
if sched is None:
return
for op in sched:
if isinstance(op, OperatorEvaluation):
if g.is_loss_node(op.id):
node_name = "Loss"
elif g.is_forward_node(op.id):
node_name = g.node_names.get(op.id)
node_name = node_name if node_name is None else f"{node_name} ({str(op.id)})"
elif g.is_backward_node(op.id):
fwd_node = g.backward_to_forward(op.id)
node_name = "Grad<{}> {} {}".format(g.node_names.get(fwd_node), fwd_node, op.id)
else:
raise ValueError("Unknown operation")
# dot.node("op{}".format(op.id), node_name, shape="diamond")
# dot.edge("op{}".format(op.id), "reg{}".format(op.out_register))
dot.node(f"reg{op.out_register}", f"Register {op.out_register} for {node_name}", shape="box")
for dep_op, dep_reg in op.arg_regs.items():
dot.edge("reg{}".format(dep_reg), "reg{}".format(op.out_register),
style="dashed", label=str(g.args[op.id].index(dep_op)))

try:
dot.render(directory=directory, format=format, quiet=quiet)
except TypeError:
dot.render(directory=directory, format=format)

18 changes: 0 additions & 18 deletions requirements.txt

This file was deleted.

18 changes: 0 additions & 18 deletions requirements_gpu.txt

This file was deleted.

3 changes: 0 additions & 3 deletions scripts/connect_redis.sh

This file was deleted.

27 changes: 0 additions & 27 deletions scripts/convert_griewank_pickles.py

This file was deleted.

6 changes: 0 additions & 6 deletions scripts/eval_scripts/c76.sh

This file was deleted.

4 changes: 0 additions & 4 deletions scripts/eval_scripts/c77.sh

This file was deleted.