Skip to content

Commit

Permalink
testing checkpoint fix with multiple run compute in shadowing loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaitanya Talnikar committed Jan 29, 2017
1 parent 43e51db commit e0ca93a
Showing 1 changed file with 17 additions and 33 deletions.
50 changes: 17 additions & 33 deletions fds/fds.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,20 @@ def continue_shadowing(
time_dil = TimeDilation(run, u0, parameter, run_id,
simultaneous_runs, interprocess)

V = time_dil.project(V)
v = time_dil.project(v)

u0, V, v, J0, G, g = run_segment(
run, u0, V, v, parameter, i, steps_per_segment,
epsilon, simultaneous_runs, interprocess, get_host_dir=get_host_dir,
compute_outputs=compute_outputs, spawn_compute_job=spawn_compute_job)

J_hist.append(J0)
G_lss.append(G)
g_lss.append(g)
for i in range(lss.K_segments(), num_segments):
V = time_dil.project(V)
v = time_dil.project(v)

for i in range(lss.K_segments() + 1, num_segments + 1):
# extra outputs to compute
compute_outputs = []
# run all segments
u0, V, v, J0, G, g = run_segment(
run, u0, V, v, parameter, i, steps_per_segment,
epsilon, simultaneous_runs, interprocess, get_host_dir=get_host_dir,
compute_outputs=compute_outputs, spawn_compute_job=spawn_compute_job)
J_hist.append(J0)
G_lss.append(G)
g_lss.append(g)

# time dilation contribution
run_id = 'time_dilation_{0:02d}'.format(i)
Expand All @@ -136,22 +137,10 @@ def continue_shadowing(
G_dil.append(time_dil.contribution(V))
g_dil.append(time_dil.contribution(v))

V = time_dil.project(V)
v = time_dil.project(v)

V, v = lss.checkpoint(V, v)
# extra outputs to compute
compute_outputs = [lss.Rs[-1], lss.bs[-1], G_dil[-1], g_dil[-1]]

# run all segments
if i < num_segments:
u0, V, v, J0, G, g = run_segment(
run, u0, V, v, parameter, i, steps_per_segment,
epsilon, simultaneous_runs, interprocess, get_host_dir=get_host_dir,
compute_outputs=compute_outputs, spawn_compute_job=spawn_compute_job)
else:
run_compute(compute_outputs, spawn_compute_job=spawn_compute_job, interprocess=interprocess)

compute_outputs = [lss.Rs[-1], lss.bs[-1], G_dil[-1], g_dil[-1]]
run_compute(compute_outputs, spawn_compute_job=spawn_compute_job, interprocess=interprocess)
for output in [lss.Rs, lss.bs, G_dil, g_dil]:
output[-1] = output[-1].field

Expand All @@ -160,14 +149,9 @@ def continue_shadowing(
print(lss_gradient(checkpoint))
sys.stdout.flush()

if checkpoint_path and (i) % checkpoint_interval == 0:
if checkpoint_path and (i+1) % checkpoint_interval == 0:
save_checkpoint(checkpoint_path, checkpoint)

if i < num_segments:
J_hist.append(J0)
G_lss.append(G)
g_lss.append(g)


if return_checkpoint:
return checkpoint
else:
Expand Down

0 comments on commit e0ca93a

Please sign in to comment.