In [2]:
import numpy as np
import torch
from torch import nn
from transformers import LongformerTokenizer, LongformerModel

In [2]:
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.bias']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
document = """Transformer-based models are unable to process long sequences due to their self-attention operation, which scales quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or longer. Longformer's attention mechanism is a drop-in replacement for the standard self-attention and combines a local windowed attention with a task motivated global attention. Following prior work on long-sequence transformers, we evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream tasks. Our pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on WikiHop and TriviaQA."""

In [4]:
tokens = tokenizer.tokenize(document)

In [5]:
print(f"First 20 tokens = {tokens[:20]}")
print(f"Last  20 tokens = {tokens[-20:]}")
print(f"Total number of tokens = {len(tokens)}")

First 20 tokens = ['Trans', 'former', '-', 'based', 'Ġmodels', 'Ġare', 'Ġunable', 'Ġto', 'Ġprocess', 'Ġlong', 'Ġsequences', 'Ġdue', 'Ġto', 'Ġtheir', 'Ġself', '-', 'att', 'ention', 'Ġoperation', ',']
Last  20 tokens = ['Ġand', 'Ġsets', 'Ġnew', 'Ġstate', '-', 'of', '-', 'the', '-', 'art', 'Ġresults', 'Ġon', 'ĠWiki', 'Hop', 'Ġand', 'ĠTri', 'via', 'Q', 'A', '.']
Total number of tokens = 195


In [8]:
print(f"special ids  = {tokenizer.all_special_ids}")
print(f"special tokens map = {tokenizer.special_tokens_map}")
print()

print(f"pad token id = {tokenizer.pad_token_id}, "
      f"pad_token = {tokenizer.pad_token}")
print(f"unk token id = {tokenizer.unk_token_id}, "
      f"unk token = {tokenizer.unk_token}")
print(f"beginning-of-sentence (bos), bos token id = {tokenizer.bos_token_id}, "
      f"bos token = {tokenizer.bos_token}")
print(f"end-of-sentence (eos), eos token id = {tokenizer.eos_token_id}, "
      f"eos token = {tokenizer.eos_token}")
print(f"mask token id = {tokenizer.mask_token_id}, "
      f"mask token = {tokenizer.mask_token}")
print(f"cls token id = {tokenizer.cls_token_id}, "
      f"cls token = {tokenizer.cls_token}")
print(f"sep token id = {tokenizer.sep_token_id}, "
      f"sep token = {tokenizer.sep_token}")
print()

print(f"model max length = {tokenizer.model_max_length}")
print(f"max length sentence = {tokenizer.max_len_single_sentence}")
print(f"max length sentence pair = {tokenizer.max_len_sentences_pair}")

special ids  = [0, 2, 3, 1, 50264]
special tokens map = {'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}

pad token id = 1, pad_token = <pad>
unk token id = 3, unk token = <unk>
beginning-of-sentence (bos), bos token id = 0, bos token = <s>
end-of-sentence (eos), eos token id = 2, eos token = </s>
mask token id = 50264, mask token = <mask>
cls token id = 0, cls token = <s>
sep token id = 2, sep token = </s>

model max length = 4096
max length sentence = 4094
max length sentence pair = 4092


In [9]:
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(f"Number of token ids = {len(token_ids)}")
print(f"First 10 token ids = {token_ids[:10]}")
print(f"Last 10 token ids = {token_ids[-10:]}")

Number of token ids = 195
First 10 token ids = [19163, 22098, 12, 805, 3092, 32, 3276, 7, 609, 251]
Last 10 token ids = [775, 15, 45569, 30158, 8, 6892, 11409, 1864, 250, 4]


In [11]:
tokenizer(document)

{'input_ids': [0, 19163, 22098, 12, 805, 3092, 32, 3276, 7, 609, 251, 26929, 528, 7, 49, 1403, 12, 2611, 19774, 2513, 6, 61, 21423, 15694, 338, 23050, 19, 5, 13931, 5933, 4, 598, 1100, 42, 22830, 6, 52, 6581, 5, 2597, 22098, 19, 41, 1503, 9562, 14, 21423, 24248, 23099, 19, 13931, 5933, 6, 442, 24, 1365, 7, 609, 2339, 9, 1583, 9, 22121, 50, 1181, 4, 2597, 22098, 18, 1503, 9562, 16, 10, 1874, 12, 179, 5010, 13, 5, 2526, 1403, 12, 2611, 19774, 8, 15678, 10, 400, 2931, 196, 1503, 19, 10, 3685, 7958, 720, 1503, 4, 3515, 2052, 173, 15, 251, 12, 46665, 7891, 268, 6, 52, 10516, 2597, 22098, 15, 2048, 12, 4483, 2777, 19039, 8, 3042, 194, 12, 1116, 12, 627, 12, 2013, 775, 15, 2788, 398, 8, 1177, 39224, 398, 4, 96, 5709, 7, 144, 2052, 173, 6, 52, 67, 11857, 9946, 2597, 22098, 8, 8746, 594, 4438, 24, 15, 10, 3143, 9, 18561, 8558, 4, 1541, 11857, 26492, 2597, 22098, 6566, 9980, 33334, 3830, 11126, 38495, 15, 251, 3780, 8558, 8, 3880, 92, 194, 12, 1116, 12, 627, 12, 2013, 775, 15, 45569, 30158, 8, 6

In [18]:
token_ids_pt = torch.LongTensor(token_ids)
attn_mask = torch.FloatTensor([1 for _ in range(len(token_ids))])
global_attn_mask = torch.FloatTensor([0 for _ in range(len(token_ids))])
global_attn_mask[100] = 1
global_attn_mask[150] = 1
output = model(token_ids_pt.unsqueeze(0), attn_mask.unsqueeze(0), 
                global_attn_mask.unsqueeze(0))

In [19]:
type(output)

transformers.models.longformer.modeling_longformer.LongformerBaseModelOutputWithPooling

In [21]:
output.last_hidden_state.shape, output.last_hidden_state.dtype

(torch.Size([1, 195, 768]), torch.float32)

In [24]:
embed = nn.Embedding(3, 10)
embedding_ids = np.random.randint(0, 3, size=(5, 12))
embedding_ids_pt = torch.IntTensor(embedding_ids)
embedding_out = embed(embedding_ids_pt)
print(embedding_out.shape, embedding_out.dtype)

torch.Size([5, 12, 10]) torch.float32


In [26]:
embedding_argmax = embedding_out.argmax(dim=2).int()
print(embedding_argmax.dtype, embedding_argmax.shape)

torch.int32 torch.Size([5, 12])


In [27]:
embedding_argmax

tensor([[9, 9, 9, 6, 9, 6, 9, 9, 6, 9, 9, 9],
        [9, 9, 6, 9, 6, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 6, 9, 6, 6, 9, 9, 6, 9],
        [6, 9, 6, 9, 9, 9, 6, 6, 9, 9, 6, 9],
        [6, 6, 9, 6, 6, 6, 6, 9, 9, 6, 9, 6]], dtype=torch.int32)

In [29]:
B = 10
L = 12
C = 3
label_ids = torch.IntTensor(np.random.randint(0, C, size=(B, L)))
attn = np.zeros((B, L), dtype=float)
for i in range(B):
    l = np.random.randint(0, L + 1)
    attn[i, :l] = 1.
attn = torch.FloatTensor(attn)
logits = torch.FloatTensor(np.random.randn(B, L, C))

print(f"label_ids : shape={label_ids.shape} dtype={label_ids.dtype}")
print(f"attn      : shape={attn.shape} dtype={attn.dtype}")
print(f"logits    : shape={logits.shape} dtype={logits.dtype}")

label_ids : shape=torch.Size([10, 12]) dtype=torch.int32
attn      : shape=torch.Size([10, 12]) dtype=torch.float32
logits    : shape=torch.Size([10, 12, 3]) dtype=torch.float32


In [31]:
attn

tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [33]:
label_ids

tensor([[0, 1, 0, 0, 1, 1, 0, 0, 2, 1, 2, 2],
        [2, 0, 2, 1, 2, 1, 1, 2, 1, 0, 1, 0],
        [0, 2, 0, 0, 0, 0, 1, 2, 1, 1, 0, 2],
        [2, 1, 2, 2, 0, 1, 0, 1, 2, 1, 1, 0],
        [2, 0, 0, 2, 2, 0, 0, 2, 1, 1, 2, 0],
        [2, 1, 2, 1, 2, 0, 2, 2, 2, 0, 2, 0],
        [1, 1, 1, 2, 0, 1, 0, 0, 1, 0, 2, 1],
        [2, 2, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0],
        [0, 1, 0, 2, 0, 0, 0, 0, 0, 2, 1, 0],
        [2, 0, 1, 2, 2, 1, 0, 2, 2, 2, 2, 2]], dtype=torch.int32)

In [34]:
print(label_ids[attn == 1.])
print(label_ids[attn == 1.].shape)

tensor([0, 1, 0, 0, 1, 1, 0, 2, 0, 2, 0, 0, 0, 0, 2, 1, 2, 2, 0, 0, 2, 2, 0, 0,
        2, 1, 1, 2, 1, 2, 1, 2, 0, 1, 1, 1, 2, 0, 2, 2, 0, 1, 1, 0, 0, 1, 0, 1,
        0, 2, 0, 0, 0, 0, 0, 2, 2, 0, 1, 2, 2, 1, 0, 2, 2, 2, 2, 2],
       dtype=torch.int32)
torch.Size([68])


In [47]:
print(logits.flatten(0, 1).shape, attn.flatten().shape)
print(logits.flatten(0, 1)[attn.flatten() == 1.].shape)

torch.Size([120, 3]) torch.Size([120])
torch.Size([68, 3])


In [48]:
active_labels = label_ids[attn == 1.]
active_logits = logits.flatten(0, 1)[attn.flatten() == 1.]
print(f"active labels : shape = {active_labels.shape}, "
      f"dtype = {active_labels.dtype}")
print(f"active logits : shape = {active_logits.shape}, "
      f"dtype = {active_logits.dtype}")

active labels : shape = torch.Size([68]), dtype = torch.int32
active logits : shape = torch.Size([68, 3]), dtype = torch.float32


In [51]:
label_count = np.bincount(active_labels)
print(f"label distribution = {label_count}")

label distribution = [27 17 24]


In [59]:
weights = torch.FloatTensor(1/(1 + label_count), device="cpu")
print(f"class weight = {weights}, device = {weights.device}")

class weight = tensor([0.0357, 0.0556, 0.0400]), device = cpu


In [60]:
groundtruth = torch.IntTensor(np.random.randint(0, 5, size=(10, 50)))
predictions = torch.IntTensor(np.random.randint(0, 5, size=(10, 50)))
identifiers = torch.IntTensor(np.random.randint(0, 100, size=(10)))

for gt, pred, ident in zip(groundtruth, predictions, identifiers):
    print(f"groundtruth = {gt.dtype} {gt.shape}, "
          f"prediction = {pred.dtype} {pred.shape}, id = {ident}")

groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 40
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 28
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 59
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 64
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 69
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 7
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 36
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 42
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 93
groundtruth = torch.int32 torch.Size([50]), prediction = torch.int32 torch.Size([50]), id = 9


In [4]:
x = torch.randint(0, 5, size=(1, 1)).to("cuda:0")

In [5]:
x

tensor([[4]], device='cuda:0')

In [7]:
type(x.item())

int