Skip to content
Permalink
Browse files

Refactor core of package to remat package (#1)

  • Loading branch information
parasj committed Sep 19, 2019
1 parent 6879686 commit d0da9caf401efebd4cc047571970c649155e90f1
Showing with 599 additions and 1,249 deletions.
  1. +21 −0 .gitignore
  2. +25 −1 README.md
  3. +1 −1 {src → _deprecated_src}/collect_execution_pickles.py
  4. +1 −1 {src → _deprecated_src}/eval_runner.py
  5. +1 −1 {src → _deprecated_src}/evaluation/budget_sweep.py
  6. +3 −7 src/evaluation/test_execution.py → _deprecated_src/evaluation/eval_execution.py
  7. +2 −18 {src → _deprecated_src}/evaluation/maximize_batch_size.py
  8. 0 {src → _deprecated_src}/evaluation/solve_time_plot.py
  9. 0 {src → _deprecated_src}/evaluation/util/cost_model.py
  10. +2 −2 {src → _deprecated_src}/evaluation/util/evaluation_utils.py
  11. +6 −4 {src → _deprecated_src}/evaluation/util/solver_utils.py
  12. +3 −3 {src → _deprecated_src}/execute_one.py
  13. 0 {src → _deprecated_src}/get_shapes.py
  14. 0 {src → _deprecated_src}/global_version.py
  15. 0 {src → _deprecated_src}/integration/onnx/hooks.py
  16. 0 {src → _deprecated_src}/integration/onnx/process.py
  17. +3 −3 {src → _deprecated_src}/integration/tf2/TF2ExtractorParams.py
  18. +3 −4 {src → _deprecated_src}/integration/tf2/TF2Runner.py
  19. +4 −4 {src → _deprecated_src}/integration/tf2/extraction.py
  20. 0 {src → _deprecated_src}/integration/tf2/hooks.py
  21. 0 {src → _deprecated_src}/integration/tf2/misc.py
  22. 0 {src → _deprecated_src}/integration/tf2/process.py
  23. 0 {src → _deprecated_src}/integration/tf2/runtimes.py
  24. 0 {src → _deprecated_src}/profile_keras.py
  25. +6 −94 {src → _deprecated_src}/solvers/solver.py
  26. +9 −8 {src → _deprecated_src}/solvers/solver_ilp.py
  27. +9 −8 {src → _deprecated_src}/solvers/solver_ilp_maxbs.py
  28. +2 −2 {src → _deprecated_src}/utils/redis.py
  29. 0 {src → _deprecated_src}/utils/setup_logger.py
  30. +56 −139 src/utils/graph.py → remat/core/dfgraph.py
  31. +65 −0 remat/core/schedule.py
  32. +42 −9 src/solvers/util.py → remat/core/solvers/common.py
  33. +44 −23 {src → remat/core}/solvers/scheduler.py
  34. +38 −0 remat/core/solvers/strategy_checkpoint_all.py
  35. +25 −0 remat/core/solvers/strategy_checkpoint_last.py
  36. +52 −0 remat/core/solvers/strategy_chen.py
  37. 0 src/evaluation/util/solve_strategy.py → remat/core/solvers/strategy_enum.py
  38. +117 −0 remat/core/utils/debug_plotting.py
  39. +2 −1 {src → remat/core}/utils/timer.py
  40. 0 remat/tensorflow2/execution.py
  41. 0 remat/tensorflow2/extraction.py
  42. +0 −18 requirements.txt
  43. +0 −18 requirements_gpu.txt
  44. +0 −3 scripts/connect_redis.sh
  45. +0 −27 scripts/convert_griewank_pickles.py
  46. +0 −6 scripts/eval_scripts/c76.sh
  47. +0 −4 scripts/eval_scripts/c77.sh
  48. +0 −4 scripts/eval_scripts/c78.sh
  49. +0 −7 scripts/eval_scripts/havoc.sh
  50. +0 −93 scripts/gen_all_plot.sh
  51. +0 −44 scripts/git_sync.py
  52. +0 −18 scripts/git_sync_local.sh
  53. +0 −9 scripts/install.sh
  54. +0 −5 scripts/machines.txt
  55. +0 −28 scripts/multi_gpu_parallel_profile.sh
  56. +0 −80 scripts/plot_memory_wall.py
  57. +0 −51 scripts/profile_all.py
  58. +0 −22 scripts/profile_all.sh
  59. +0 −116 scripts/revolve.py
  60. +0 −12 scripts/run.sh
  61. +0 −7 scripts/run_c77.sh
  62. +0 −5 scripts/run_c78.sh
  63. +0 −19 scripts/run_e2e.sh
  64. +0 −9 scripts/run_havoc.sh
  65. +0 −16 scripts/run_ilpsolver1.sh
  66. +0 −20 scripts/run_ilpsolver2.sh
  67. +0 −21 scripts/run_ilpsolver3.sh
  68. +0 −38 scripts/single_gpu_parallel_profile.sh
  69. +28 −0 setup.py
  70. +0 −216 src/solvers/result.py
  71. +29 −0 tests/test_linear.py
@@ -12,6 +12,27 @@ env/
.idea/
.envrc

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Gurobi
*.lp
*.mps
@@ -1,2 +1,26 @@
# Optimal Checkpointing
# Optimal tensor rematerialization
`remat` is a package to compute schedules for rematerializing tensors in DFGraphs (tensor dataflow graphs).

# Installation
```bash
$ git clone https://github.com/parasj/tensor-remat.git
$ cd tensor-remat
$ pip install -e .
$ py.test
```

If you are evaluating on a GPU instance, you can install `tensorflow-gpu` as a dependency in order to enable GPU support.

# Project structure
```
.
├── experiments stores one-off plotting scripts for paper
├── remat main python package
│   ├── core
│   │   ├── graph.py DFGraph data structure
│   │   ├── schedule.py Schedule defition of concrete evaluation order
│   │   ├── solvers Package containing solver implementations
│   │   └── utils
│   └── tensorflow2 Tensorflow 2.0 integration (extraction and execution)
└── tests
```
@@ -5,7 +5,7 @@

import pandas as pd

from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy


def extract_params():
@@ -11,7 +11,7 @@
from evaluation.budget_sweep import eval_budget_sweep
from evaluation.maximize_batch_size import eval_maximize_batch_size
from evaluation.solve_time_plot import eval_solve_time
from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from integration.tf2.extraction import MODEL_NAMES
from utils.redis import RedisCache
from utils.setup_logger import setup_logger
@@ -13,7 +13,7 @@

from evaluation.util.cost_model import CostModel
from evaluation.util.evaluation_utils import prefix_min_np, result_dict_to_dataframe, RSResultDict, get_futures
from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from evaluation.util.solver_utils import remote_evaluation_iteration
from integration.tf2.TF2ExtractorParams import TF2ExtractorParams
from integration.tf2.extraction import get_keras_model, pretty_platform_name, platform_memory, \
@@ -1,25 +1,21 @@
import functools
import os
from typing import Dict, Optional, List, Tuple, Iterable
from typing import Optional, List, Tuple
from tqdm import tqdm

import pandas
import ray
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import numpy as np

from evaluation.util.evaluation_utils import result_dict_to_dataframe, RSResultDict
from evaluation.util.solver_utils import remote_evaluation_iteration
from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from integration.tf2.TF2ExtractorParams import TF2ExtractorParams
from integration.tf2.extraction import get_keras_model, platform_memory
from integration.tf2.TF2Runner import TF2Runner
from integration.tf2.misc import categorical_cross_entropy, random_batch
from solvers.result import RSResult
from utils.redis import RedisCache
from utils.setup_logger import setup_logger
from utils.timer import Timer
from remat.core.utils.timer import Timer


def plot_solver_result(results: pandas.DataFrame, plot_file: str):
@@ -1,28 +1,12 @@
from __future__ import division

import functools
import math
import os
from typing import List, Dict, Iterable

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import ray
import seaborn as sns
from tqdm import tqdm

from evaluation.util.cost_model import CostModel
from evaluation.util.evaluation_utils import prefix_min_np, result_dict_to_dataframe, RSResultDict
from evaluation.util.solver_utils import remote_evaluation_iteration
from evaluation.util.solve_strategy import SolveStrategy

from integration.tf2.TF2ExtractorParams import TF2ExtractorParams
from integration.tf2.extraction import get_keras_model, pretty_model_name, pretty_platform_name, platform_memory, \
CHAIN_GRAPH_MODELS
from integration.tf2.misc import categorical_cross_entropy
from integration.tf2.extraction import get_keras_model
from solvers.result import RSResult
from solvers.solver_ilp_maxbs import MaxBatchILPSolver
from utils.redis import RedisCache
from utils.setup_logger import setup_logger

GB = 1000 * 1000 * 1000
File renamed without changes.
File renamed without changes.
@@ -1,13 +1,13 @@
import math
import sys
from typing import List, Dict, Iterable
from typing import List, Dict

import numpy as np
import pandas
import ray
from tqdm import tqdm

from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from solvers.result import RSResult

RSResultDict = Dict[SolveStrategy, List[RSResult]]
@@ -4,12 +4,14 @@
import ray

import utils.redis
from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from integration.tf2.TF2ExtractorParams import TF2ExtractorParams
from remat.core.solvers.strategy_checkpoint_all import solve_checkpoint_all
from remat.core.solvers.strategy_checkpoint_last import solve_checkpoint_last_node
from solvers.result import PartialRSResult, RSResult
from solvers.solver import CheckpointSolver
from utils.setup_logger import setup_logger
from utils.timer import Timer
from remat.core.utils.timer import Timer

RAY_OVERPROVISON_PCT = 0.75

@@ -24,9 +26,9 @@ def tf2_solve(solve_params: TF2ExtractorParams, solve_strategy: SolveStrategy, b
elif solve_strategy == SolveStrategy.CHEN_SQRTN_NOAP:
sol = CheckpointSolver.schedule_sqrtn_chen16(solve_params.g, use_actuation_points=False)
elif solve_strategy == SolveStrategy.CHECKPOINT_LAST_NODE:
sol = CheckpointSolver.schedule_checkpoint_last_node(solve_params.g)
sol = solve_checkpoint_last_node(solve_params.g)
elif solve_strategy == SolveStrategy.CHECKPOINT_ALL:
sol = CheckpointSolver.schedule_checkpoint_all(solve_params.g)
sol = solve_checkpoint_all(solve_params.g)
elif solve_strategy == SolveStrategy.CHECKPOINT_ALL_AP:
sol = CheckpointSolver.schedule_checkpoint_all_ap(solve_params.g)
elif solve_strategy == SolveStrategy.CHEN_GREEDY:
@@ -8,9 +8,9 @@

import dotenv

from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from integration.tf2.extraction import MODEL_NAMES
from evaluation.test_execution import execute_one
from evaluation.eval_execution import execute_one
from utils.setup_logger import setup_logger


@@ -91,7 +91,7 @@ def run_single_model(args):


def get_allstrat_ram(args):
from evaluation.test_execution import get_solution_to_evaluate
from evaluation.eval_execution import get_solution_to_evaluate

log_base = os.path.join("data", "get_allstrat_ram",
f"{args.platform}_{args.model_name}_{args.model_version}_{args.batch_size}_{args.input_shape}_{args.strategy}_{args.buffer_mem_mb}_gradless_eagerfalse")
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
@@ -6,7 +6,7 @@
import tensorflow as tf

from integration.tf2.extraction import extract_graph_from_keras
from utils.graph import Graph
from remat.core.dfgraph import DFGraph
from utils.setup_logger import setup_logger


@@ -24,8 +24,8 @@ def __init__(self, keras_model: tf.keras.models.Model, batch_size: int = 1,
self.output_shape = list(keras_model.output_shape)
self.output_shape[0] = batch_size

self.g: Graph = extract_graph_from_keras(keras_model, batch_size=batch_size, loss_cpu_cost=loss_cpu_cost,
loss_ram_cost=loss_ram_cost, costs_np=costs_np)
self.g: DFGraph = extract_graph_from_keras(keras_model, batch_size=batch_size, loss_cpu_cost=loss_cpu_cost,
loss_ram_cost=loss_ram_cost, costs_np=costs_np)
self.logger.info(f"Extracted graph {keras_model.name} with {self.g.size} nodes")

if self.log_base is not None:
@@ -1,16 +1,15 @@
import itertools
import os

import tensorflow as tf

from integration.tf2.misc import categorical_cross_entropy
from solvers.scheduler import AllocateRegister, DeallocateRegister, OperatorEvaluation, Schedule
from utils.graph import Graph
from remat.core.schedule import OperatorEvaluation, AllocateRegister, DeallocateRegister, Schedule
from remat.core.dfgraph import DFGraph
from utils.setup_logger import setup_logger


class TF2Runner:
def __init__(self, keras_model: tf.keras.models.Model, g: Graph, schedule: Schedule,
def __init__(self, keras_model: tf.keras.models.Model, g: DFGraph, schedule: Schedule,
loss_fn=categorical_cross_entropy, eager: bool = True, log_base: str = None, debug=False, batch_size=None):
self.log_base = log_base
self.logger = setup_logger("TF2Runner", os.path.join(log_base, 'TF2Runner.log'))
@@ -5,7 +5,7 @@
import numpy as np
import tensorflow.compat.v2 as tf

from utils import graph
from remat.core import dfgraph
from utils.setup_logger import setup_logger

try:
@@ -199,6 +199,6 @@ def extract_graph_from_keras(mod: tf.keras.models.Model,
total_params = sum(count_params_keras(mod))
total_mem_params = total_params * MEMORY_MULTIPLIER

return graph.Graph(args=args, v=vfwd + [loss_node_idx] + vback, vfwd_map=vfwd_map,
vloss=loss_node_idx, cost_cpu=costs, cost_ram=mems, node_names=names,
cost_ram_parameters=total_mem_params)
return dfgraph.DFGraph(args=args, v=vfwd + [loss_node_idx] + vback, vfwd_map=vfwd_map,
vloss=loss_node_idx, cost_cpu=costs, cost_ram=mems, node_names=names,
cost_ram_parameters=total_mem_params)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
@@ -5,108 +5,20 @@

import numpy as np
import pandas as pd
import ray

from remat.core.dfgraph import DFGraph
from remat.core.solvers.common import setup_implied_s_backwards, gen_s_matrix_fixed_checkpoints, solve_r_opt
from solvers.solver_ilp import ILPSolver
from solvers.util import gen_s_matrix_fixed_checkpoints, setup_implied_s_backwards
from utils.graph import Graph
from utils.setup_logger import setup_logger

SOLVER_DTYPE = np.int


class CheckpointSolver:
@staticmethod
def solve_r_opt(G: Graph, S: np.ndarray):
"""Find the optimal recomputation pattern given caching decisions.
Given S, E = [(i, j)] where node j depends on the result of node i,
find R that minimizes cost, satisfies constraints. Assumes recomputation
costs are nonnegative.
NOTE: Does NOT check if memory limits are exceeded.
Enforcing R[t,i] != S[t,i] does not seem to be necessary.
"""
T = S.shape[0]
assert S.shape[1] == T

R = np.eye(T, dtype=S.dtype) # Enforce R_t,t = 1
# Enforce S_{t+1,v} <= S_{t,v} + R_{t,v},
# i.e. R_{t,v} >= S_{t+1,v} - S_{t,v}
S_diff = S[1:] - S[:-1]
R[:-1] = R[:-1] | (R[:-1] < S_diff)
# Create reverse adjacency list (child -> parents, i.e. node -> dependencies)
adj = [[] for v in range(T)]
for (u, v) in G.edge_list:
adj[v].append(u)
# Enforce R_{t,v} <= R_{t,u} + S_{t,u} for all (u, v) \in E
for t in range(T):
for v in range(t, -1, -1):
for u in adj[v]:
if R[t, v] > R[t, u] + S[t, u]:
R[t, u] = 1
return R

@staticmethod
def schedule_checkpoint_all(g: Graph):
"""Checkpoint only one node between stages"""
S = np.zeros((g.size, g.size), dtype=SOLVER_DTYPE)
S = gen_s_matrix_fixed_checkpoints(g, g.vfwd)
R = CheckpointSolver.solve_r_opt(g, S)
return R, S

@staticmethod
def schedule_checkpoint_last_node(g: Graph):
"""Checkpoint only one node between stages"""
S = np.zeros((g.size, g.size), dtype=SOLVER_DTYPE)
np.fill_diagonal(S[1:], 1)
R = CheckpointSolver.solve_r_opt(g, S)
return R, S

@staticmethod
def schedule_greedy_chen16(g: Graph, segment_mem_B: int, use_actuation_points: bool):
C = g.checkpoint_set if use_actuation_points else g.checkpoint_set_all
temp = 0
x = 0
# y = 0 # FIXME: y seems to have no function
checkpoints = set()
for v in g.topological_order_fwd:
temp += g.cost_ram[v]
if v in C and temp > segment_mem_B:
x += g.cost_ram[v]
# y = max(y, temp)
temp = 0
checkpoints.add(v)
S = gen_s_matrix_fixed_checkpoints(g, checkpoints)
R = CheckpointSolver.solve_r_opt(g, S)
return R, S

@staticmethod
@ray.remote(num_cpus=1, num_return_vals=2)
def remote_schedule_greedy_chen16(g: Graph, segment_mem_B: int, use_actuation_points: bool):
return CheckpointSolver.schedule_greedy_chen16(g, segment_mem_B, use_actuation_points)

@staticmethod
def schedule_sqrtn_chen16(g: Graph, use_actuation_points: bool):
C = g.checkpoint_set if use_actuation_points else g.checkpoint_set_all
k = int(math.sqrt(len(C)))
checkpoints = [v for idx, v in enumerate(C) if (idx + 1) % k == 0]
S = gen_s_matrix_fixed_checkpoints(g, set(checkpoints))
R = CheckpointSolver.solve_r_opt(g, S)
return R, S

@staticmethod
@ray.remote(num_cpus=1, num_return_vals=2)
def remote_schedule_sqrtn_chen16(g: Graph, use_actuation_points: bool):
return CheckpointSolver.schedule_sqrtn_chen16(g, use_actuation_points)

@staticmethod
def schedule_checkpoint_all_ap(g: Graph):
S = gen_s_matrix_fixed_checkpoints(g, g.checkpoint_set)
R = CheckpointSolver.solve_r_opt(g, S)
return R, S
def solve_r_opt(G: DFGraph, S: np.ndarray):
return solve_r_opt(G, S)

@staticmethod
def schedule_ilp_gurobi(g: Graph, budget: int, seed_s: np.ndarray = None, approx: bool = True, time_limit=None,
def schedule_ilp_gurobi(g: DFGraph, budget: int, seed_s: np.ndarray = None, approx: bool = True, time_limit=None,
log_file=None, print_to_console=True, model_file=None,
remote=False, eps_noise=0.01, solver_cores=1):
"""
@@ -147,7 +59,7 @@ def schedule_ilp_gurobi(g: Graph, budget: int, seed_s: np.ndarray = None, approx
return sol, return_vals

@staticmethod
def schedule_griewank(g: Graph, budget: int):
def schedule_griewank(g: DFGraph, budget: int):
S = np.zeros((g.size, g.size), dtype=np.int32)
S = setup_implied_s_backwards(g, S)
np.fill_diagonal(S[1:], 1)

0 comments on commit d0da9ca

Please sign in to comment.
You can’t perform that action at this time.