In [6]:
import torch
from torch.nn.utils.rnn import pad_sequence


In [49]:
class CollateFnV2:
    def __init__(self, batch_first: bool, word_pad_idx: int, 
                 swipe_pad_idx: int = 0) -> None:
        self.word_pad_idx = word_pad_idx
        self.batch_first = batch_first
        self.swipe_pad_idx = swipe_pad_idx

    def _assert_encoder_in_type_and_shape(self, encoder_in_example):
        assert len(encoder_in_example) == 2
        for el in encoder_in_example:
            assert isinstance(el, torch.Tensor), \
                f"Expected torch.Tensor, got {type(el)}"
        
    def _is_encoder_input_tuple(self, batch):
        encoder_in_example = batch[0][0][0]
        if isinstance(encoder_in_example, tuple):
            self._assert_encoder_in_type_and_shape(encoder_in_example)
            return True
        elif isinstance(encoder_in_example, torch.Tensor):
            return False
        else:
            raise ValueError(f"Unknown type of encoder input {type(batch[0][0])}")
    
    def __call__(self, batch: list):
        """
        Given a List where each row is 
        ((encoder_in_sample, decoder_in_sample), decoder_out_sample) 
        returns a tuple of two elements:
        1. (encoder_in, decoder_in, swipe_pad_mask, word_pad_mask)
        2. decoder_out

        Arguments:
        ----------
        batch: list of tuples:
            ((encoder_in, dec_in_char_seq), dec_out_char_seq),
            where encoder_in may be a tuple of torch tensors
            (ex. ```(traj_feats, nearest_kb_tokens)```)
            or a single tensor (ex. ```nearest_kb_tokens```)


        Returns:
        --------
        transformer_in: tuple of torch tensors:
            1. (enc_in, dec_in, swipe_pad_mask, word_pad_mask),
                where enc_in can be either a single tensor or a tuple
                of two tensors (depends on type of input)
                Each element is a torch tensor of shape:
                - enc_in: either (curve_len, batch_size, n_feats) or
                    ((curve_len, batch_size, n_feats1), (curve_len, batch_size, n_feats2))
                - dec_in: (chars_seq_len - 1, batch_size)
                - swipe_pad_mask: (batch_size, curve_len)
                - word_pad_mask: (batch_size, chars_seq_len - 1, )
        """
        is_encoder_input_tuple = self._is_encoder_input_tuple(batch)
        dec_in_no_pad = []
        dec_out_no_pad = []

        encoder_in_no_pad = ([], []) if is_encoder_input_tuple else []

        for row in batch:
            x_smpl, decoder_out_smpl = row
            encoder_in_smpl, decoder_in_smpl = x_smpl
            if is_encoder_input_tuple:
                for i in range(2):
                    encoder_in_no_pad[i].append(encoder_in_smpl[i])
            else:
                encoder_in_no_pad.append(encoder_in)

            dec_in_no_pad.append(decoder_in_smpl)
            dec_out_no_pad.append(decoder_out_smpl)

        if is_encoder_input_tuple:
            encoder_in = tuple(pad_sequence(encoder_in_no_pad_i, batch_first=self.batch_first, 
                                       padding_value=self.swipe_pad_idx)
                          for encoder_in_no_pad_i in encoder_in_no_pad)
        else:
            encoder_in = pad_sequence(encoder_in_no_pad, batch_first=self.batch_first,
                                      padding_value=self.swipe_pad_idx)

        dec_out = pad_sequence(dec_out_no_pad, batch_first=self.batch_first,
                                        padding_value=self.word_pad_idx)
        
        dec_in = pad_sequence(dec_in_no_pad, batch_first=self.batch_first,
                                        padding_value=self.word_pad_idx)
        
        word_pad_mask = dec_in == self.word_pad_idx
        if not self.batch_first:
            word_pad_mask = word_pad_mask.T  # word_pad_mask is always batch first

        encoder_in_el = encoder_in[0] if is_encoder_input_tuple else encoder_in
        max_curve_len = encoder_in_el.shape[1] if self.batch_first else encoder_in_el.shape[0]
        encoder_in_no_pad_el = encoder_in_no_pad[0] if is_encoder_input_tuple else encoder_in_no_pad
        encoder_lens = torch.tensor([len(x) for x in encoder_in_no_pad_el])

        # Берем матрицу c len(encoder_lens) строками вида
        # [0, 1, ... , max_curve_len - 1].  Каждый элемент i-ой строки
        # сравниваем с длиной i-ой траектории.  Получится матрица, где True
        # только на позициях, больших, чем длина соответствующей траектории.
        # (batch_size, max_curve_len)
        encoder_pad_mask = torch.arange(max_curve_len).expand(
            len(encoder_lens), max_curve_len) >= encoder_lens.unsqueeze(1)
        
        transformer_in = (encoder_in, dec_in, encoder_pad_mask, word_pad_mask)
        
        return transformer_in, dec_out


In [8]:
class CollateFnV3:
    def __init__(self, batch_first: bool, word_pad_idx: int, swipe_pad_idx: int = 0) -> None:
        self.word_pad_idx = word_pad_idx
        self.batch_first = batch_first
        self.swipe_pad_idx = swipe_pad_idx

    def _assert_encoder_in_type_and_shape(self, encoder_in_example):
        assert len(encoder_in_example) == 2
        for el in encoder_in_example:
            assert isinstance(el, torch.Tensor), f"Expected torch.Tensor, got {type(el)}"
        
    def _is_encoder_input_tuple(self, batch):
        encoder_in_example = batch[0][0][0]
        if isinstance(encoder_in_example, tuple):
            self._assert_encoder_in_type_and_shape(encoder_in_example)
            return True
        elif isinstance(encoder_in_example, torch.Tensor):
            return False
        else:
            raise ValueError(f"Unknown type of encoder input {type(batch[0][0])}")
    
    def __call__(self, batch: list):
        """
        Given a List where each row is 
        ((encoder_in_sample, decoder_in_sample), decoder_out_sample) 
        returns a tuple of two elements:
        1. (encoder_in, decoder_in, swipe_pad_mask, word_pad_mask)
        2. decoder_out

        Arguments:
        ----------
        batch: list of tuples:
            ((encoder_in, dec_in_char_seq), dec_out_char_seq),
            where encoder_in may be a tuple of torch tensors
            (ex. ```(traj_feats, nearest_kb_tokens)```)
            or a single tensor (ex. ```nearest_kb_tokens```)


        Returns:
        --------
        transformer_in: tuple of torch tensors:
            1. (enc_in, dec_in, swipe_pad_mask, word_pad_mask),
                where enc_in can be either a single tensor or a tuple
                of two tensors (depends on type of input)
                Each element is a torch tensor of shape:
                - enc_in: either (curve_len, batch_size, n_feats) or
                    ((curve_len, batch_size, n_feats1), (curve_len, batch_size, n_feats2))
                - dec_in: (chars_seq_len - 1, batch_size)
                - swipe_pad_mask: (batch_size, curve_len)
                - word_pad_mask: (batch_size, chars_seq_len - 1, )
        """
        is_encoder_input_tuple = self._is_encoder_input_tuple(batch)

        encoder_in_samples, dec_in_samples, dec_out_samples = zip(*((x_smpl[0], x_smpl[1], decoder_out_smpl) for x_smpl, decoder_out_smpl in batch))

        if is_encoder_input_tuple:
            encoder_in_samples_0, encoder_in_samples_1 = zip(*encoder_in_samples)
            encoder_in_no_pad = (
                list(encoder_in_samples_0),
                list(encoder_in_samples_1)
            )
        else:
            encoder_in_no_pad = encoder_in_samples

        encoder_in = (
            tuple(
                pad_sequence(enc_in_no_pad_i, batch_first=self.batch_first, padding_value=self.swipe_pad_idx)
                for enc_in_no_pad_i in encoder_in_no_pad
            )
            if is_encoder_input_tuple else
            pad_sequence(encoder_in_no_pad, batch_first=self.batch_first, padding_value=self.swipe_pad_idx)
        )

        dec_in = pad_sequence(list(dec_in_samples), batch_first=self.batch_first, padding_value=self.word_pad_idx)
        dec_out = pad_sequence(list(dec_out_samples), batch_first=self.batch_first, padding_value=self.word_pad_idx)

        word_pad_mask = dec_in == self.word_pad_idx
        if not self.batch_first:
            word_pad_mask = word_pad_mask.T  # word_pad_mask is always batch first

        encoder_in_el = encoder_in[0] if is_encoder_input_tuple else encoder_in
        max_curve_len = encoder_in_el.shape[1] if self.batch_first else encoder_in_el.shape[0]
        encoder_in_no_pad_el = encoder_in_no_pad[0] if is_encoder_input_tuple else encoder_in_no_pad
        encoder_lens = torch.tensor([len(x) for x in encoder_in_no_pad_el])

        # Берем матрицу c len(encoder_lens) строками вида
        # [0, 1, ... , max_curve_len - 1].  Каждый элемент i-ой строки
        # сравниваем с длиной i-ой траектории.  Получится матрица, где True
        # только на позициях, больших, чем длина соответствующей траектории.
        # (batch_size, max_curve_len)
        encoder_pad_mask = torch.arange(max_curve_len).expand(
            len(encoder_lens), max_curve_len) >= encoder_lens.unsqueeze(1)
                
        transformer_in = (encoder_in, dec_in, encoder_pad_mask, word_pad_mask)
        
        return transformer_in, dec_out

In [9]:
def generate_batch(is_encoder_in_tuple: bool,
                   batch_size = 7,
                    min_swipe_len = 8,
                    max_swipe_len = 16,
                    min_word_len = 1,
                    max_word_len = 10) -> list:
    
    swipe_lens = torch.randint(min_swipe_len, max_swipe_len, (batch_size,))
    word_lens = torch.randint(min_word_len, max_word_len, (batch_size,))


    batch = []
    for i in range(batch_size):
        if is_encoder_in_tuple:
            encoder_in = tuple(torch.tensor(range(swipe_lens[i])) for _ in range (2))
        else:
            encoder_in = torch.tensor(range(swipe_lens))
        decoder_in, decoder_out = [torch.tensor(range(word_lens[i])) for _ in range(2)]

        batch.append(((encoder_in, decoder_in), decoder_out))

    return batch



In [38]:
class CollateFnV4:
    def __init__(self, batch_first: bool, word_pad_idx: int, 
                 swipe_pad_idx: int = 0) -> None:
        self.word_pad_idx = word_pad_idx
        self.batch_first = batch_first
        self.swipe_pad_idx = swipe_pad_idx

    def _assert_encoder_in_type_and_shape(self, encoder_in_example):
        assert len(encoder_in_example) == 2
        for el in encoder_in_example:
            assert isinstance(el, torch.Tensor), \
                f"Expected torch.Tensor, got {type(el)}"
        
    def _is_encoder_input_tuple(self, batch):
        encoder_in_example = batch[0][0][0]
        if isinstance(encoder_in_example, tuple):
            self._assert_encoder_in_type_and_shape(encoder_in_example)
            return True
        elif isinstance(encoder_in_example, torch.Tensor):
            return False
        else:
            raise ValueError(f"Unknown type of encoder input {type(batch[0][0])}")
    
    def __call__(self, batch: list):
        """
        Given a List where each row is 
        ((encoder_in_sample, decoder_in_sample), decoder_out_sample) 
        returns a tuple of two elements:
        1. (encoder_in, decoder_in, swipe_pad_mask, word_pad_mask)
        2. decoder_out

        Arguments:
        ----------
        batch: list of tuples:
            ((encoder_in, dec_in_char_seq), dec_out_char_seq),
            where encoder_in may be a tuple of torch tensors
            (ex. ```(traj_feats, nearest_kb_tokens)```)
            or a single tensor (ex. ```nearest_kb_tokens```)


        Returns:
        --------
        transformer_in: tuple of torch tensors:
            1. (enc_in, dec_in, swipe_pad_mask, word_pad_mask),
                where enc_in can be either a single tensor or a tuple
                of two tensors (depends on type of input)
                Each element is a torch tensor of shape:
                - enc_in: either (curve_len, batch_size, n_feats) or
                    ((curve_len, batch_size, n_feats1), (curve_len, batch_size, n_feats2))
                - dec_in: (chars_seq_len - 1, batch_size)
                - swipe_pad_mask: (batch_size, curve_len)
                - word_pad_mask: (batch_size, chars_seq_len - 1, )
        """
        is_encoder_input_tuple = self._is_encoder_input_tuple(batch)
        dec_in_no_pad = []
        dec_out_no_pad = []

        encoder_in_no_pad = ([], []) if is_encoder_input_tuple else ([],)

        for row in batch:
            x_smpl, decoder_out_smpl = row
            encoder_in_smpl, decoder_in_smpl = x_smpl
            if isinstance(encoder_in_smpl, torch.Tensor):
                encoder_in_smpl = (encoder_in_smpl,)
            # assert isinstance(encoder_in_smpl, tuple)
            for enc_in_smlp_i, enc_in_no_pad_i in zip(encoder_in_smpl, encoder_in_no_pad):
                enc_in_no_pad_i.append(enc_in_smlp_i)

            dec_in_no_pad.append(decoder_in_smpl)
            dec_out_no_pad.append(decoder_out_smpl)

        encoder_in = tuple(pad_sequence(encoder_in_no_pad_i, batch_first=self.batch_first, 
                                    padding_value=self.swipe_pad_idx)
                        for encoder_in_no_pad_i in encoder_in_no_pad)
        if len(encoder_in) == 1:
            encoder_in = encoder_in[0]

        dec_out = pad_sequence(dec_out_no_pad, batch_first=self.batch_first,
                                        padding_value=self.word_pad_idx)
        
        dec_in = pad_sequence(dec_in_no_pad, batch_first=self.batch_first,
                                        padding_value=self.word_pad_idx)
        
        word_pad_mask = dec_in == self.word_pad_idx
        if not self.batch_first:
            word_pad_mask = word_pad_mask.T  # word_pad_mask is always batch first

        encoder_in_el = encoder_in[0] if is_encoder_input_tuple else encoder_in
        max_curve_len = encoder_in_el.shape[1] if self.batch_first else encoder_in_el.shape[0]
        encoder_in_no_pad_el = encoder_in_no_pad[0] if is_encoder_input_tuple else encoder_in_no_pad
        encoder_lens = torch.tensor([len(x) for x in encoder_in_no_pad_el])

        # Берем матрицу c len(encoder_lens) строками вида
        # [0, 1, ... , max_curve_len - 1].  Каждый элемент i-ой строки
        # сравниваем с длиной i-ой траектории.  Получится матрица, где True
        # только на позициях, больших, чем длина соответствующей траектории.
        # (batch_size, max_curve_len)
        encoder_pad_mask = torch.arange(max_curve_len).expand(
            len(encoder_lens), max_curve_len) >= encoder_lens.unsqueeze(1)
        
        transformer_in = (encoder_in, dec_in, encoder_pad_mask, word_pad_mask)
        
        return transformer_in, dec_out


In [30]:
batch = generate_batch(True)
(encoder_in, decoder_in), decoder_out = batch[0]

print(encoder_in, decoder_in, decoder_out, sep='\n')


(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]), tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]))
tensor([0, 1, 2, 3, 4, 5, 6, 7])
tensor([0, 1, 2, 3, 4, 5, 6, 7])


In [31]:
isinstance(encoder_in[0], torch.Tensor)

True

In [50]:
collate_fn = CollateFnV2(batch_first=False, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

assert torch.equal(encoder_in[0]==-1 , encoder_pad_mask.T)

In [51]:
collate_fn = CollateFnV2(batch_first=True, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

assert torch.equal(encoder_in[0]==-1 , encoder_pad_mask)

In [34]:
collate_fn = CollateFnV3(batch_first=False, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

assert torch.equal(encoder_in[0]==-1 , encoder_pad_mask.T)

In [35]:
collate_fn = CollateFnV3(batch_first=True, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

assert torch.equal(encoder_in[0]==-1 , encoder_pad_mask)

In [39]:
collate_fn = CollateFnV4(batch_first=False, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

assert torch.equal(encoder_in[0]==-1 , encoder_pad_mask.T)

In [40]:
collate_fn = CollateFnV4(batch_first=True, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

assert torch.equal(encoder_in[0]==-1 , encoder_pad_mask)

In [45]:
%%timeit -r 3 -n 10000

collate_fn = CollateFnV4(batch_first=True, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)


630 µs ± 87.1 µs per loop (mean ± std. dev. of 3 runs, 10,000 loops each)


In [43]:
%%timeit -r 3 -n 10000
collate_fn = CollateFnV3(batch_first=True, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

527 µs ± 26.3 µs per loop (mean ± std. dev. of 3 runs, 10,000 loops each)


In [54]:
%%timeit -r 3 -n 10000

collate_fn = CollateFnV2(batch_first=True, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)


510 µs ± 2.58 µs per loop (mean ± std. dev. of 3 runs, 10,000 loops each)


In [46]:
%%timeit -r 3 -n 10000

collate_fn = CollateFnV4(batch_first=False, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

538 µs ± 11.2 µs per loop (mean ± std. dev. of 3 runs, 10,000 loops each)


In [47]:
%%timeit -r 3 -n 10000
collate_fn = CollateFnV3(batch_first=False, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

533 µs ± 11.5 µs per loop (mean ± std. dev. of 3 runs, 10,000 loops each)


In [53]:
%%timeit -r 3 -n 10000
collate_fn = CollateFnV2(batch_first=False, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

554 µs ± 53 µs per loop (mean ± std. dev. of 3 runs, 10,000 loops each)


In [48]:
%%timeit -r 3 -n 10000
collate_fn = CollateFnV2(batch_first=False, word_pad_idx=-1, swipe_pad_idx = -1)

(encoder_in, dec_in, encoder_pad_mask, word_pad_mask), dec_out =  collate_fn(batch)

555 µs ± 21 µs per loop (mean ± std. dev. of 3 runs, 10,000 loops each)


In [83]:
# Протестируем корректность collate_fn (вызывается неявно в DataLoader)

batch_size = 6


PAD_CHAR_TOKEN = word_char_tokenizer.char_to_idx["<pad>"]


train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
                              num_workers=0, collate_fn=collate_fn)


dataset_els = [train_dataset[i] for i in range(batch_size)]
unproc_batch_x, unproc_batch_y = zip(*dataset_els)

batch_x, batch_y = next(iter(train_dataloader))


############### Проверка корректности batch_y ###################
max_out_seq_len = max([len(y) for y in unproc_batch_y])

assert batch_y.shape == (max_out_seq_len, batch_size)


for i in range(batch_size):
    assert (batch_y[:len(unproc_batch_y[i]), i] == unproc_batch_y[i]).all()
    assert (batch_y[len(unproc_batch_y[i]):, i] == PAD_CHAR_TOKEN).all()

print("batch_y is correct")



############### Проверка корректности batch_x ###################
unproc_batch_traj_feats, unproc_batch_kb_tokens, unproc_batch_dec_in_char_seq = zip(*unproc_batch_x)

(traj_feats, kb_tokens, dec_in_char_seq, traj_pad_mask, word_pad_mask) = batch_x


# каждая сущность, полученная выше из unpoc_batch_x - это tuple длины batch_size.
# Например, unproc_batch_traj_feats[i] = train_dataset[i][0][0]

N_TRAJ_FEATS = 6
max_curve_len = max([el.shape[0] for el in unproc_batch_traj_feats]) 

assert max_curve_len == max([el.shape[0] for el in unproc_batch_kb_tokens])

assert traj_feats.shape == (max_curve_len, batch_size, N_TRAJ_FEATS)
assert kb_tokens.shape == (max_curve_len, batch_size)
assert dec_in_char_seq.shape == (max_out_seq_len, batch_size)
assert traj_pad_mask.shape == (batch_size, max_curve_len)
assert word_pad_mask.shape == (batch_size, max_out_seq_len)


for i in range(batch_size):
    assert (traj_feats[:len(unproc_batch_traj_feats[i]), i] == unproc_batch_traj_feats[i]).all()
    assert (kb_tokens[:len(unproc_batch_kb_tokens[i]), i] == unproc_batch_kb_tokens[i]).all()

    assert (dec_in_char_seq[:len(unproc_batch_dec_in_char_seq[i]), i] == unproc_batch_dec_in_char_seq[i]).all()
    assert (dec_in_char_seq[len(unproc_batch_dec_in_char_seq[i]):, i] == PAD_CHAR_TOKEN).all()

    assert (traj_pad_mask[i, :len(unproc_batch_traj_feats[i])] == False).all()
    assert (traj_pad_mask[i, len(unproc_batch_traj_feats[i]):] == True).all()
    
    assert (word_pad_mask[i, :len(unproc_batch_dec_in_char_seq[i])] == False).all()
    assert (word_pad_mask[i, len(unproc_batch_dec_in_char_seq[i]):] == True).all()

print("batch_x is correct")

NameError: name 'word_char_tokenizer' is not defined