# Tokenization Demonstration

To work with UniRel and similar models, we need to understand how their data is tokenized.

## Imports

In [1]:
from pprint import pprint
from transformers import BertTokenizerFast
from dataprocess.data_processor import UniRelDataProcessor
from dataprocess.dataset import UniRelSpanDataset

## Tokenizer

We load `bert-base-cased` with similar arguments as used in the UniRel code.

In [2]:
tokenizer = BertTokenizerFast.from_pretrained(
    "bert-base-cased",
    do_basic_tokenize=False,
)

When we call the tokenizer on a string, it becomes a list of integers.

In [3]:
text = "The quick brown fox jumps over the lazy dog."
encoding = tokenizer.encode_plus(
    text,
    max_length=16,
    padding="max_length",
    truncation=True,
)
pprint(encoding)

{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
 'input_ids': [101,
               1109,
               3613,
               3058,
               17594,
               15457,
               1166,
               1103,
               16688,
               3676,
               119,
               102,
               0,
               0,
               0,
               0],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


In [4]:
decoding = tokenizer.decode(encoding["input_ids"])
print(decoding)

[CLS] The quick brown fox jumps over the lazy dog. [SEP] [PAD] [PAD] [PAD] [PAD]


The string is surrounded by special `[CLS]` and `[SEP]` tokens, and the remaining space is filled with the `[PAD]` padding token. Comparing against the `input_ids` of the encoding, we can conclude

In [5]:
assert tokenizer.decode([0]) == "[PAD]"
assert tokenizer.decode([101]) == "[CLS]"
assert tokenizer.decode([102]) == "[SEP]"

If the text is longer than the allotted tokens, it is cut short but still surrounded by `[CLS]` and `[SEP]` tokens:

In [6]:
truncated = tokenizer.encode(text, max_length=8, truncation=True)
print(tokenizer.decode(truncated))
assert len(truncated) == 8

[CLS] The quick brown fox jumps over [SEP]


## UniRel Data

The UniRel `Dataset` loads the concatenated tokens of the text and relation.

In [7]:
text_len = 100
data_processor = UniRelDataProcessor(
    root="data",
    tokenizer=tokenizer,
    dataset_name="nyt",
)
dataset = UniRelSpanDataset(
    data_processor.get_train_sample(n_samples=1, token_len=text_len),
    data_processor,
    tokenizer,
    mode='test',
    ignore_label=-100,
    model_type='bert',
    ngram_dict=None,
    max_length=text_len+2,
    predict=True,
    eval_type="test"
)
example = dataset[0]
pprint(example)

100%|██████████| 1/1 [00:00<00:00, 672.70it/s]

37
more than 100: 0
more than 150: 0
{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1]),
 'head_label': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]),
 'input_ids': tensor([  101,  3559, 15278, 18082,  2249,  9960,  2349, 11185,  2038, 21715,
         5541,   132,  1145,  1120,  6523,  1181,  1531,   117,  3883,  7807,
         1513,   118,  1113,   118,  6236,   117,   151,   119,   162,   119,
          117,  1351,   122,   118




In [8]:
print(tokenizer.decode(example["input_ids"]))

[CLS] Massachusetts ASTON MAGNA Great Barrington ; also at Bard College, Annandale - on - Hudson, N. Y., July 1 - Aug. [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] advisors founders industry holding founded shareholder company country administrative capital contains neighbor death geographic people children ethnicity nationality lived birthplace profession religion location teams


### Sequence Fields

Of the many fields in the dictionary, `attention_mask`, `input_ids`, and `token_type_ids` are 1-dimensional and have the same lengths. They correspond to the text concatenated with relations.

The relation tokens (one word representing each relation) are always aligned to the end of the sequence. If the text is shorter than the allotted length, then there are padding tokens added between the text and the relations, as in the sample above. If the text is longer, it is truncated.

Notably, the `[CLS]` token is present but the `[SEP]` token is missing from the encoding (it is replaced with a `[PAD]` token). Thus, even if the text is truncated, there is thus at least one `[PAD]` token between the text and the relations. The value `0` in the `attention_mask` or `input_ids` incidates `[PAD]` tokens. Overall, we can write the structure as

```
[CLS] (text) [PAD]+ (relations)
```

`token_type_ids` is `1` for relation tokens and `0` for everything else. Thus, to get the number of text tokens (or equivalently, the index of the first relation token), we can run the following:

In [9]:
example["token_type_ids"].argmax().item()

102

Notably, this is the same as `token_len_batch` so we can conclude that field stores the length of the text in the embedding. This is also the same as `text_len + 2`. That is, the encoding can contain up to `text_len` non-special text tokens, plus a leading `[CLS]` and a trailing `[PAD]`.

### Matrix Fields

The other fields are `head_label`, `span_label`, and `tail_label` corresponding to the three ground truth interaction matrices.

In [10]:
sequence_length, = example["input_ids"].shape
assert example["head_label"].shape \
    == example["span_label"].shape \
    == example["tail_label"].shape \
    == (sequence_length, sequence_length)