diff --git a/fds/fds.py b/fds/fds.py index 5dd0757..7b594a9 100644 --- a/fds/fds.py +++ b/fds/fds.py @@ -112,19 +112,20 @@ def continue_shadowing( time_dil = TimeDilation(run, u0, parameter, run_id, simultaneous_runs, interprocess) - V = time_dil.project(V) - v = time_dil.project(v) - - u0, V, v, J0, G, g = run_segment( - run, u0, V, v, parameter, i, steps_per_segment, - epsilon, simultaneous_runs, interprocess, get_host_dir=get_host_dir, - compute_outputs=compute_outputs, spawn_compute_job=spawn_compute_job) - - J_hist.append(J0) - G_lss.append(G) - g_lss.append(g) + for i in range(lss.K_segments(), num_segments): + V = time_dil.project(V) + v = time_dil.project(v) - for i in range(lss.K_segments() + 1, num_segments + 1): + # extra outputs to compute + compute_outputs = [] + # run all segments + u0, V, v, J0, G, g = run_segment( + run, u0, V, v, parameter, i, steps_per_segment, + epsilon, simultaneous_runs, interprocess, get_host_dir=get_host_dir, + compute_outputs=compute_outputs, spawn_compute_job=spawn_compute_job) + J_hist.append(J0) + G_lss.append(G) + g_lss.append(g) # time dilation contribution run_id = 'time_dilation_{0:02d}'.format(i) @@ -136,22 +137,10 @@ def continue_shadowing( G_dil.append(time_dil.contribution(V)) g_dil.append(time_dil.contribution(v)) - V = time_dil.project(V) - v = time_dil.project(v) - V, v = lss.checkpoint(V, v) - # extra outputs to compute - compute_outputs = [lss.Rs[-1], lss.bs[-1], G_dil[-1], g_dil[-1]] - - # run all segments - if i < num_segments: - u0, V, v, J0, G, g = run_segment( - run, u0, V, v, parameter, i, steps_per_segment, - epsilon, simultaneous_runs, interprocess, get_host_dir=get_host_dir, - compute_outputs=compute_outputs, spawn_compute_job=spawn_compute_job) - else: - run_compute(compute_outputs, spawn_compute_job=spawn_compute_job, interprocess=interprocess) + compute_outputs = [lss.Rs[-1], lss.bs[-1], G_dil[-1], g_dil[-1]] + run_compute(compute_outputs, spawn_compute_job=spawn_compute_job, interprocess=interprocess) for output in [lss.Rs, lss.bs, G_dil, g_dil]: output[-1] = output[-1].field @@ -160,14 +149,9 @@ def continue_shadowing( print(lss_gradient(checkpoint)) sys.stdout.flush() - if checkpoint_path and (i) % checkpoint_interval == 0: + if checkpoint_path and (i+1) % checkpoint_interval == 0: save_checkpoint(checkpoint_path, checkpoint) - - if i < num_segments: - J_hist.append(J0) - G_lss.append(G) - g_lss.append(g) - + if return_checkpoint: return checkpoint else: