In [20]:
from datasets import Dataset
import sentencepiece as spm
from torch.nn.utils.rnn import pad_sequence
import torch
from torch.utils.data import DataLoader
from functools import partial
text = ['GLTNAFIASAPAREVRYDGVITPANANYRFMGGDKGGSLTVGSHLTGSNMVTIGPMGVVVFTNNNDYTGNTFIMGGGTLQLGSNTAWGSLPN\n',
        'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFPLSPAQLGIWYAQHLDPQVPITIAQYVDLHGALDVEVLERASIDASHELGSGFLRIVERDGEPLQYV\n',
        'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFP\n',
        'MDRLDFGGGE\n']
train_dataset = Dataset.from_dict({'sequences': text})
train_dataset

Dataset({
    features: ['sequences'],
    num_rows: 4
})

In [80]:
class FastaBatchedDataset(object):
    """inspired by esm2, but instead of sorting the original sequences,
    we should really sorting based on tokenized sequences
    """
    def __init__(self,sequence_strs, tokenizer, max_sequence_length):
        self.sequence_strs = sequence_strs['sequences']
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length
        raw_tokenized_sequences = [self.tokenizer.encode(s) for s in self.sequence_strs]
        self.tokenized_sequences = []
        for tokens in raw_tokenized_sequences:
            if len(tokens) >= self.max_sequence_length-1:
                # considering the added special token bos
                sampled_windows = self.sample_windows(tokens, self.max_sequence_length)
                for sample in sampled_windows:
                    sample.insert(0, self.tokenizer.bos_id())
                    self.tokenized_sequences.append(sample)
            else:
                tokens.insert(0, self.tokenizer.bos_id())
                self.tokenized_sequences.append(tokens)

    @staticmethod
    def sample_windows(sequence, max_sequence_length, extra_toks_per_seq=1):
        import random
        random.seed(42)
        """based on Phil's implement
        Returns: a list of sampled windows.
        """
        sampled_windows = []
        # the beginning window
        sampled_windows.append(sequence[:max_sequence_length-extra_toks_per_seq])
        # calculate the num of random slices needed, remove head and tail
        num_slices_required = (len(sequence) // max_sequence_length) - 2
        max_start_index = len(sequence) - max_sequence_length

        if max_start_index < 0:
            raise ValueError("max_sequence_length greater than length of sequence")

        if num_slices_required > 0:
            for _ in range(num_slices_required):
                # Randomly select start index of the window
                start_index = random.randint(0, max_start_index)
                # Extract window
                # Considering the added special token bos
                window = sequence[start_index:start_index + max_sequence_length-extra_toks_per_seq]
                # Append the window to the list of sampled windows
                sampled_windows.append(window)

        # the end window
        sampled_windows.append(sequence[-(max_sequence_length-extra_toks_per_seq):])
        return sampled_windows

    def get_batch_indices(self, extra_toks_per_seq=1):
        print(self.sequence_strs)
        sizes = [(len(tokens), i) for i, tokens in enumerate(self.tokenized_sequences)]
        sizes.sort()
        print(sizes)
        batches = []
        buf = []
        current_buf_len = 0

        def _flush_current_buf():
            nonlocal current_buf_len, buf
            if len(buf) == 0:
                return
            batches.append(buf)
            buf = []
            current_buf_len = 0
            #print('my batches is:')
            #print(batches)

        for sz, i in sizes:
            # considering the extra bos
            sz += extra_toks_per_seq
            # check accumulative seq length in the buffer
            if current_buf_len + sz > self.max_sequence_length:
                _flush_current_buf()
            buf.append(i)
            current_buf_len += sz
            #print('my buffer is:')
            #print(buf)

        _flush_current_buf()
        return batches

    def __len__(self):
        return len(self.tokenized_sequences)

    def __getitem__(self, idx):
        return self.tokenized_sequences[idx]

In [81]:
class BatchConverter(object):
    """add padding, create labels for GPT-alike training, used as collate_fn, need processed batch indices
    processed (labels + tensor) batch.
    """

    def __init__(self, tokenizer):
        self.pad_token_id = tokenizer.pad_id()

    def __call__(self, batches):
        data_tokens = [torch.tensor(token_list) for token_list in batches]
        data_tokens_padded = pad_sequence(data_tokens, batch_first=True, padding_value=self.pad_token_id)

        # Create attention masks
        attention_masks = (data_tokens_padded != self.pad_token_id).long()

        # skip label==-100 during training so that these tokens won't be used in loss calculation
        labels = data_tokens_padded.clone()
        labels[data_tokens_padded == self.pad_token_id] = -100

        return {
            'input_ids': data_tokens_padded,
            'attention_mask': attention_masks,
            'labels': labels
        }

In [82]:
tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path+"protein_8k.model")

In [83]:
dataset = FastaBatchedDataset(train_dataset, tokenizer, 10)
batches = dataset.get_batch_indices()
dataloader = torch.utils.data.DataLoader(dataset, collate_fn=BatchConverter(tokenizer),
                                          batch_sampler=batches, pin_memory=True)


['GLTNAFIASAPAREVRYDGVITPANANYRFMGGDKGGSLTVGSHLTGSNMVTIGPMGVVVFTNNNDYTGNTFIMGGGTLQLGSNTAWGSLPN\n', 'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFPLSPAQLGIWYAQHLDPQVPITIAQYVDLHGALDVEVLERASIDASHELGSGFLRIVERDGEPLQYV\n', 'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFP\n', 'MDRLDFGGGE\n']
[(5, 8), (10, 0), (10, 1), (10, 2), (10, 3), (10, 4), (10, 5), (10, 6), (10, 7)]


In [84]:
batches

[[8], [0], [1], [2], [3], [4], [5], [6], [7]]

In [76]:
for batch_id, batch in enumerate(dataloader):
    print(batch_id, batch)

0 {'input_ids': tensor([[   1,  854, 1642,  653,   91]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]]), 'labels': tensor([[   1,  854, 1642,  653,   91]])}
1 {'input_ids': tensor([[   1,  820,  814, 3992,   37, 1313, 7791,  245, 7689,  701]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[   1,  820,  814, 3992,   37, 1313, 7791,  245, 7689,  701]])}
2 {'input_ids': tensor([[   1, 2844, 3768, 4240, 2298, 3576, 1302,  828,  114, 2232]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[   1, 2844, 3768, 4240, 2298, 3576, 1302,  828,  114, 2232]])}
3 {'input_ids': tensor([[   1,   37, 1313, 7791,  245, 7689,  701, 4163, 6182, 1302]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[   1,   37, 1313, 7791,  245, 7689,  701, 4163, 6182, 1302]])}
4 {'input_ids': tensor([[   1,  820,  814, 3992,   37, 1313, 7791,  245, 7689,  701]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels

In [79]:
# verify via simple functions
def sample_windows(sequence, max_sequence_length):
    import random
    random.seed(42)
    """based on Phil's implement
    Returns: a list of sampled windows.
    """
    sampled_windows = []
    # the beginning window
    sampled_windows.append(sequence[:max_sequence_length - 1])
     # Number of random slices needed
    num_slices_required = (len(sequence) // max_sequence_length) - 2
    max_start_index = len(sequence) - max_sequence_length

    if max_start_index < 0:
        raise ValueError("max_sequence_length greater than length of sequence")

    if num_slices_required > 0:
        for _ in range(num_slices_required):
            # Randomly select start index of the window
            start_index = random.randint(0, max_start_index)
            # Extract window
            # Considering the added special token bos
            window = sequence[start_index:start_index + max_sequence_length - 1]
            # Append the window to the list of sampled windows
            sampled_windows.append(window)

    # the end window
    sampled_windows.append(sequence[-(max_sequence_length - 1):])
    return sampled_windows


def tokenize_batch(batch, tokenizer, max_sequence_length):
    proteins = [sample['sequences'] for sample in batch]
    data_tokens = tokenizer.encode(proteins)
    bos = tokenizer.bos_id()
    data_tokens_sampled = []
    for token_list in data_tokens:
        length = len(token_list)
        print(token_list, len(token_list))
        if length >= max_sequence_length - 1:
            # considering the added special token bos
            sampled_windows = sample_windows(token_list, max_sequence_length)
            for sample in sampled_windows:
                sample.insert(0, bos)
                data_tokens_sampled.append(sample)
        else:
            token_list.insert(0, bos)
            data_tokens_sampled.append(token_list)
            #data_tokens_sampled.append(token_list)

    return data_tokens_sampled

dataloader = DataLoader(train_dataset,
                        batch_size=2,
                        shuffle=False,
                        drop_last=False,
                        collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_sequence_length=10))
for batch_id, batch in enumerate(dataloader):
    print(batch_id)
    for v in batch:
        print(v)

[820, 814, 3992, 37, 1313, 7791, 245, 7689, 701, 4163, 6182, 1302, 4844, 7660, 1104, 3455, 2232, 4290, 2886, 2351, 2844, 3768, 4240, 2298, 3576, 1302, 828, 114, 2232, 43, 6101, 1894] 32
[854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817, 1950, 3101, 7696, 1783, 6728, 3777, 2132, 2949, 2596, 1868, 2403, 3399, 760, 1197, 614, 2013, 2199, 2572, 688, 2203, 1363, 890, 1369, 5521] 33
0
[1, 820, 814, 3992, 37, 1313, 7791, 245, 7689, 701]
[1, 2844, 3768, 4240, 2298, 3576, 1302, 828, 114, 2232]
[1, 2298, 3576, 1302, 828, 114, 2232, 43, 6101, 1894]
[1, 854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817]
[1, 3399, 760, 1197, 614, 2013, 2199, 2572, 688, 2203]
[1, 2013, 2199, 2572, 688, 2203, 1363, 890, 1369, 5521]
[854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817, 49, 141, 2076] 12
[854, 1642, 653, 91] 4
1
[1, 854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817]
[1, 1279, 639, 86, 1195, 1214, 6817, 49, 141, 2076]
[1, 854, 1642, 653, 91]
