Skip to content

Commit

Permalink
ExternSprintDataset, fix/cleanup for Python3
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed May 13, 2024
1 parent aa41c4a commit ece6efc
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions returnn/datasets/sprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def _start_child(self, epoch, get_dim_only=False):

try:
init_signal, (input_dim, output_dim, num_segments) = self._read_next_raw()
assert init_signal == b"init"
assert init_signal == "init"
assert isinstance(input_dim, int) and isinstance(output_dim, int)
# Ignore num_segments. It can be totally different than the real number of sequences.
self.set_dimensions(input_dim, output_dim)
Expand Down Expand Up @@ -922,9 +922,7 @@ def _read_next_raw(self):
:return: (data_type, args)
:rtype: (str, object)
"""
# encoding is for converting Python2 strings to Python3.
# Cannot use utf8 because Numpy will also encode the data as strings and there we need it as bytes.
data_type, args = util.read_pickled_object(self.pipe_c2p[0], encoding="bytes")
data_type, args = util.read_pickled_object(self.pipe_c2p[0])
return data_type, args

def _join_child(self, wait=True, expected_exit_status=None):
Expand Down Expand Up @@ -974,7 +972,7 @@ def _reader_thread_proc(self, child_pid, epoch):
if self.python_exit or not self.child_pid:
break

if data_type == b"data":
if data_type == "data":
seq_count += 1
segment_name, features, targets = args
if segment_name is not None:
Expand All @@ -987,7 +985,7 @@ def _reader_thread_proc(self, child_pid, epoch):
numpy_copy_and_set_unused(targets),
segment_name=segment_name,
)
elif data_type == b"exit":
elif data_type == "exit":
have_seen_the_whole = True
break
else:
Expand Down Expand Up @@ -1148,15 +1146,15 @@ def read(self, name):
"""
res = self.sprint_cache.read(name, typ=self.type)
if self.type == "align":
for (t, a, s, w) in res:
for t, a, s, w in res:
assert w == 1, "soft alignment not supported"
label_seq = numpy.array(
[self.allophone_labeling.get_label_idx(a, s) for (t, a, s, w) in res], dtype=self.dtype
)
assert label_seq.shape == (len(res),)
return label_seq
elif self.type == "align_raw":
for (t, a, s, w) in res:
for t, a, s, w in res:
assert w == 1, "soft alignment not supported"
label_seq = numpy.array(
[self.allophone_labeling.state_tying_by_allo_state_idx[a] for (t, a, s, w) in res], dtype=self.dtype
Expand Down

0 comments on commit ece6efc

Please sign in to comment.