In [1]:
text = ['GLTNAFIASAPAREVRYDGVITPANANYRFMGGDKGGSLTVGSHLTGSNMVTIGPMGVVVFTNNNDYTGNTFIMGGGTLQLGSNTAWGSLPN\n',
        'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFPLSPAQLGIWYAQHLDPQVPITIAQYVDLHGALDVEVLERASIDASHELGSGFLRIVERDGEPLQYV\n',
        'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFP\n',
        'MDRLDFGGGE\n']

In [7]:
from datasets import Dataset
import sentencepiece as spm
import torch
train_dataset = Dataset.from_dict({'sequences': text})
train_dataset

Dataset({
    features: ['sequences'],
    num_rows: 4
})

In [8]:
from torch.utils.data import Dataset

class FastaBatchedDataset(Dataset):
    def __init__(self, sequence_strs, tokenizer):
        self.sequence_strs = list(sequence_strs['sequences'])
        self.tokenizer = tokenizer
        self.tokenized_sequences = [self.tokenizer.encode(s) for s in self.sequence_strs]

    # Implement the __len__ method
    def __len__(self):
        return len(self.tokenized_sequences)

    # Implement the __getitem__ method
    def __getitem__(self, idx):
        return self.tokenized_sequences[idx]
def collate_fn(data, max_sequence_length, extra_toks_per_seq=1):
    sizes = [(len(tokens), tokens) for tokens in data]
    sizes.sort(key=lambda x: x[0])

    batches = []
    buf = []
    current_buf_len = 0

    def _flush_current_buf():
        nonlocal current_buf_len, buf
        if len(buf) == 0:
            return
        batches.extend(buf)
        buf = []
        current_buf_len = 0

    for sz, tokens in sizes:
        sz += extra_toks_per_seq
        if current_buf_len + sz > max_sequence_length:
            _flush_current_buf()
        buf.append(torch.tensor(tokens))
        current_buf_len += sz

    _flush_current_buf()
    return torch.stack(batches)
from torch.utils.data import DataLoader

desired_max_sequence_length = 60
train_dataset = {"sequences": ["A" * 10, "A" * 20, "A" * 60, "A" * 100]}
tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path+"protein_8k.model")
dataset = FastaBatchedDataset(train_dataset, tokenizer)

# Wrap the dataset in DataLoader
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda data: collate_fn(data, desired_max_sequence_length))

# Iterate over batches
for batch in loader:
    # Your training loop here
    print(batch)


RuntimeError: stack expects each tensor to be equal size, but got [3] at entry 0 and [5] at entry 1

In [49]:
class FastaBatchedDataset(object):
    """inspired by esm2, but instead of sorting the original sequences,
    we should really sorting based on tokenized sequences
    """
    def __init__(self,sequence_strs, tokenizer, max_sequence_length):
        self.sequence_strs = sequence_strs['sequences']
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length
        self.tokenized_sequence = [self.tokenizer.encode(s) for s in self.sequence_strs]
        self.tokenized_batches = self.get_tokenized_batches()

    def get_batch_indices(self, extra_toks_per_seq=1):
        print(self.sequence_strs)
        sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
        sizes.sort()
        print(sizes)
        batches = []
        buf = []
        current_buf_len = 0

        def _flush_current_buf():
            nonlocal current_buf_len, buf
            if len(buf) == 0:
                return
            batches.append(buf)
            buf = []
            current_buf_len = 0
            #print('my batches is:')
            #print(batches)

        for sz, i in sizes:
            # considering the extra bos
            sz += extra_toks_per_seq
            # check accumulative seq length in the buffer
            if current_buf_len + sz > self.max_sequence_length:
                _flush_current_buf()
            buf.append(i)
            current_buf_len += sz
            #print('my buffer is:')
            #print(buf)

        _flush_current_buf()
        return batches

    @staticmethod
    def sample_windows(sequence, max_sequence_length, num_intervals):
        import random
        random.seed(42)
        """based on Phil's implement
        Returns: a list of sampled windows.
        """
        sampled_windows = []
        # the beginning window
        sampled_windows.append(sequence[:max_sequence_length-1])

        if num_intervals > 2:
            max_start_index = len(sequence) - max_sequence_length
            if max_start_index < 0:
                raise ValueError("max_sequence_length greater than length of sequence")
            for _ in range(num_intervals):
                # Randomly select start index of the window
                start_index = random.randint(0, max_start_index)
                # Extract window
                # Considering the added special token bos
                window = sequence[start_index:start_index + max_sequence_length-1]
                # Append the window to the list of sampled windows
                sampled_windows.append(window)

        # the end window
        sampled_windows.append(sequence[-(max_sequence_length-1):])
        return sampled_windows

    def get_tokenized_batches(self):
        batch_indices = self.get_batch_indices(self.max_sequence_length)
        print(batch_indices)
        tokenized_batches = []

        for batch_index in batch_indices:
            batch_sequences = [self.tokenized_sequence[i] for i in batch_index]
            tokenized_batch = []
            for token_list in batch_sequences:
                if len(token_list) >= self.max_sequence_length-1:
                    # considering the added special token bos
                    sampled_windows = self.sample_windows(token_list, self.max_sequence_length, len(token_list)//self.max_sequence_length)
                    for sample in sampled_windows:
                        sample.insert(0, self.tokenizer.bos_id())
                        tokenized_batch.append(sample)
                else:
                    token_list.insert(0, self.tokenizer.bos_id())
                    tokenized_batch.append(token_list)
            tokenized_batches.append(tokenized_batch)
        return tokenized_batches

    def __len__(self):
        return len(self.tokenized_batches)

    def __getitem__(self, idx):
        return self.tokenized_batches[idx]

import sentencepiece as spm
tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path+"protein_8k.model")
dataset = FastaBatchedDataset(train_dataset, tokenizer, 60).get_tokenized_batches()

['GLTNAFIASAPAREVRYDGVITPANANYRFMGGDKGGSLTVGSHLTGSNMVTIGPMGVVVFTNNNDYTGNTFIMGGGTLQLGSNTAWGSLPN\n', 'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFPLSPAQLGIWYAQHLDPQVPITIAQYVDLHGALDVEVLERASIDASHELGSGFLRIVERDGEPLQYV\n', 'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFP\n', 'MDRLDFGGGE\n']
[(11, 3), (34, 2), (93, 0), (100, 1)]
[[3], [2], [0], [1]]
['GLTNAFIASAPAREVRYDGVITPANANYRFMGGDKGGSLTVGSHLTGSNMVTIGPMGVVVFTNNNDYTGNTFIMGGGTLQLGSNTAWGSLPN\n', 'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFPLSPAQLGIWYAQHLDPQVPITIAQYVDLHGALDVEVLERASIDASHELGSGFLRIVERDGEPLQYV\n', 'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFP\n', 'MDRLDFGGGE\n']
[(11, 3), (34, 2), (93, 0), (100, 1)]
[[3], [2], [0], [1]]


In [17]:
dataset

[[[1, 1, 854, 1642, 653, 91]],
 [[1, 1, 854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817, 49, 141, 2076]],
 [[1,
   1,
   820,
   814,
   3992,
   37,
   1313,
   7791,
   245,
   7689,
   701,
   4163,
   6182,
   1302,
   4844,
   7660,
   1104,
   3455,
   2232,
   4290,
   2886,
   2351,
   2844,
   3768,
   4240,
   2298,
   3576,
   1302,
   828,
   114,
   2232,
   43,
   6101,
   1894]],
 [[1,
   1,
   854,
   1642,
   653,
   1279,
   639,
   86,
   1195,
   1214,
   6817,
   1950,
   3101,
   7696,
   1783,
   6728,
   3777,
   2132,
   2949,
   2596,
   1868,
   2403,
   3399,
   760,
   1197,
   614,
   2013,
   2199,
   2572,
   688,
   2203,
   1363,
   890,
   1369,
   5521]]]

In [50]:
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset)

In [51]:
for batch_id, batch in enumerate(dataloader):
    print(batch_id, batch)

0 [[tensor([1]), tensor([1]), tensor([854]), tensor([1642]), tensor([653]), tensor([91])]]
1 [[tensor([1]), tensor([1]), tensor([854]), tensor([1642]), tensor([653]), tensor([1279]), tensor([639]), tensor([86]), tensor([1195]), tensor([1214]), tensor([6817]), tensor([49]), tensor([141]), tensor([2076])]]
2 [[tensor([1]), tensor([1]), tensor([820]), tensor([814]), tensor([3992]), tensor([37]), tensor([1313]), tensor([7791]), tensor([245]), tensor([7689]), tensor([701]), tensor([4163]), tensor([6182]), tensor([1302]), tensor([4844]), tensor([7660]), tensor([1104]), tensor([3455]), tensor([2232]), tensor([4290]), tensor([2886]), tensor([2351]), tensor([2844]), tensor([3768]), tensor([4240]), tensor([2298]), tensor([3576]), tensor([1302]), tensor([828]), tensor([114]), tensor([2232]), tensor([43]), tensor([6101]), tensor([1894])]]
3 [[tensor([1]), tensor([1]), tensor([854]), tensor([1642]), tensor([653]), tensor([1279]), tensor([639]), tensor([86]), tensor([1195]), tensor([1214]), tensor([

In [1]:
from transformers import PreTrainedTokenizer
import sentencepiece as spm
tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path+"protein_8k.model")

In [2]:
# Check the tokens corresponding to specific ids
ids_to_check = [0, 1, 2]

for id in ids_to_check:
    token = tokenizer.IdToPiece(id)
    print(f"ID {id} corresponds to token: {token}")

ID 0 corresponds to token: <unk>
ID 1 corresponds to token: <s>
ID 2 corresponds to token: </s>


In [200]:
import random
random.seed(42)
def sample_windows(sequence, max_sequence_length, num_intervals):
    """based on Phil's implement
    Returns: a list of sampled windows.
    """
    sampled_windows = []
    # the beginning window
    sampled_windows.append(sequence[:max_sequence_length-1])

    if num_intervals > 2:
        max_start_index = len(sequence) - max_sequence_length
        if max_start_index < 0:
            raise ValueError("max_sequence_length greater than length of sequence")
        for _ in range(num_intervals):
            # Randomly select start index of the window
            start_index = random.randint(0, max_start_index)
            # Extract window
            # Considering the added special token bos
            window = sequence[start_index:start_index + max_sequence_length-1]
            # Append the window to the list of sampled windows
            sampled_windows.append(window)

    # the end window
    sampled_windows.append(sequence[-(max_sequence_length-1):])
    return sampled_windows

In [201]:
tokenizer.bos_id()

1

In [202]:
import torch
def tokenize_batch(batch, tokenizer, max_sequence_length):
    proteins = [sample['sequences'] for sample in batch]
    data_tokens = tokenizer.encode(proteins)
    bos = tokenizer.bos_id()
    data_tokens_sampled = []
    for token_list in data_tokens:
        length = len(token_list)
        print(token_list)
        if length >= max_sequence_length-1:
            # considering the added special token bos
            sampled_windows = sample_windows(token_list, max_sequence_length, length//max_sequence_length)
            for sample in sampled_windows:
                sample.insert(0, bos)
                data_tokens_sampled.append(sample)
        else:
            token_list.insert(0, bos)
            data_tokens_sampled.append(token_list)
            #data_tokens_sampled.append(token_list)

    return data_tokens_sampled

#tokenized_dataset = train_dataset.map(tokenize_batch, batched=True)

In [203]:
from torch.utils.data import DataLoader
from functools import partial
dataloader = DataLoader(train_dataset,
                     batch_size=1,
                     shuffle=False,
                     drop_last=False,
                     collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_sequence_length = 10))

In [204]:
for batch_id, batch in enumerate(dataloader):
    print(batch_id)
    for v in batch:
        print(v)

[820, 814, 3992, 37, 1313, 7791, 245, 7689, 701, 4163, 6182, 1302, 4844, 7660, 1104, 3455, 2232, 4290, 2886, 2351, 2844, 3768, 4240, 2298, 3576, 1302, 828, 114, 2232, 43, 6101, 1894]
0
[1, 820, 814, 3992, 37, 1313, 7791, 245, 7689, 701]
[1, 2844, 3768, 4240, 2298, 3576, 1302, 828, 114, 2232]
[1, 37, 1313, 7791, 245, 7689, 701, 4163, 6182, 1302]
[1, 820, 814, 3992, 37, 1313, 7791, 245, 7689, 701]
[1, 2298, 3576, 1302, 828, 114, 2232, 43, 6101, 1894]
[854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817, 1950, 3101, 7696, 1783, 6728, 3777, 2132, 2949, 2596, 1868, 2403, 3399, 760, 1197, 614, 2013, 2199, 2572, 688, 2203, 1363, 890, 1369, 5521]
1
[1, 854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817]
[1, 614, 2013, 2199, 2572, 688, 2203, 1363, 890, 1369]
[1, 6817, 1950, 3101, 7696, 1783, 6728, 3777, 2132, 2949]
[1, 1214, 6817, 1950, 3101, 7696, 1783, 6728, 3777, 2132]
[1, 2013, 2199, 2572, 688, 2203, 1363, 890, 1369, 5521]
[854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817, 49, 141, 2076]
2
[1, 854,

In [20]:
text = ['GLTNAFIASAPAREVRYDGVITPANANYRFMGGDKGGSLTVGSHLTGSNMVTIGPMGVVVFTNNNDYTGNTFIMGGGTLQLGSNTAWGSLPN\n',
        'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFPLSPAQLGIWYAQHLDPQVPITIAQYVDLHGALDVEVLERASIDASHELGSGFLRIVERDGEPLQYV\n',
        'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFP\n',
        'MDRLDFGGGE\n']
from datasets import Dataset
import sentencepiece as spm
import torch

train_dataset = Dataset.from_dict({'sequences': text})
train_dataset

Dataset({
    features: ['sequences'],
    num_rows: 4
})

In [80]:
class FastaBatchedDataset(object):
    """inspired by esm2, but instead of sorting the original sequences,
    we should really sorting based on tokenized sequences
    """
    def __init__(self,sequence_strs, tokenizer, max_sequence_length):
        self.sequence_strs = sequence_strs['sequences']
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length
        raw_tokenized_sequences = [self.tokenizer.encode(s) for s in self.sequence_strs]
        self.tokenized_sequences = []
        for tokens in raw_tokenized_sequences:
            if len(tokens) >= self.max_sequence_length-1:
                # considering the added special token bos
                sampled_windows = self.sample_windows(tokens, self.max_sequence_length)
                for sample in sampled_windows:
                    sample.insert(0, self.tokenizer.bos_id())
                    self.tokenized_sequences.append(sample)
            else:
                tokens.insert(0, self.tokenizer.bos_id())
                self.tokenized_sequences.append(tokens)

    @staticmethod
    def sample_windows(sequence, max_sequence_length, extra_toks_per_seq=1):
        import random
        random.seed(42)
        """based on Phil's implement
        Returns: a list of sampled windows.
        """
        sampled_windows = []
        # the beginning window
        sampled_windows.append(sequence[:max_sequence_length-extra_toks_per_seq])
        # calculate the num of random slices needed, remove head and tail
        num_slices_required = (len(sequence) // max_sequence_length) - 2
        max_start_index = len(sequence) - max_sequence_length

        if max_start_index < 0:
            raise ValueError("max_sequence_length greater than length of sequence")

        if num_slices_required > 0:
            for _ in range(num_slices_required):
                # Randomly select start index of the window
                start_index = random.randint(0, max_start_index)
                # Extract window
                # Considering the added special token bos
                window = sequence[start_index:start_index + max_sequence_length-extra_toks_per_seq]
                # Append the window to the list of sampled windows
                sampled_windows.append(window)

        # the end window
        sampled_windows.append(sequence[-(max_sequence_length-extra_toks_per_seq):])
        return sampled_windows

    def get_batch_indices(self, extra_toks_per_seq=1):
        print(self.sequence_strs)
        sizes = [(len(tokens), i) for i, tokens in enumerate(self.tokenized_sequences)]
        sizes.sort()
        print(sizes)
        batches = []
        buf = []
        current_buf_len = 0

        def _flush_current_buf():
            nonlocal current_buf_len, buf
            if len(buf) == 0:
                return
            batches.append(buf)
            buf = []
            current_buf_len = 0
            #print('my batches is:')
            #print(batches)

        for sz, i in sizes:
            # considering the extra bos
            sz += extra_toks_per_seq
            # check accumulative seq length in the buffer
            if current_buf_len + sz > self.max_sequence_length:
                _flush_current_buf()
            buf.append(i)
            current_buf_len += sz
            #print('my buffer is:')
            #print(buf)

        _flush_current_buf()
        return batches

    def __len__(self):
        return len(self.tokenized_sequences)

    def __getitem__(self, idx):
        return self.tokenized_sequences[idx]

In [81]:
from torch.nn.utils.rnn import pad_sequence
class BatchConverter(object):
    """add padding, create labels for GPT-alike training, used as collate_fn, need processed batch indices
    processed (labels + tensor) batch.
    """

    def __init__(self, tokenizer):
        self.pad_token_id = tokenizer.pad_id()

    def __call__(self, batches):
        data_tokens = [torch.tensor(token_list) for token_list in batches]
        data_tokens_padded = pad_sequence(data_tokens, batch_first=True, padding_value=self.pad_token_id)

        # Create attention masks
        attention_masks = (data_tokens_padded != self.pad_token_id).long()

        # skip label==-100 during training so that these tokens won't be used in loss calculation
        labels = data_tokens_padded.clone()
        labels[data_tokens_padded == self.pad_token_id] = -100

        return {
            'input_ids': data_tokens_padded,
            'attention_mask': attention_masks,
            'labels': labels
        }

In [82]:
import sentencepiece as spm
tokenizer_path = '/data/rozen/home/e0833634/lama/protllama/batch_script/'
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path+"protein_8k.model")

In [83]:
dataset = FastaBatchedDataset(train_dataset, tokenizer, 10)
batches = dataset.get_batch_indices()
dataloader = torch.utils.data.DataLoader(dataset, collate_fn=BatchConverter(tokenizer),
                                          batch_sampler=batches, pin_memory=True)


['GLTNAFIASAPAREVRYDGVITPANANYRFMGGDKGGSLTVGSHLTGSNMVTIGPMGVVVFTNNNDYTGNTFIMGGGTLQLGSNTAWGSLPN\n', 'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFPLSPAQLGIWYAQHLDPQVPITIAQYVDLHGALDVEVLERASIDASHELGSGFLRIVERDGEPLQYV\n', 'MDRLDFGGNGEAGSEVAPVPVSGQPVSSEQLFP\n', 'MDRLDFGGGE\n']
[(5, 8), (10, 0), (10, 1), (10, 2), (10, 3), (10, 4), (10, 5), (10, 6), (10, 7)]


In [84]:
batches

[[8], [0], [1], [2], [3], [4], [5], [6], [7]]

In [85]:
for batch_id, batch in enumerate(dataloader):
    print(batch_id, batch)

0 {'input_ids': tensor([[   1,  854, 1642,  653,   91]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]]), 'labels': tensor([[   1,  854, 1642,  653,   91]])}
1 {'input_ids': tensor([[   1,  820,  814, 3992,   37, 1313, 7791,  245, 7689,  701]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[   1,  820,  814, 3992,   37, 1313, 7791,  245, 7689,  701]])}
2 {'input_ids': tensor([[   1, 2844, 3768, 4240, 2298, 3576, 1302,  828,  114, 2232]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[   1, 2844, 3768, 4240, 2298, 3576, 1302,  828,  114, 2232]])}
3 {'input_ids': tensor([[   1, 2298, 3576, 1302,  828,  114, 2232,   43, 6101, 1894]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[   1, 2298, 3576, 1302,  828,  114, 2232,   43, 6101, 1894]])}
4 {'input_ids': tensor([[   1,  854, 1642,  653, 1279,  639,   86, 1195, 1214, 6817]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels

In [79]:
# verify via simple functions
def sample_windows(sequence, max_sequence_length):
    import random
    random.seed(42)
    """based on Phil's implement
    Returns: a list of sampled windows.
    """
    sampled_windows = []
    # the beginning window
    sampled_windows.append(sequence[:max_sequence_length - 1])
     # Number of random slices needed
    num_slices_required = (len(sequence) // max_sequence_length) - 2
    max_start_index = len(sequence) - max_sequence_length

    if max_start_index < 0:
        raise ValueError("max_sequence_length greater than length of sequence")

    if num_slices_required > 0:
        for _ in range(num_slices_required):
            # Randomly select start index of the window
            start_index = random.randint(0, max_start_index)
            # Extract window
            # Considering the added special token bos
            window = sequence[start_index:start_index + max_sequence_length - 1]
            # Append the window to the list of sampled windows
            sampled_windows.append(window)

    # the end window
    sampled_windows.append(sequence[-(max_sequence_length - 1):])
    return sampled_windows

import torch

def tokenize_batch(batch, tokenizer, max_sequence_length):
    proteins = [sample['sequences'] for sample in batch]
    data_tokens = tokenizer.encode(proteins)
    bos = tokenizer.bos_id()
    data_tokens_sampled = []
    for token_list in data_tokens:
        length = len(token_list)
        print(token_list, len(token_list))
        if length >= max_sequence_length - 1:
            # considering the added special token bos
            sampled_windows = sample_windows(token_list, max_sequence_length)
            for sample in sampled_windows:
                sample.insert(0, bos)
                data_tokens_sampled.append(sample)
        else:
            token_list.insert(0, bos)
            data_tokens_sampled.append(token_list)
            #data_tokens_sampled.append(token_list)

    return data_tokens_sampled


from torch.utils.data import DataLoader
from functools import partial

dataloader = DataLoader(train_dataset,
                        batch_size=2,
                        shuffle=False,
                        drop_last=False,
                        collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_sequence_length=10))
for batch_id, batch in enumerate(dataloader):
    print(batch_id)
    for v in batch:
        print(v)

[820, 814, 3992, 37, 1313, 7791, 245, 7689, 701, 4163, 6182, 1302, 4844, 7660, 1104, 3455, 2232, 4290, 2886, 2351, 2844, 3768, 4240, 2298, 3576, 1302, 828, 114, 2232, 43, 6101, 1894] 32
[854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817, 1950, 3101, 7696, 1783, 6728, 3777, 2132, 2949, 2596, 1868, 2403, 3399, 760, 1197, 614, 2013, 2199, 2572, 688, 2203, 1363, 890, 1369, 5521] 33
0
[1, 820, 814, 3992, 37, 1313, 7791, 245, 7689, 701]
[1, 2844, 3768, 4240, 2298, 3576, 1302, 828, 114, 2232]
[1, 2298, 3576, 1302, 828, 114, 2232, 43, 6101, 1894]
[1, 854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817]
[1, 3399, 760, 1197, 614, 2013, 2199, 2572, 688, 2203]
[1, 2013, 2199, 2572, 688, 2203, 1363, 890, 1369, 5521]
[854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817, 49, 141, 2076] 12
[854, 1642, 653, 91] 4
1
[1, 854, 1642, 653, 1279, 639, 86, 1195, 1214, 6817]
[1, 1279, 639, 86, 1195, 1214, 6817, 49, 141, 2076]
[1, 854, 1642, 653, 91]


In [55]:
import pickle
with open('/data/rozen/home/e0833634/lama/protllama/original_lama.pkl', 'rb') as f:
    loaded_data_lama = pickle.load(f)
#name = 'togethercomputer/RedPajama-Data-1T-Sample'
#test_sample = loaded_data_lama['train'][:2]
#test_sample_dict = DatasetDict({"train": test_sample})['train']

{'text': ["\\section{Introduction}\n\\label{sec:intro}\n\n\\emph{Gender diversity}, or more often its lack thereof, among participants to\nsoftware development activities has been thoroughly studied in recent years. In\nparticular, the presence of, effects of, and countermeasures for \\emph{gender\n  bias} in Free/Open Source Software (FOSS) have received a lot of attention\nover the past decade~\\cite{david2008fossdevs, qiu2010kdewomen,\n  nafus2012patches, kuechler2012genderfoss, vasilescu2014gender,\n  oneil2016debiansurvey, robles2016womeninfoss, terrell2017gender,\n  zacchiroli2021gender}.  \\emph{Geographic diversity} is on the other hand the\nkind of diversity that stems from participants in some global activity coming\nfrom different world regions and cultures.\n\nGeographic diversity in FOSS has received relatively little attention in scholarly\nworks. In particular, while seminal survey-based and\npoint-in-time medium-scale studies of the geographic origins of FOSS\ncontribut

test_sample = loaded_data_lama['train'][:2]
test_sample

with open("/data/rozen/home/e0833634/lama/protllama/notebooks/text_sample.pkl", "wb") as f:
    pickle.dump(test_sample, f)

In [None]:
"""below is the original Llama tokenization handling, load your text sample with pickle first"""

In [43]:
import pickle

In [39]:
from transformers.models.llama.tokenization_llama import LlamaTokenizer

2023-09-29 12:40:54.760319: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [76]:
t = LlamaTokenizer(vocab_file=tokenizer_path+'protein_8k.vocab')

RuntimeError: Internal: /sentencepiece/python/bundled/sentencepiece/src/sentencepiece_processor.cc(848) [model_proto->ParseFromArray(serialized.data(), serialized.size())] 

In [58]:
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
tokenizer.pad_token = tokenizer.unk_token


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


train_set = loaded_data_lama['train']
train_set

In [68]:
def tokenize_batch(batch):
    texts = [sample['text'] for sample in batch]
    data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=1024)
    data['labels'] = data['input_ids'].clone()
    return data

dataloader = DataLoader(train_set, batch_size=2, shuffle=False,
                        drop_last=True,
                        collate_fn=partial(tokenize_batch))

with open("/data/rozen/home/e0833634/lama/protllama/notebooks/llama_dataloader.pkl", "wb") as f:
    pickle.dump(dataloader, f)

In [74]:
with open("/data/rozen/home/e0833634/lama/protllama/notebooks/llama_dataloader.pkl", "rb") as f:
    dataloader = pickle.load(f)

In [75]:
import numpy as np
for step, batch in enumerate(dataloader):
    if step == 1:
        break
    for k, v in batch.items():
        print(k, v)
        print(np.shape(v))

input_ids tensor([[    1,   320,  2042,  ...,   424, 23460, 29889],
        [    1,   320,  2042,  ...,   393,   445, 14581]])
torch.Size([2, 1024])
attention_mask tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])
torch.Size([2, 1024])
labels tensor([[    1,   320,  2042,  ...,   424, 23460, 29889],
        [    1,   320,  2042,  ...,   393,   445, 14581]])
torch.Size([2, 1024])
