In [None]:
from tokenizers import SentencePieceBPETokenizer, SentencePieceUnigramTokenizer
from tokenizers.trainers import UnigramTrainer
from tokenizers.processors import BertProcessing
from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer
import datasets
import pandas as pd
from datasets import load_from_disk
from pathlib import Path

import wandb

In [None]:
run = wandb.init(project='protobert', job_type="tokenizer_train")

In [None]:
data_at = run.use_artifact('uniref_1m:latest')
dataset_dir = Path(data_at.download())

In [None]:
sample_dataset = load_from_disk(dataset_dir/'uniref_1m')

In [None]:
%%time
tokenizer = SentencePieceBPETokenizer()
tokenizer.train_from_iterator(sample_dataset["text"], vocab_size=1000, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

In [None]:
ls

In [None]:
tokenizer.save('proteins-tmp')
tokenizer = PreTrainedTokenizerFast(tokenizer_file='proteins-tmp')
tokenizer._tokenizer.post_processor = BertProcessing(
    ("</s>", 2),
    ("<s>", 0),
)
tokenizer.mask_token = "<mask>"
tokenizer.cls_token = "</s>"
tokenizer.sep_token = "<s>"
tokenizer.pad_token = "<pad>"
tokenizer.unk_token = "<unk>"

tokenizer.save_pretrained('proteins-base')

In [None]:
o = tokenizer('ASDFAFDGADFGADFGHAG')
tokenizer.decode(o['input_ids'])
for i in o['input_ids']:
    print(f'{i}: {tokenizer.decode(i)}')

In [None]:
tok_at = wandb.Artifact('uniref_1m_tokenizer', type="tokenizer")

In [None]:
tok_at.add_dir('proteins-base', name='uniref_1m_tokenizer')

In [None]:
run.log_artifact(tok_at)

In [None]:
run.finish()