In [1]:
from transformers import BertTokenizer, BertModel
from kmeans_pytorch import kmeans, kmeans_predict
from tqdm import tqdm
from pprint import pprint
import torch
import gc
from datasets import load_dataset
from nltk.tokenize.treebank import TreebankWordDetokenizer

In [2]:
model_name = 'bert-base-cased'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained("bert-base-cased", output_hidden_states=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Prepare Data

The CoNLL2003 dataset is first used to evaluate performance. The ground truth pos tags are used to label propernouns in the text to eliminate any error introduced by a pos tagger in the system.

In [3]:
original_tags = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
# New non-BIO scheme
new_tags = {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4}
new_tags_string = {0: 'O', 1: 'PER', 2: 'ORG', 3: 'LOC', 4: 'MISC'}

dataset = load_dataset("conll2003")

for split in ['train', 'validation', 'test']:
    dataset[split] = dataset[split].map(lambda batch: {
        'text': TreebankWordDetokenizer().detokenize(batch['tokens']),
        'encoded_PROPNS': [ids for tok, pos_tag in zip(batch['tokens'], batch['pos_tags']) for ids in tokenizer.encode(tok)[1:-1] if pos_tag in [22, 23]],
        'ner_tags': [new_tags[tag] for tag in batch['ner_tags']]
        }).remove_columns(['id', 'chunk_tags'])

print(dataset['train'][0])

Found cached dataset conll2003 (/home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-21c45632c0c27d5d.arrow
Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-1e1fee781bb923a7.arrow
Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-16bc004b4297e5a5.arrow


{'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'ner_tags': [2, 0, 4, 0, 0, 0, 4, 0, 0], 'text': 'EU rejects German call to boycott British lamb.', 'encoded_PROPNS': [7270]}


In [4]:
texts = dataset['train']['text']
encoded_PROPNS = dataset['train']['encoded_PROPNS']

## Training

Define a method to extract the context vectors of the proper nouns

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_context_vectors(enc_sent, ids_t):    
    # Ignore gradient to improve performance
    with torch.no_grad():
        outputs = model(enc_sent.input_ids.to(device), enc_sent.attention_mask.to(device))
        hidden_states = outputs.hidden_states

    # Extract model embeddings layer activations
    token_embeddings = torch.stack(hidden_states, dim=0)

    # Remove batches dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=1)

    # Swap layer and token dimensions
    token_embeddings = token_embeddings.permute(1, 0, 2)
    
    # Identify indices within encoded text to calculate context embeddings
    target_indices = (enc_sent.input_ids.T == ids_t).nonzero()

    # Use the sum of the last 4 embedding layers as an aggregation of context for the selected indices
    stacked_token_embeddings = token_embeddings.repeat(ids_t.size(0), 1, 1, 1)
    embedding_aggregate = torch.sum(stacked_token_embeddings[target_indices[:, 1], target_indices[:, 0], -4:], dim=1)

    return embedding_aggregate

Extract all of the context vectors using the defined method on the trainins split

In [7]:
context_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for sentence, encoded_ids in tqdm(zip(texts, encoded_PROPNS), desc='Collecting Contexts'):
    enc_sent = tokenizer.encode_plus(sentence, return_tensors='pt')
    ids_t = torch.tensor(encoded_ids)

    embedding_aggregate = get_context_vectors(enc_sent, ids_t)
    context_vectors = torch.cat((context_vectors, embedding_aggregate))

print(context_vectors.shape)
print(len([i for id in dataset['train']['encoded_PROPNS'] for i in id]))

Collecting Contexts: 14041it [02:36, 89.63it/s]

torch.Size([75289, 768])
66214





In [8]:
i = 2
enc_sent = tokenizer.encode_plus(texts[i], return_tensors='pt')
ids_t = torch.tensor(encoded_PROPNS[i])

print(enc_sent)
print(ids_t)

thing = (enc_sent.input_ids.T == ids_t).nonzero()
print(thing)

enc_sent.attention_mask[0, thing[:, 0]] = 0
print(enc_sent.attention_mask)

{'input_ids': tensor([[  101, 26660, 13329, 12649, 15928,  1820,   118,  4775,   118,  1659,
           102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
tensor([26660, 13329, 12649, 15928])
tensor([[1, 0],
        [2, 1],
        [3, 2],
        [4, 3]])
tensor([[1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]])


## Cluster

Cluster context vectors using kmeans with $k = 4$

In [9]:
cluster_ids, cluster_centres = kmeans(context_vectors, 4, distance='cosine', device=device)

print(context_vectors.shape)

running k-means on cuda..


[running kmeans]: 18it [00:05,  3.31it/s, center_shift=0.000000, iteration=18, tol=0.000100]  

torch.Size([75289, 768])





Manually inspect results

In [11]:
i = 106
text = dataset['train'][i]['text']
encoded_ids = dataset['train'][i]['encoded_PROPNS']

enc_sent = tokenizer.encode_plus(sentence, return_tensors='pt')
ids_t = torch.tensor(encoded_ids)

embedding_aggregate = get_context_vectors(enc_sent, ids_t)

if embedding_aggregate.size(0) == 1:
    embedding_aggregate = embedding_aggregate.repeat(2, 1)

clusters = kmeans_predict(embedding_aggregate, cluster_centres, device=device, distance='cosine')

encoded_sent = tokenizer.encode(text)
encoded_tokens = [(token, tokenizer.encode(token)[1:-1]) for token, pos_tag in zip(dataset['train'][i]['tokens'], dataset['train'][i]['pos_tags']) if pos_tag in [22, 23]]
clusters_list = clusters.tolist()

thing = [(enc_toks[0], [clusters_list.pop(0) for _ in range(len(enc_toks[1]))]) for enc_toks in encoded_tokens]
print(thing)

predicting on cuda..


IndexError: pop from empty list

## Evaluate

Collect cluster labels and find which assignment of labels obtains the highest f-1. This saves manually labelling them while testing but manually labelling the clusters would be necessary on an unlabelled dataset

In [12]:
test_texts = dataset['test']['text']
test_encoded_PROPNS = dataset['test']['encoded_PROPNS']
test_labels = dataset['test']['ner_tags']

In [5]:
training_examples = []
for i, example in tqdm(enumerate(dataset['train'])):
    target_mask = torch.tensor([0], dtype=torch.int8)
    encoded_string = torch.tensor([101], dtype=int)

    for token, pos_tag in zip(example['tokens'], example['pos_tags']):
        encoded_tok = tokenizer.encode(token, return_tensors='pt')[0, 1:-1]
        encoded_string = torch.cat((encoded_string, encoded_tok))

        target_mask = torch.cat((target_mask, torch.full_like(encoded_tok, 1 if pos_tag in [22, 23] else 0)))

    target_mask = torch.cat((target_mask, torch.tensor([0]))).unsqueeze(0)
    encoded_string = torch.cat((encoded_string, torch.tensor([102]))).unsqueeze(0)

    training_examples.append({'input_ids': encoded_string, 'target_mask': target_mask})

print(len(training_examples))

14041it [00:38, 361.34it/s]

14041





In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
def get_context_vectors_new(input_ids, attention_mask, target_mask):
    # Ignore gradient to improve performance
    with torch.no_grad():
        outputs = model(input_ids.to(device), attention_mask.to(device))
        hidden_states = outputs.hidden_states

    # Extract model embeddings layer activations
    token_embeddings = torch.stack(hidden_states, dim=0)

    # Remove batches dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=1)

    # Swap layer and token dimensions
    token_embeddings = token_embeddings.permute(1, 0, 2)
    
    # Identify indices within encoded text to calculate context embeddings
    target_indices = target_mask.to(device).nonzero()

    # Use the sum of the last 4 embedding layers as an aggregation of context for the selected indices
    stacked_token_embeddings = token_embeddings.repeat(torch.sum(target_mask), 1, 1, 1)
    embedding_aggregate = torch.sum(stacked_token_embeddings[torch.arange(0, torch.sum(target_mask)), target_indices[:, 1], -4:], dim=1)

    return embedding_aggregate

i = 3
attention_mask = torch.ones_like(training_examples[i]['input_ids'])
get_context_vectors_new(training_examples[i]['input_ids'], attention_mask, training_examples[i]['target_mask'])

tensor([[-0.3164, -0.4556,  1.3220,  ...,  4.5251,  0.2369,  1.0924],
        [-3.7310, -2.7809,  0.7833,  ...,  2.3378, -0.6178,  0.5139],
        [ 0.5765, -4.3769, -2.2516,  ...,  0.3305,  0.6382,  2.1370]],
       device='cuda:0')

In [18]:
context_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for example in tqdm(training_examples, desc='Collecting Contexts'):
    attention_mask = torch.ones_like(example['input_ids'])

    embedding_aggregate = get_context_vectors_new(example['input_ids'], attention_mask, example['target_mask'])
    context_vectors = torch.cat((context_vectors, embedding_aggregate))

print(context_vectors.shape)
print(len([i for id in dataset['train']['encoded_PROPNS'] for i in id]))

Collecting Contexts: 100%|██████████| 14041/14041 [02:34<00:00, 91.14it/s]

torch.Size([66214, 768])
66214





Identify cluster allocations

In [None]:
test_cluster_ids = kmeans_predict(test_context_vectors, cluster_centres, device=device)

predicting on cuda..
tensor([1, 1, 1,  ..., 2, 2, 2])


In [None]:
# word, tokenized, ner_tag

# word
tokenised = [[{'token': token, 'tokenised': tokenizer.encode(token)[1:-1], 'ner_tag': ner_tag, 'pos_tag': pos_tag} for token, pos_tag, ner_tag in zip(example['tokens'], example['pos_tags'], example['ner_tags']) if pos_tag in [22, 23]] for example in dataset['test']]
pprint(tokenised)
print(len(test_cluster_ids))
print(sum([len(x_i['tokenised']) for x in tokenised for x_i in x]))

[[{'ner_tag': 3,
   'pos_tag': 22,
   'token': 'JAPAN',
   'tokenised': [147, 12240, 14962]},
  {'ner_tag': 0,
   'pos_tag': 22,
   'token': 'LUCKY',
   'tokenised': [149, 21986, 2428, 3663]},
  {'ner_tag': 0, 'pos_tag': 22, 'token': 'WIN', 'tokenised': [160, 11607]},
  {'ner_tag': 1,
   'pos_tag': 22,
   'token': 'CHINA',
   'tokenised': [24890, 11607, 1592]}],
 [{'ner_tag': 1,
   'pos_tag': 22,
   'token': 'Nadim',
   'tokenised': [11896, 3309, 1306]},
  {'ner_tag': 1,
   'pos_tag': 22,
   'token': 'Ladki',
   'tokenised': [2001, 1181, 2293]}],
 [{'ner_tag': 3,
   'pos_tag': 22,
   'token': 'AL-AIN',
   'tokenised': [18589, 118, 19016, 2249]},
  {'ner_tag': 3, 'pos_tag': 22, 'token': 'United', 'tokenised': [1244]},
  {'ner_tag': 3, 'pos_tag': 22, 'token': 'Arab', 'tokenised': [4699]},
  {'ner_tag': 3, 'pos_tag': 23, 'token': 'Emirates', 'tokenised': [14832]}],
 [{'ner_tag': 3, 'pos_tag': 22, 'token': 'Japan', 'tokenised': [1999]},
  {'ner_tag': 4, 'pos_tag': 22, 'token': 'Cup', 'toke