Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor core of package to remat package #1

Merged
merged 22 commits into from
Sep 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,27 @@ env/
.idea/
.envrc

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Gurobi
*.lp
*.mps
Expand Down
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,26 @@
# Optimal Checkpointing
# Optimal tensor rematerialization
`remat` is a package to compute schedules for rematerializing tensors in DFGraphs (tensor dataflow graphs).

# Installation
```bash
$ git clone https://github.com/parasj/tensor-remat.git
$ cd tensor-remat
$ pip install -e .
$ py.test
```

If you are evaluating on a GPU instance, you can install `tensorflow-gpu` as a dependency in order to enable GPU support.

# Project structure
```
.
├── experiments stores one-off plotting scripts for paper
├── remat main python package
│   ├── core
│   │   ├── graph.py DFGraph data structure
│   │   ├── schedule.py Schedule defition of concrete evaluation order
│   │   ├── solvers Package containing solver implementations
│   │   └── utils
│   └── tensorflow2 Tensorflow 2.0 integration (extraction and execution)
└── tests
```
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas as pd

from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy


def extract_params():
Expand Down
2 changes: 1 addition & 1 deletion src/eval_runner.py → _deprecated_src/eval_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from evaluation.budget_sweep import eval_budget_sweep
from evaluation.maximize_batch_size import eval_maximize_batch_size
from evaluation.solve_time_plot import eval_solve_time
from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from integration.tf2.extraction import MODEL_NAMES
from utils.redis import RedisCache
from utils.setup_logger import setup_logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from evaluation.util.cost_model import CostModel
from evaluation.util.evaluation_utils import prefix_min_np, result_dict_to_dataframe, RSResultDict, get_futures
from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from evaluation.util.solver_utils import remote_evaluation_iteration
from integration.tf2.TF2ExtractorParams import TF2ExtractorParams
from integration.tf2.extraction import get_keras_model, pretty_platform_name, platform_memory, \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
import functools
import os
from typing import Dict, Optional, List, Tuple, Iterable
from typing import Optional, List, Tuple
from tqdm import tqdm

import pandas
import ray
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import numpy as np

from evaluation.util.evaluation_utils import result_dict_to_dataframe, RSResultDict
from evaluation.util.solver_utils import remote_evaluation_iteration
from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from integration.tf2.TF2ExtractorParams import TF2ExtractorParams
from integration.tf2.extraction import get_keras_model, platform_memory
from integration.tf2.TF2Runner import TF2Runner
from integration.tf2.misc import categorical_cross_entropy, random_batch
from solvers.result import RSResult
from utils.redis import RedisCache
from utils.setup_logger import setup_logger
from utils.timer import Timer
from remat.core.utils.timer import Timer


def plot_solver_result(results: pandas.DataFrame, plot_file: str):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,12 @@
from __future__ import division

import functools
import math
import os
from typing import List, Dict, Iterable

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import ray
import seaborn as sns
from tqdm import tqdm

from evaluation.util.cost_model import CostModel
from evaluation.util.evaluation_utils import prefix_min_np, result_dict_to_dataframe, RSResultDict
from evaluation.util.solver_utils import remote_evaluation_iteration
from evaluation.util.solve_strategy import SolveStrategy

from integration.tf2.TF2ExtractorParams import TF2ExtractorParams
from integration.tf2.extraction import get_keras_model, pretty_model_name, pretty_platform_name, platform_memory, \
CHAIN_GRAPH_MODELS
from integration.tf2.misc import categorical_cross_entropy
from integration.tf2.extraction import get_keras_model
from solvers.result import RSResult
from solvers.solver_ilp_maxbs import MaxBatchILPSolver
from utils.redis import RedisCache
from utils.setup_logger import setup_logger

GB = 1000 * 1000 * 1000
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import math
import sys
from typing import List, Dict, Iterable
from typing import List, Dict

import numpy as np
import pandas
import ray
from tqdm import tqdm

from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from solvers.result import RSResult

RSResultDict = Dict[SolveStrategy, List[RSResult]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import ray

import utils.redis
from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from integration.tf2.TF2ExtractorParams import TF2ExtractorParams
from remat.core.solvers.strategy_checkpoint_all import solve_checkpoint_all
from remat.core.solvers.strategy_checkpoint_last import solve_checkpoint_last_node
from solvers.result import PartialRSResult, RSResult
from solvers.solver import CheckpointSolver
from utils.setup_logger import setup_logger
from utils.timer import Timer
from remat.core.utils.timer import Timer

RAY_OVERPROVISON_PCT = 0.75

Expand All @@ -24,9 +26,9 @@ def tf2_solve(solve_params: TF2ExtractorParams, solve_strategy: SolveStrategy, b
elif solve_strategy == SolveStrategy.CHEN_SQRTN_NOAP:
sol = CheckpointSolver.schedule_sqrtn_chen16(solve_params.g, use_actuation_points=False)
elif solve_strategy == SolveStrategy.CHECKPOINT_LAST_NODE:
sol = CheckpointSolver.schedule_checkpoint_last_node(solve_params.g)
sol = solve_checkpoint_last_node(solve_params.g)
elif solve_strategy == SolveStrategy.CHECKPOINT_ALL:
sol = CheckpointSolver.schedule_checkpoint_all(solve_params.g)
sol = solve_checkpoint_all(solve_params.g)
elif solve_strategy == SolveStrategy.CHECKPOINT_ALL_AP:
sol = CheckpointSolver.schedule_checkpoint_all_ap(solve_params.g)
elif solve_strategy == SolveStrategy.CHEN_GREEDY:
Expand Down
6 changes: 3 additions & 3 deletions src/execute_one.py → _deprecated_src/execute_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import dotenv

from evaluation.util.solve_strategy import SolveStrategy
from remat.core.solvers.strategy_enum import SolveStrategy
from integration.tf2.extraction import MODEL_NAMES
from evaluation.test_execution import execute_one
from evaluation.eval_execution import execute_one
from utils.setup_logger import setup_logger


Expand Down Expand Up @@ -91,7 +91,7 @@ def run_single_model(args):


def get_allstrat_ram(args):
from evaluation.test_execution import get_solution_to_evaluate
from evaluation.eval_execution import get_solution_to_evaluate

log_base = os.path.join("data", "get_allstrat_ram",
f"{args.platform}_{args.model_name}_{args.model_version}_{args.batch_size}_{args.input_shape}_{args.strategy}_{args.buffer_mem_mb}_gradless_eagerfalse")
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tensorflow as tf

from integration.tf2.extraction import extract_graph_from_keras
from utils.graph import Graph
from remat.core.dfgraph import DFGraph
from utils.setup_logger import setup_logger


Expand All @@ -24,8 +24,8 @@ def __init__(self, keras_model: tf.keras.models.Model, batch_size: int = 1,
self.output_shape = list(keras_model.output_shape)
self.output_shape[0] = batch_size

self.g: Graph = extract_graph_from_keras(keras_model, batch_size=batch_size, loss_cpu_cost=loss_cpu_cost,
loss_ram_cost=loss_ram_cost, costs_np=costs_np)
self.g: DFGraph = extract_graph_from_keras(keras_model, batch_size=batch_size, loss_cpu_cost=loss_cpu_cost,
loss_ram_cost=loss_ram_cost, costs_np=costs_np)
self.logger.info(f"Extracted graph {keras_model.name} with {self.g.size} nodes")

if self.log_base is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import itertools
import os

import tensorflow as tf

from integration.tf2.misc import categorical_cross_entropy
from solvers.scheduler import AllocateRegister, DeallocateRegister, OperatorEvaluation, Schedule
from utils.graph import Graph
from remat.core.schedule import OperatorEvaluation, AllocateRegister, DeallocateRegister, Schedule
from remat.core.dfgraph import DFGraph
from utils.setup_logger import setup_logger


class TF2Runner:
def __init__(self, keras_model: tf.keras.models.Model, g: Graph, schedule: Schedule,
def __init__(self, keras_model: tf.keras.models.Model, g: DFGraph, schedule: Schedule,
loss_fn=categorical_cross_entropy, eager: bool = True, log_base: str = None, debug=False, batch_size=None):
self.log_base = log_base
self.logger = setup_logger("TF2Runner", os.path.join(log_base, 'TF2Runner.log'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import tensorflow.compat.v2 as tf

from utils import graph
from remat.core import dfgraph
from utils.setup_logger import setup_logger

try:
Expand Down Expand Up @@ -199,6 +199,6 @@ def extract_graph_from_keras(mod: tf.keras.models.Model,
total_params = sum(count_params_keras(mod))
total_mem_params = total_params * MEMORY_MULTIPLIER

return graph.Graph(args=args, v=vfwd + [loss_node_idx] + vback, vfwd_map=vfwd_map,
vloss=loss_node_idx, cost_cpu=costs, cost_ram=mems, node_names=names,
cost_ram_parameters=total_mem_params)
return dfgraph.DFGraph(args=args, v=vfwd + [loss_node_idx] + vback, vfwd_map=vfwd_map,
vloss=loss_node_idx, cost_cpu=costs, cost_ram=mems, node_names=names,
cost_ram_parameters=total_mem_params)
File renamed without changes.
File renamed without changes.
File renamed without changes.
100 changes: 6 additions & 94 deletions src/solvers/solver.py → _deprecated_src/solvers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,108 +5,20 @@

import numpy as np
import pandas as pd
import ray

from remat.core.dfgraph import DFGraph
from remat.core.solvers.common import setup_implied_s_backwards, gen_s_matrix_fixed_checkpoints, solve_r_opt
from solvers.solver_ilp import ILPSolver
from solvers.util import gen_s_matrix_fixed_checkpoints, setup_implied_s_backwards
from utils.graph import Graph
from utils.setup_logger import setup_logger

SOLVER_DTYPE = np.int


class CheckpointSolver:
@staticmethod
def solve_r_opt(G: Graph, S: np.ndarray):
"""Find the optimal recomputation pattern given caching decisions.
Given S, E = [(i, j)] where node j depends on the result of node i,
find R that minimizes cost, satisfies constraints. Assumes recomputation
costs are nonnegative.

NOTE: Does NOT check if memory limits are exceeded.
Enforcing R[t,i] != S[t,i] does not seem to be necessary.
"""
T = S.shape[0]
assert S.shape[1] == T

R = np.eye(T, dtype=S.dtype) # Enforce R_t,t = 1
# Enforce S_{t+1,v} <= S_{t,v} + R_{t,v},
# i.e. R_{t,v} >= S_{t+1,v} - S_{t,v}
S_diff = S[1:] - S[:-1]
R[:-1] = R[:-1] | (R[:-1] < S_diff)
# Create reverse adjacency list (child -> parents, i.e. node -> dependencies)
adj = [[] for v in range(T)]
for (u, v) in G.edge_list:
adj[v].append(u)
# Enforce R_{t,v} <= R_{t,u} + S_{t,u} for all (u, v) \in E
for t in range(T):
for v in range(t, -1, -1):
for u in adj[v]:
if R[t, v] > R[t, u] + S[t, u]:
R[t, u] = 1
return R

@staticmethod
def schedule_checkpoint_all(g: Graph):
"""Checkpoint only one node between stages"""
S = np.zeros((g.size, g.size), dtype=SOLVER_DTYPE)
S = gen_s_matrix_fixed_checkpoints(g, g.vfwd)
R = CheckpointSolver.solve_r_opt(g, S)
return R, S

@staticmethod
def schedule_checkpoint_last_node(g: Graph):
"""Checkpoint only one node between stages"""
S = np.zeros((g.size, g.size), dtype=SOLVER_DTYPE)
np.fill_diagonal(S[1:], 1)
R = CheckpointSolver.solve_r_opt(g, S)
return R, S

@staticmethod
def schedule_greedy_chen16(g: Graph, segment_mem_B: int, use_actuation_points: bool):
C = g.checkpoint_set if use_actuation_points else g.checkpoint_set_all
temp = 0
x = 0
# y = 0 # FIXME: y seems to have no function
checkpoints = set()
for v in g.topological_order_fwd:
temp += g.cost_ram[v]
if v in C and temp > segment_mem_B:
x += g.cost_ram[v]
# y = max(y, temp)
temp = 0
checkpoints.add(v)
S = gen_s_matrix_fixed_checkpoints(g, checkpoints)
R = CheckpointSolver.solve_r_opt(g, S)
return R, S

@staticmethod
@ray.remote(num_cpus=1, num_return_vals=2)
def remote_schedule_greedy_chen16(g: Graph, segment_mem_B: int, use_actuation_points: bool):
return CheckpointSolver.schedule_greedy_chen16(g, segment_mem_B, use_actuation_points)

@staticmethod
def schedule_sqrtn_chen16(g: Graph, use_actuation_points: bool):
C = g.checkpoint_set if use_actuation_points else g.checkpoint_set_all
k = int(math.sqrt(len(C)))
checkpoints = [v for idx, v in enumerate(C) if (idx + 1) % k == 0]
S = gen_s_matrix_fixed_checkpoints(g, set(checkpoints))
R = CheckpointSolver.solve_r_opt(g, S)
return R, S

@staticmethod
@ray.remote(num_cpus=1, num_return_vals=2)
def remote_schedule_sqrtn_chen16(g: Graph, use_actuation_points: bool):
return CheckpointSolver.schedule_sqrtn_chen16(g, use_actuation_points)

@staticmethod
def schedule_checkpoint_all_ap(g: Graph):
S = gen_s_matrix_fixed_checkpoints(g, g.checkpoint_set)
R = CheckpointSolver.solve_r_opt(g, S)
return R, S
def solve_r_opt(G: DFGraph, S: np.ndarray):
return solve_r_opt(G, S)

@staticmethod
def schedule_ilp_gurobi(g: Graph, budget: int, seed_s: np.ndarray = None, approx: bool = True, time_limit=None,
def schedule_ilp_gurobi(g: DFGraph, budget: int, seed_s: np.ndarray = None, approx: bool = True, time_limit=None,
log_file=None, print_to_console=True, model_file=None,
remote=False, eps_noise=0.01, solver_cores=1):
"""
Expand Down Expand Up @@ -147,7 +59,7 @@ def schedule_ilp_gurobi(g: Graph, budget: int, seed_s: np.ndarray = None, approx
return sol, return_vals

@staticmethod
def schedule_griewank(g: Graph, budget: int):
def schedule_griewank(g: DFGraph, budget: int):
S = np.zeros((g.size, g.size), dtype=np.int32)
S = setup_implied_s_backwards(g, S)
np.fill_diagonal(S[1:], 1)
Expand Down
Loading