In [1]:
import torch
import random
import sentencepiece as spm
from transformers import ReformerConfig, ReformerModelWithLMHead, ReformerTokenizer
from torch.utils.data import DataLoader, Dataset

NUM_BATCHES = None
BATCH_SIZE = 6
GRADIENT_ACCUMULATE_EVERY = 3
LEARNING_RATE = 0.01
VALIDATE_EVERY  = 20
SEQ_LEN = 4608

In [2]:
import torch

# Encoding
def encode(list_of_strings, pad_to_max_length=True, pad_token_id=0):
    max_length = max([len(string) for string in list_of_strings])

    # create emtpy tensors
    attention_masks = torch.zeros((len(list_of_strings), max_length), dtype=torch.long)
    input_ids = torch.full((len(list_of_strings), max_length), pad_token_id, dtype=torch.long)

    for idx, string in enumerate(list_of_strings):
        # make sure string is in byte format
        if not isinstance(string, bytes):
            string = str.encode(string)

        input_ids[idx, :len(string)] = torch.tensor([x + 2 for x in string])
        attention_masks[idx, :len(string)] = 1

    return input_ids, attention_masks

# Decoding
def decode(outputs_ids):
    decoded_outputs = []
    for output_ids in outputs_ids.tolist():
        # transform id back to char IDs < 2 are simply transformed to ""
        decoded_outputs.append("".join([chr(x - 2) if x > 1 else "" for x in output_ids]))
    return decoded_outputs

In [3]:
encode(['ABCDEF'])

(tensor([[67, 68, 69, 70, 71, 72]]), tensor([[1, 1, 1, 1, 1, 1]]))

In [4]:
spm.SentencePieceTrainer.Train("--input=./data/tokenizer_training/AAresiduals.txt \
                                --vocab_size=28 \
                                --model_prefix=sequence_tokenizer \
                                --model_type=char \
                                --character_coverage=1.0")
tokenizer = ReformerTokenizer(vocab_file="sequence_tokenizer.model", do_lower_case=False, model_max_length=SEQ_LEN)

In [5]:
configuration = ReformerConfig.from_pretrained("google/reformer-enwik8")
configuration.axial_pos_shape = (64, 72)
configuration.max_position_embeddings=SEQ_LEN
configuration.vocab_size=tokenizer.vocab_size
configuration.save_pretrained('model/config_enwik8_modified/')
configuration = ReformerConfig.from_pretrained('model/config_enwik8_modified/')
model = ReformerModelWithLMHead(configuration)

In [6]:
model.train()

ReformerModelWithLMHead(
  (reformer): ReformerModel(
    (embeddings): ReformerEmbeddings(
      (word_embeddings): Embedding(28, 1024)
      (position_embeddings): AxialPositionEmbeddings(
        (weights): ParameterList(
            (0): Parameter containing: [torch.FloatTensor of size 64x1x256]
            (1): Parameter containing: [torch.FloatTensor of size 1x72x768]
        )
      )
    )
    (encoder): ReformerEncoder(
      (layers): ModuleList(
        (0): ReformerLayer(
          (attention): ReformerAttention(
            (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (self_attention): LocalSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=False)
              (key): Linear(in_features=1024, out_features=1024, bias=False)
              (value): Linear(in_features=1024, out_features=1024, bias=False)
            )
            (output): ReformerSelfOutput(
              (dense): Linear(in_features=1024,

In [12]:
tokenizer.max_len

4608

In [13]:
input_id = torch.tensor(tokenizer.encode("ABCDEFGH", add_special_tokens=True, pad_to_max_length=True)).unsqueeze(0)  # Batch size 1


In [14]:
input_id.shape

torch.Size([1, 4608])

In [16]:
outputs = model(input_id, labels = input_id)

In [17]:
loss, prediction_scores = outputs[:2]

In [20]:
prediction_scores.shape

torch.Size([1, 4608, 28])

In [27]:
input_id[torch.argmax(prediction_scores, dim=2) == input_id]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [22]:
input_id

tensor([[0, 3, 4,  ..., 0, 0, 0]])

In [30]:
[torch.randint(7, (1,)).item() for i in range(9)]

[3, 0, 6, 1, 6, 2, 6, 2, 5]

In [3]:
class LineByLineTextDataset(Dataset):
    """modified: 
    https://github.com/huggingface/transformers/blob/cb3c2212c79d7ff0a4a4e84c3db48371ecc1c15d/src/transformers/data/datasets/language_modeling.py#L77
    """

    def __init__(self, tokenizer, file_path: str):
        assert os.path.isfile(file_path)

        with open(file_path, encoding="utf-8") as f:
            lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]

#         lines = lines[:50_000]
        batch_encoding = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=tokenizer.vocab_size)
        self.examples = batch_encoding["input_ids"]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> torch.Tensor:
        return torch.tensor(self.examples[i], dtype=torch.long)

In [18]:
# input_ids = torch.tensor(tokenizer.encode("ALKLAKALK", 
#                                           add_special_tokens=True, 
#                                           max_length=SEQ_LEN, 
#                                           pad_to_max_length=True)).unsqueeze(0)  # Batch size 1