In [1]:
from transformers import BertTokenizer, BertModel
from kmeans_pytorch import kmeans, kmeans_predict
from tqdm import tqdm
from pprint import pprint
import torch
import gc
from datasets import load_dataset

In [2]:
model_name = 'bert-base-cased'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained("bert-base-cased", output_hidden_states=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Preparation

The CoNLL2003 dataset is first used to evaluate performance. The ground truth ner tags are used only select named entities to allow the class allocation to be tested independently by assumung perfect extraction. Later work will be conducted into identifying which entities should be labelled. NER tags are mapped to a non-BIO based scheme to reduce the number of clusters.

In [48]:
original_tags = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
# New non-BIO scheme
new_tags = {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4}
new_tags_string = {0: 'O', 1: 'PER', 2: 'ORG', 3: 'LOC', 4: 'MISC'}

dataset = load_dataset("conll2003")

def process_batch(example):
    target_mask = torch.tensor([0], dtype=torch.int8)
    input_ids = torch.tensor([101], dtype=int)

    for token, ner_tag in zip(example['tokens'], example['ner_tags']):
        encoded_tok = tokenizer.encode(token, return_tensors='pt')[0, 1:-1]

        input_ids = torch.cat((input_ids, encoded_tok))
        target_mask = torch.cat((target_mask, torch.full_like(encoded_tok, 1 if ner_tag != 0 else 0)))

    target_mask = torch.cat((target_mask, torch.tensor([0]))).unsqueeze(0)
    input_ids = torch.cat((input_ids, torch.tensor([102]))).unsqueeze(0)

    return {'input_ids': input_ids, 'target_mask': target_mask}

for split in ['train', 'validation', 'test']:
    dataset[split] = dataset[split].map(lambda batch: {
        **process_batch(batch),
        'ner_tags': [new_tags[tag] for tag in batch['ner_tags']]
        }).remove_columns(['id', 'chunk_tags'])

print(dataset['train'][0])

Found cached dataset conll2003 (/home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/14041 [00:00<?, ?ex/s]

  0%|          | 0/3250 [00:00<?, ?ex/s]

  0%|          | 0/3453 [00:00<?, ?ex/s]

{'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'ner_tags': [2, 0, 4, 0, 0, 0, 4, 0, 0], 'input_ids': [[101, 7270, 22961, 1528, 1840, 1106, 21423, 1418, 2495, 12913, 119, 102]], 'target_mask': [[0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]]}


## Training

Define a method to extract the context vectors of the extracted entities

In [49]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
def get_context_vectors(input_ids, attention_mask, target_mask):
    # Ignore gradient to improve performance
    with torch.no_grad():
        outputs = model(input_ids.to(device), attention_mask.to(device))
        hidden_states = outputs.hidden_states

    # Extract model embeddings layer activations
    token_embeddings = torch.stack(hidden_states, dim=0)

    # Remove batches dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=1)

    # Swap layer and token dimensions
    token_embeddings = token_embeddings.permute(1, 0, 2)
    
    # Identify indices within encoded text to calculate context embeddings
    target_indices = target_mask.to(device).nonzero()

    # Use the sum of the last 4 embedding layers as an aggregation of context for the selected indices
    stacked_token_embeddings = token_embeddings.repeat(torch.sum(target_mask), 1, 1, 1)
    embedding_aggregate = torch.sum(stacked_token_embeddings[torch.arange(0, torch.sum(target_mask)), target_indices[:, 1], -4:], dim=1)

    return embedding_aggregate

Extract all of the context vectors using the defined method on the training split

In [50]:
context_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for example in tqdm(dataset['train'], desc='Collecting Contexts'):
    input_ids = torch.tensor(example['input_ids'])
    attention_mask = torch.ones_like(input_ids)
    target_mask = torch.tensor(example['target_mask'])

    embedding_aggregate = get_context_vectors(input_ids, attention_mask, target_mask)
    context_vectors = torch.cat((context_vectors, embedding_aggregate))

print(context_vectors.shape)
print(sum(sum(x[0]) for x in dataset['train']['target_mask']))

Collecting Contexts: 100%|██████████| 14041/14041 [02:38<00:00, 88.71it/s]


torch.Size([65030, 768])
65030


## Cluster

Cluster context vectors using kmeans with $k = 4$

In [51]:
cluster_ids, cluster_centres = kmeans(context_vectors, 4, distance='cosine', device=device)

print(context_vectors.shape)

running k-means on cuda..


[running kmeans]: 31it [00:08,  3.56it/s, center_shift=0.000071, iteration=31, tol=0.000100]  

torch.Size([65030, 768])





Manually evaluate clusters

In [52]:
i = 3
example = dataset['train'][i]

input_ids = torch.tensor(example['input_ids'])
attention_mask = torch.ones_like(input_ids)
target_mask = torch.tensor(example['target_mask'])

test_context_vectors = get_context_vectors(input_ids, attention_mask, target_mask)

if len(test_context_vectors) == 1:
    test_context_vectors = test_context_vectors.repeat(2, 1)

test_cluster_ids = kmeans_predict(test_context_vectors, cluster_centres, device=device)
test_clusters_ids_list = test_cluster_ids.squeeze().tolist()

token_ids = {token: tokenizer.encode(token)[1:-1] for token, ner_tag in zip(example['tokens'], example['ner_tags']) if ner_tag != 0}
token_predictions = {token: [test_clusters_ids_list.pop(0) for _ in range(len(ids))] for token, ids in token_ids.items()}

print(token_predictions)

predicting on cuda..
{'European': [0], 'Commission': [1], 'German': [0], 'British': [1]}


## Evaluate

Calculate context vectors for all validation samples

In [53]:
val_context_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for example in tqdm(dataset['validation'], desc='Collecting Contexts'):
    input_ids = torch.tensor(example['input_ids'])
    attention_mask = torch.ones_like(input_ids)
    target_mask = torch.tensor(example['target_mask'])

    embedding_aggregate = get_context_vectors(input_ids, attention_mask, target_mask)
    val_context_vectors = torch.cat((val_context_vectors, embedding_aggregate))

print(val_context_vectors.shape)
print(sum(sum(x[0]) for x in dataset['validation']['target_mask']))

Collecting Contexts: 100%|██████████| 3250/3250 [00:37<00:00, 87.82it/s]

torch.Size([16225, 768])
16225





Identify cluster allocations

In [54]:
val_cluster_ids = kmeans_predict(val_context_vectors, cluster_centres, device=device)

predicting on cuda..


Assign cluster ids to each token in the dataset and obtain a single cluster id for items with multiple, i.e. words composed of sub-word level tokens, by taking the mode or majority vote. Using `max(set(lst), key=lst.count))` for majority voting yields the most frequent list item or, in the case that there is a tie for most frequent, the item with the lowest value:

```python
a = [1, 1, 1, 0, 0]
b = [1, 1, 0, 0]
max(set(a), key=a.count)
# >> 1
max(set(b), key=b.count)
# >> 0
```
Another implementation may find that the head sub-word token is more informative.

Also collect the ground truth NER tag for each entity for comparison. `Note`: The cluster id numbers will not correspond to the NER tag numbers. They are two sets of labels that need to be mapped to eachother in some way.

In [55]:
val_clusters_ids_list = val_cluster_ids.squeeze().tolist()
sample_predictions = []

for example in tqdm(dataset['validation']):
    token_ids = {f'{i}-{token}': {
        'ids': tokenizer.encode(token)[1:-1],
        'ner_tag': ner_tag
     } for i, (token, ner_tag) in enumerate(zip(example['tokens'], example['ner_tags'])) if ner_tag != 0}

    token_predictions = {token: {
        'cluster_id': max(set((lst := [val_clusters_ids_list.pop(0) for _ in range(len(attributes['ids']))])), key=lst.count),
        'ner_tag': attributes['ner_tag']
     } for token, attributes in token_ids.items()}

    sample_predictions.append(token_predictions)

pprint(sample_predictions[:3])

100%|██████████| 3250/3250 [00:01<00:00, 2188.48it/s]

[{'2-LEICESTERSHIRE': {'cluster_id': 3, 'ner_tag': 2}},
 {'0-LONDON': {'cluster_id': 3, 'ner_tag': 3}},
 {'0-West': {'cluster_id': 0, 'ner_tag': 4},
  '1-Indian': {'cluster_id': 0, 'ner_tag': 4},
  '12-Leicestershire': {'cluster_id': 1, 'ner_tag': 2},
  '14-Somerset': {'cluster_id': 1, 'ner_tag': 2},
  '3-Phil': {'cluster_id': 1, 'ner_tag': 1},
  '4-Simmons': {'cluster_id': 1, 'ner_tag': 1}}]





Find which mapping of cluster ids to NER ids yields the best f-1 performance. `Note`: This calculation does not consider entities that weren't tagged by using PROPN. This is to control for the initial entity extraction task which will be reviewed at another stage. This is to just evaluate whether this style of clustering is suitable for identifying entity classes.

In [85]:
from collections import Counter, deque, defaultdict

cluster_counts = Counter([x_i['cluster_id'] for x in sample_predictions for x_i in x.values()])
ner_counts = Counter([x_i['ner_tag'] for x in sample_predictions for x_i in x.values()])

print(f'Cluster Distribution: {cluster_counts}')
print(f'NER Tag Distribution: {ner_counts}')

mapping_freqs = Counter([(x_i['cluster_id'], x_i['ner_tag']) for x in sample_predictions for x_i in x.values()])
p = defaultdict(dict)

for k, v in mapping_freqs.items():
    p[k[0]].update({k[1]: v})

print(p)

print(mapping_freqs)

a = [0, 1, 2, 3]
b = deque([1, 2, 3, 4])

for i in range(len(a)):
    b.rotate(i)
    mapping = {a_i: b_i for a_i, b_i in zip(a, b)}

    print(p[0][4] / sum([p[0][x] for x in p[0].keys() if x != p[0][4]]) * 100)
    



    


Cluster Distribution: Counter({1: 4159, 0: 2216, 2: 1520, 3: 708})
NER Tag Distribution: Counter({1: 3149, 3: 2094, 2: 2092, 4: 1268})
defaultdict(<class 'dict'>, {3: {2: 189, 3: 346, 4: 142, 1: 31}, 0: {4: 629, 3: 1366, 2: 215, 1: 6}, 1: {1: 1848, 2: 1567, 3: 280, 4: 464}, 2: {1: 1264, 3: 102, 2: 121, 4: 33}})
Counter({(1, 1): 1848, (1, 2): 1567, (0, 3): 1366, (2, 1): 1264, (0, 4): 629, (1, 4): 464, (3, 3): 346, (1, 3): 280, (0, 2): 215, (3, 2): 189, (3, 4): 142, (2, 2): 121, (2, 3): 102, (2, 4): 33, (3, 1): 31, (0, 1): 6})
28.384476534296027
28.384476534296027
28.384476534296027
28.384476534296027
