Skip to content

Commit

Permalink
reduced code duplication in Model class
Browse files Browse the repository at this point in the history
__sum_up_blobs() and  __sum_up_blobs_speedup() are replaced by _sum_up_blobs() and __compute_start_stop()
  • Loading branch information
gregordecristoforo committed Jan 24, 2022
1 parent 213591a commit 55cf9cf
Showing 1 changed file with 28 additions and 41 deletions.
69 changes: 28 additions & 41 deletions blobmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,8 @@ def make_realization(
shape=(self.__geometry.Ny, self.__geometry.Nx, self.__geometry.t.size)
)

if speed_up:
self.__sum_up_blobs_speedup(error, labels, label_border)
else:
self.__sum_up_blobs(labels, label_border)
for b in tqdm(self.__blobs, desc="Summing up Blobs"):
self.__sum_up_blobs(b, speed_up, error, labels, label_border)

ds = self.__create_xr_dataset(labels)

Expand Down Expand Up @@ -163,29 +161,27 @@ def __create_xr_dataset(self, labels) -> xr.Dataset:

return ds

def __sum_up_blobs(self, labels: bool, label_border: float) -> None:
for b in tqdm(self.__blobs, desc="Summing up Blobs"):
__single_blob = b.discretize_blob(
x=self.__geometry.x_matrix,
y=self.__geometry.y_matrix,
t=self.__geometry.t_matrix,
periodic_y=self.__geometry.periodic_y,
Ly=self.__geometry.Ly,
)
self.__density += __single_blob
if labels:
__max_amplitudes = np.max(__single_blob, axis=(0, 1))
__max_amplitudes[__max_amplitudes == 0] = np.inf
self.__labels_field[
__single_blob >= __max_amplitudes * label_border
] = 1

def __sum_up_blobs_speedup(
self, error: float, labels: bool, label_border: float
) -> None:
# speedup implemeted for exponential pulses
# can also be used for gaussian pulses since they converge faster than exponential pulses
for b in tqdm(self.__blobs, desc="Summing up Blobs"):
def __sum_up_blobs(
self, b: Blob, speed_up: bool, error: float, labels: bool, label_border: float
):
__start, __stop = self.__compute_start_stop(b, speed_up, error)
__single_blob = b.discretize_blob(
x=self.__geometry.x_matrix[:, :, __start:__stop],
y=self.__geometry.y_matrix[:, :, __start:__stop],
t=self.__geometry.t_matrix[:, :, __start:__stop],
periodic_y=self.__geometry.periodic_y,
Ly=self.__geometry.Ly,
)
self.__density[:, :, __start:__stop] += __single_blob
if labels:
__max_amplitudes = np.max(__single_blob, axis=(0, 1))
__max_amplitudes[__max_amplitudes == 0] = np.inf
self.__labels_field[:, :, __start:__stop][
__single_blob >= __max_amplitudes * label_border
] = 1

def __compute_start_stop(self, b: Blob, speed_up: bool, error: float):
if speed_up:
__start = int(b.t_init / self.__geometry.dt)
if b.v_x == 0:
__stop = self.__geometry.t.size
Expand All @@ -199,17 +195,8 @@ def __sum_up_blobs_speedup(
/ (b.v_x * self.__geometry.dt)
),
)
__single_blob = b.discretize_blob(
x=self.__geometry.x_matrix[:, :, __start:__stop],
y=self.__geometry.y_matrix[:, :, __start:__stop],
t=self.__geometry.t_matrix[:, :, __start:__stop],
periodic_y=self.__geometry.periodic_y,
Ly=self.__geometry.Ly,
)
self.__density[:, :, __start:__stop] += __single_blob
if labels:
__max_amplitudes = np.max(__single_blob, axis=(0, 1))
__max_amplitudes[__max_amplitudes == 0] = np.inf
__tmp = np.copy(self.__labels_field[:, :, __start:__stop])
__tmp[__single_blob >= __max_amplitudes * label_border] = 1
self.__labels_field[:, :, __start:__stop] = __tmp
else:
__start = 0
__stop = self.__geometry.t.size

return __start, __stop

0 comments on commit 55cf9cf

Please sign in to comment.