In [1]:
from transformers import BertTokenizer, BertModel
from kmeans_pytorch import kmeans, kmeans_predict
from tqdm import tqdm
from pprint import pprint
import torch
import numpy as np
import pandas as pd
import gc
from datasets import load_dataset
from collections import Counter, defaultdict

In [2]:
model_name = 'bert-base-cased'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained("bert-base-cased", output_hidden_states=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Preparation

The CoNLL2003 dataset is first used to evaluate performance. The ground truth ner tags are used only select named entities to allow the class allocation to be tested independently by assumung perfect extraction. Later work will be conducted into identifying which entities should be labelled. NER tags are mapped to a non-BIO based scheme to reduce the number of clusters.

In [3]:
original_tags = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
# New non-BIO scheme
new_tags = {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4}
new_tags_string = {0: 'O', 1: 'PER', 2: 'ORG', 3: 'LOC', 4: 'MISC'}

dataset = load_dataset("conll2003")

def process_batch(example):
    target_mask = torch.tensor([0], dtype=torch.int8)
    input_ids = torch.tensor([101], dtype=int)

    for token, ner_tag in zip(example['tokens'], example['ner_tags']):
        encoded_tok = tokenizer.encode(token, return_tensors='pt')[0, 1:-1]

        input_ids = torch.cat((input_ids, encoded_tok))
        target_mask = torch.cat((target_mask, torch.full_like(encoded_tok, 1 if ner_tag != 0 else 0)))

    target_mask = torch.cat((target_mask, torch.tensor([0]))).unsqueeze(0)
    input_ids = torch.cat((input_ids, torch.tensor([102]))).unsqueeze(0)

    return {'input_ids': input_ids, 'target_mask': target_mask}

for split in ['train', 'validation', 'test']:
    dataset[split] = dataset[split].map(lambda batch: {
        **process_batch(batch),
        'ner_tags': [new_tags[tag] for tag in batch['ner_tags']]
        }).remove_columns(['id', 'chunk_tags'])

print(dataset['train'][0])

Found cached dataset conll2003 (/home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-e643d3196526a30f.arrow
Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-4ae7629d9bae3654.arrow
Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-d22db9ce0f6a945f.arrow


{'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'ner_tags': [2, 0, 4, 0, 0, 0, 4, 0, 0], 'input_ids': [[101, 7270, 22961, 1528, 1840, 1106, 21423, 1418, 2495, 12913, 119, 102]], 'target_mask': [[0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]]}


## Training

Define a method to extract the context vectors of the extracted entities

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
def get_context_vectors(input_ids, attention_mask, target_mask):
    # Ignore gradient to improve performance
    with torch.no_grad():
        outputs = model(input_ids.to(device), attention_mask.to(device))
        hidden_states = outputs.hidden_states

    # Extract model embeddings layer activations
    token_embeddings = torch.stack(hidden_states, dim=0)

    # Remove batches dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=1)

    # Swap layer and token dimensions
    token_embeddings = token_embeddings.permute(1, 0, 2)
    
    # Identify indices within encoded text to calculate context embeddings
    target_indices = target_mask.to(device).nonzero()

    # Use the sum of the last 4 embedding layers as an aggregation of context for the selected indices
    stacked_token_embeddings = token_embeddings.repeat(torch.sum(target_mask), 1, 1, 1)
    embedding_aggregate = torch.sum(stacked_token_embeddings[torch.arange(0, torch.sum(target_mask)), target_indices[:, 1], -4:], dim=1)

    return embedding_aggregate

Extract all of the context vectors using the defined method on the training split

In [9]:
context_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for example in tqdm(dataset['train'], desc='Collecting Contexts'):
    input_ids = torch.tensor(example['input_ids'])
    attention_mask = torch.ones_like(input_ids)
    target_mask = torch.tensor(example['target_mask'])

    embedding_aggregate = get_context_vectors(input_ids, attention_mask, target_mask)
    context_vectors = torch.cat((context_vectors, embedding_aggregate))

print(context_vectors.shape)
print(sum(sum(x[0]) for x in dataset['train']['target_mask']))

Collecting Contexts: 100%|██████████| 14041/14041 [02:38<00:00, 88.77it/s]

torch.Size([65030, 768])
65030





## Cluster

Cluster context vectors using kmeans with $k = 4$

In [10]:
# Seeds for reproducing cluster centres
torch.random.seed = 42
torch.cuda.seed = 42
np.random.seed(42)

distance = 'cosine' # 'euclidean'

cluster_ids, cluster_centres = kmeans(context_vectors, 4, distance=distance, device=device)

running k-means on cuda..


[running kmeans]: 39it [00:10,  3.60it/s, center_shift=0.000000, iteration=39, tol=0.000100]  


Manually evaluate clusters

In [28]:
i = 3
example = dataset['train'][i]

input_ids = torch.tensor(example['input_ids'])
attention_mask = torch.ones_like(input_ids)
target_mask = torch.tensor(example['target_mask'])

test_context_vectors = get_context_vectors(input_ids, attention_mask, target_mask)

if len(test_context_vectors) == 1:
    test_context_vectors = test_context_vectors.repeat(2, 1)

test_cluster_ids = kmeans_predict(test_context_vectors, cluster_centres, distance=distance, device=device)
test_clusters_ids_list = test_cluster_ids.squeeze().tolist()

token_ids = {token: tokenizer.encode(token)[1:-1] for token, ner_tag in zip(example['tokens'], example['ner_tags']) if ner_tag != 0}
token_predictions = {token: [test_clusters_ids_list.pop(0) for _ in range(len(ids))] for token, ids in token_ids.items()}

print(token_predictions)

predicting on cuda..
{'European': [3], 'Commission': [3], 'German': [3], 'British': [3]}


## Evaluate

Calculate context vectors for all validation samples

In [11]:
val_context_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for example in tqdm(dataset['validation'], desc='Collecting Contexts'):
    input_ids = torch.tensor(example['input_ids'])
    attention_mask = torch.ones_like(input_ids)
    target_mask = torch.tensor(example['target_mask'])

    embedding_aggregate = get_context_vectors(input_ids, attention_mask, target_mask)
    val_context_vectors = torch.cat((val_context_vectors, embedding_aggregate))

print(val_context_vectors.shape)
print(sum(sum(x[0]) for x in dataset['validation']['target_mask']))

Collecting Contexts: 100%|██████████| 3250/3250 [00:36<00:00, 88.15it/s]

torch.Size([16225, 768])
16225





Identify cluster allocations

In [12]:
val_cluster_ids = kmeans_predict(val_context_vectors, cluster_centres, distance=distance, device=device)

predicting on cuda..


Assign cluster ids to each token in the dataset and obtain a single cluster id for items with multiple, i.e. words composed of sub-word level tokens, by taking the mode or majority vote. Using `max(set(lst), key=lst.count))` for majority voting yields the most frequent list item or, in the case that there is a tie for most frequent, the item with the lowest value:

```python
a = [1, 1, 1, 0, 0]
b = [1, 1, 0, 0]
max(set(a), key=a.count)
# >> 1
max(set(b), key=b.count)
# >> 0
```
Another implementation may find that the head sub-word token is more informative.

Also collect the ground truth NER tag for each entity for comparison. `Note`: The cluster id numbers will not correspond to the NER tag numbers. They are two sets of labels that need to be mapped to eachother in some way.

Finally, find which mapping of cluster ids to NER ids yields the best f-1 performance. `Note`: This calculation does not consider entities that weren't tagged by using PROPN. This is to control for the initial entity extraction task which will be reviewed at another stage. This is to just evaluate whether this style of clustering is suitable for identifying entity classes.

In [17]:
def evaluate(cluster_ids):
    sample_predictions = []

    for example in tqdm(dataset['validation']):
        token_ids = {f'{i}-{token}': {
            'ids': tokenizer.encode(token)[1:-1],
            'ner_tag': ner_tag
        } for i, (token, ner_tag) in enumerate(zip(example['tokens'], example['ner_tags'])) if ner_tag != 0}

        token_predictions = {token: {
            'cluster_id': max(set((lst := [cluster_ids.pop(0) for _ in range(len(attributes['ids']))])), key=lst.count),
            'ner_tag': attributes['ner_tag']
        } for token, attributes in token_ids.items()}

        sample_predictions.append(token_predictions)

    mapping_freqs = Counter([(x_i['cluster_id'], x_i['ner_tag']) for x in sample_predictions for x_i in x.values()])
    p = defaultdict(dict)

    for k, v in mapping_freqs.items():
        p[k[0]].update({k[1]: v})

    df = pd.DataFrame.from_dict(data=p, orient='index').sort_index()
    df = df.reindex(sorted(df.columns), axis=1)

    eye = np.eye(len(df.columns), dtype=bool)
    total = np.sum(df.to_numpy().flatten())

    results = []
    for i in range(len(df.columns)):
        mapping = list(df[(df * np.roll(eye, i, axis=1)) != 0].stack().index)
        accuracy = np.sum((df * np.roll(eye, i, axis=1)).to_numpy().flatten()) / total * 100
        results.append({'mapping': mapping, 'accuracy': accuracy})

    return results

val_clusters_ids_list = val_cluster_ids.squeeze().tolist()
results = evaluate(val_clusters_ids_list)
pprint(results)

100%|██████████| 3250/3250 [00:01<00:00, 2159.95it/s]

[{'accuracy': 17.59851214692549, 'mapping': [(0, 1), (1, 2), (2, 3), (3, 4)]},
 {'accuracy': 5.579449029408346, 'mapping': [(0, 2), (1, 3), (2, 4), (3, 1)]},
 {'accuracy': 30.861327443914917, 'mapping': [(0, 3), (1, 4), (2, 1), (3, 2)]},
 {'accuracy': 45.96071137975125, 'mapping': [(0, 4), (1, 1), (2, 2), (3, 3)]}]





# Exploring Masking

Using a more testable implementation of the work in this notebook, explore variations on the original model to evaluate improvement. First import developed class

In [18]:
from context_vector_clustering import ContextClustering

Fit the class to the training examples

In [22]:
X_fit = ContextClustering(random_state=42).fit(dataset['train'], masked_targets=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Collecting Contexts: 100%|██████████| 14041/14041 [02:40<00:00, 87.51it/s]


running k-means on cuda..


[running kmeans]: 48it [00:12,  3.71it/s, center_shift=0.000000, iteration=48, tol=0.000100]  


## Evaluate

Predict the class assignment of the validation set

In [23]:
y_hat = X_fit.predict(dataset['validation'])
y_hat_list = y_hat.squeeze().tolist()

Collecting Contexts: 100%|██████████| 3250/3250 [00:36<00:00, 90.25it/s]

predicting on cuda..





Evaluate performance in the same way as before. Masking the targetted tokens performed very badly - about random.

In [24]:
results = evaluate(y_hat_list)
pprint(results)

100%|██████████| 3250/3250 [00:01<00:00, 2214.33it/s]

[{'accuracy': 18.935255143554574, 'mapping': [(0, 1), (1, 2), (2, 3), (3, 4)]},
 {'accuracy': 25.525979309543185, 'mapping': [(0, 2), (1, 3), (2, 4), (3, 1)]},
 {'accuracy': 28.12972218993374, 'mapping': [(0, 3), (1, 4), (2, 1), (3, 2)]},
 {'accuracy': 27.4090433569685, 'mapping': [(0, 4), (1, 1), (2, 2), (3, 3)]}]





In [32]:
def cat_n_layers(n):
    def aggregate(stacked_token_embeddings, target_indices):
        return torch.cat([*stacked_token_embeddings[torch.arange(0, target_indices.size(0)), target_indices[:, 1], -n:].permute(1, 0, 2)], dim=1)

    return aggregate

def sum_n_layers(n):
    def aggregate(stacked_token_embeddings, target_indices):
        print(stacked_token_embeddings[torch.arange(0, target_indices.size(0)), target_indices[:, 1], -n:].shape)
        print(torch.sum(stacked_token_embeddings[torch.arange(0, target_indices.size(0)), target_indices[:, 1], -n:], dim=1).shape)
        return torch.sum(stacked_token_embeddings[torch.arange(0, target_indices.size(0)), target_indices[:, 1], -n:], dim=1)

    return aggregate

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
def get_context_vectors(input_ids, attention_mask, target_mask):
    # Ignore gradient to improve performance
    with torch.no_grad():
        outputs = model(input_ids.to(device), attention_mask.to(device))
        hidden_states = outputs.hidden_states

    # Extract model embeddings layer activations
    token_embeddings = torch.stack(hidden_states, dim=0)

    # Remove batches dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=1)

    # Swap layer and token dimensions
    token_embeddings = token_embeddings.permute(1, 0, 2)
    
    # Identify indices within encoded text to calculate context embeddings
    target_indices = target_mask.to(device).nonzero()

    # Use the sum of the last 4 embedding layers as an aggregation of context for the selected indices
    stacked_token_embeddings = token_embeddings.repeat(torch.sum(target_mask), 1, 1, 1)
    embedding_aggregate = cat_n_layers(4)(stacked_token_embeddings, target_indices)
    thing = sum_n_layers(4)(stacked_token_embeddings, target_indices)

    return embedding_aggregate

test_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for example in tqdm(dataset['validation'], desc='Collecting Contexts'):
    input_ids = torch.tensor(example['input_ids'])
    attention_mask = torch.ones_like(input_ids)
    target_mask = torch.tensor(example['target_mask'])

    embedding_aggregate = get_context_vectors(input_ids, attention_mask, target_mask)
    #test_vectors = torch.cat((test_vectors, embedding_aggregate))

    break

Collecting Contexts:   0%|          | 0/3250 [00:00<?, ?it/s]

9
torch.Size([9, 4, 768])
torch.Size([9, 3072])
torch.Size([9, 4, 768])
torch.Size([9, 768])





In [52]:
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

b = a.repeat(2, 1, 1)
print(b)

indices = torch.tensor([[0, 0], [0, 1]])

torch.sum(b[torch.arange(0, indices.size(0)), indices[:, 1], -1:], dim=1)

tensor([[[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]],

        [[1, 2, 3],
         [4, 5, 6],
         [7, 8, 9]]])


tensor([3, 6])