Skip to content

Commit

Permalink
Add saving and restoring from checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
pmandera committed Sep 6, 2017
1 parent 74e791b commit 0d2692e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ build/
*.c
*.so

checkpoints/
plots/
28 changes: 24 additions & 4 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def plot_grid(grid, fout):
help='Directory to save plots.')
parser.add_argument('--plot_dir', default='./plots/',
help='Directory to save plots.')

parser.add_argument('--checkpoint',
help='Start from checkpoint, not from empty grid.')
parser.add_argument('--checkpoint_every', type=float,
help='Directory to save checkpoints.')
parser.add_argument('--checkpoint_dir', default='./checkpoints/',
help='Directory to save checkpoint.')

parser.add_argument('--verbose', action='store_true',
help='Inform about progress.')

Expand All @@ -46,14 +54,26 @@ def plot_grid(grid, fout):
else:
plot_every = None

grid_center = grid_size/2
if args.checkpoint_every is not None:
checkpoint_every = int(args.checkpoint_every)
else:
checkpoint_every = None

if args.checkpoint is None:
sandpile = Sandpile(grid_size, grid_size)
else:
sandpile = Sandpile.load(args.checkpoint)

sandpile = Sandpile(grid_size, grid_size)
x_grid_center = sandpile.x_size/2
y_grid_center = sandpile.y_size/2

def plot_sandpile(sandpile):
plot_grid(sandpile.grid,
args.plot_dir + '/sandpile-{:012d}.png'.format(
sandpile.n_dropped))

sandpile.drop_sand(grid_center, grid_center, n_steps, verbose=args.verbose,
report_every=plot_every, report_func=plot_sandpile)
sandpile.drop_sand(x_grid_center, x_grid_center, n_steps,
verbose=args.verbose,
report_every=plot_every, report_func=plot_sandpile,
checkpoint_every=checkpoint_every,
checkpoint_dir=args.checkpoint_dir)
28 changes: 24 additions & 4 deletions sandpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

from __future__ import print_function

import os
import gzip
import cPickle as pickle

import numpy as np

try:
Expand Down Expand Up @@ -66,9 +70,19 @@ def __init__(self, x_size, y_size, store_avalanche_sizes=False):
self.grid = np.zeros((x_size, y_size))

def drop_sand(self, x_loc=None, y_loc=None, n=1, verbose=False,
report_every=None, report_func=None):
report_every=None, report_func=None,
checkpoint_every=None, checkpoint_dir=None):
for step in range(n):

if report_every is not None and step % report_every == 0:
report_func(self)

if checkpoint_every is not None and step % checkpoint_every == 0:
path = os.path.join(
checkpoint_dir,
'sandpile-{:012d}.pckl.gz'.format(self.n_dropped))
self.save(path)

self.grid[x_loc, y_loc] += 1
self.n_dropped += 1

Expand All @@ -77,9 +91,6 @@ def drop_sand(self, x_loc=None, y_loc=None, n=1, verbose=False,
if self.store_avalanche_sizes:
self.avalanche_sizes.append(avalanche_size)

if report_every is not None and step % report_every == 0:
report_func(self)

if verbose:
progress(step, n, self.n_dropped, avalanche_size)

Expand All @@ -91,3 +102,12 @@ def grid_size(self):

def grains_per_dot(self):
return self.grid.sum()/self.grid_size()

def save(self, fname):
with gzip.open(fname, 'w') as fout:
pickle.dump(self, fout)

@staticmethod
def load(fname):
with gzip.open(fname, 'rb') as fout:
return pickle.load(fout)

0 comments on commit 0d2692e

Please sign in to comment.