In [1]:
from transformers import BertTokenizer, BertModel
from kmeans_pytorch import kmeans, kmeans_predict
from tqdm import tqdm
from pprint import pprint
import torch
import gc
from datasets import load_dataset

In [2]:
model_name = 'bert-base-cased'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained("bert-base-cased", output_hidden_states=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Prepare Data

The CoNLL2003 dataset is first used to evaluate performance. The ground truth pos tags are used to label propernouns in the text to eliminate any error introduced by a pos tagger in the system.

In [30]:
original_tags = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
# New non-BIO scheme
new_tags = {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4}
new_tags_string = {0: 'O', 1: 'PER', 2: 'ORG', 3: 'LOC', 4: 'MISC'}

dataset = load_dataset("conll2003")

def process_batch(example):
    target_mask = torch.tensor([0], dtype=torch.int8)
    input_ids = torch.tensor([101], dtype=int)

    for token, pos_tag in zip(example['tokens'], example['pos_tags']):
        encoded_tok = tokenizer.encode(token, return_tensors='pt')[0, 1:-1]

        input_ids = torch.cat((input_ids, encoded_tok))
        target_mask = torch.cat((target_mask, torch.full_like(encoded_tok, 1 if pos_tag in [22, 23] else 0)))

    target_mask = torch.cat((target_mask, torch.tensor([0]))).unsqueeze(0)
    input_ids = torch.cat((input_ids, torch.tensor([102]))).unsqueeze(0)

    return {'input_ids': input_ids, 'target_mask': target_mask}

for split in ['train', 'validation', 'test']:
    dataset[split] = dataset[split].map(lambda batch: {
        **process_batch(batch),
        'ner_tags': [new_tags[tag] for tag in batch['ner_tags']]
        }).remove_columns(['id', 'chunk_tags'])

print(dataset['train'][0])

Found cached dataset conll2003 (/home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/14041 [00:00<?, ?ex/s]

  0%|          | 0/3250 [00:00<?, ?ex/s]

  0%|          | 0/3453 [00:00<?, ?ex/s]

{'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'ner_tags': [2, 0, 4, 0, 0, 0, 4, 0, 0], 'input_ids': [[101, 7270, 22961, 1528, 1840, 1106, 21423, 1418, 2495, 12913, 119, 102]], 'target_mask': [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}


## Training

Define a method to extract the context vectors of the proper nouns

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
def get_context_vectors(input_ids, attention_mask, target_mask):
    # Ignore gradient to improve performance
    with torch.no_grad():
        outputs = model(input_ids.to(device), attention_mask.to(device))
        hidden_states = outputs.hidden_states

    # Extract model embeddings layer activations
    token_embeddings = torch.stack(hidden_states, dim=0)

    # Remove batches dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=1)

    # Swap layer and token dimensions
    token_embeddings = token_embeddings.permute(1, 0, 2)
    
    # Identify indices within encoded text to calculate context embeddings
    target_indices = target_mask.to(device).nonzero()

    # Use the sum of the last 4 embedding layers as an aggregation of context for the selected indices
    stacked_token_embeddings = token_embeddings.repeat(torch.sum(target_mask), 1, 1, 1)
    embedding_aggregate = torch.sum(stacked_token_embeddings[torch.arange(0, torch.sum(target_mask)), target_indices[:, 1], -4:], dim=1)

    return embedding_aggregate

Extract all of the context vectors using the defined method on the trainins split

In [7]:
context_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for example in tqdm(dataset['train'], desc='Collecting Contexts'):
    input_ids = torch.tensor(example['input_ids'])
    attention_mask = torch.ones_like(input_ids)
    target_mask = torch.tensor(example['target_mask'])

    embedding_aggregate = get_context_vectors(input_ids, attention_mask, target_mask)
    context_vectors = torch.cat((context_vectors, embedding_aggregate))

print(context_vectors.shape)
print(sum(sum(x[0]) for x in dataset['train']['target_mask']))

Collecting Contexts: 100%|██████████| 14041/14041 [02:34<00:00, 90.97it/s]

torch.Size([66214, 768])





TypeError: sum(): argument 'input' (position 1) must be Tensor, not list

## Cluster

Cluster context vectors using kmeans with $k = 4$

In [24]:
cluster_ids, cluster_centres = kmeans(context_vectors, 4, distance='cosine', device=device)

print(context_vectors.shape)

running k-means on cuda..


[running kmeans]: 49it [00:13,  3.56it/s, center_shift=0.000055, iteration=49, tol=0.000100]  

torch.Size([66214, 768])





In [55]:
i = 3
example = dataset['train'][i]

input_ids = torch.tensor(example['input_ids'])
attention_mask = torch.ones_like(input_ids)
target_mask = torch.tensor(example['target_mask'])

test_context_vectors = get_context_vectors(input_ids, attention_mask, target_mask)

if len(test_context_vectors) == 1:
    test_context_vectors = test_context_vectors.repeat(2, 1)

test_cluster_ids = kmeans_predict(test_context_vectors, cluster_centres, device=device)
test_clusters_ids_list = test_cluster_ids.squeeze().tolist()

token_ids = {token: tokenizer.encode(token)[1:-1] for token, pos_tag in zip(example['tokens'], example['pos_tags']) if pos_tag in [22, 23]}
token_predictions = {token: [test_clusters_ids_list.pop(0) for _ in range(len(ids))] for token, ids in token_ids.items()}

print(token_predictions)

predicting on cuda..
{'European': [1], 'Commission': [2], 'Thursday': [1]}


## Evaluate

Collect cluster labels and find which assignment of labels obtains the highest f-1. This saves manually labelling them while testing but manually labelling the clusters would be necessary on an unlabelled dataset

In [67]:
val_context_vectors = torch.empty(0, 768).to(device)

model.eval()
model.to(device)

for example in tqdm(dataset['validation'], desc='Collecting Contexts'):
    input_ids = torch.tensor(example['input_ids'])
    attention_mask = torch.ones_like(input_ids)
    target_mask = torch.tensor(example['target_mask'])

    embedding_aggregate = get_context_vectors(input_ids, attention_mask, target_mask)
    val_context_vectors = torch.cat((val_context_vectors, embedding_aggregate))

print(val_context_vectors.shape)
print(sum(sum(x[0]) for x in dataset['validation']['target_mask']))

Collecting Contexts: 100%|██████████| 3250/3250 [00:35<00:00, 90.44it/s]

torch.Size([16592, 768])
16592





Identify cluster allocations

In [68]:
val_cluster_ids = kmeans_predict(val_context_vectors, cluster_centres, device=device)

predicting on cuda..


In [70]:
val_clusters_ids_list = val_cluster_ids.squeeze().tolist()
sample_predictions = []

for example in tqdm(dataset['validation']):
    token_ids = {token: tokenizer.encode(token)[1:-1] for token, pos_tag in zip(example['tokens'], example['pos_tags']) if pos_tag in [22, 23]}
    token_predictions = {token: [val_clusters_ids_list.pop(0) for _ in range(len(ids))] for token, ids in token_ids.items()}

    sample_predictions.append(token_predictions)

pprint(sample_predictions)

100%|██████████| 3250/3250 [00:01<00:00, 1882.29it/s]


[{'AFTER': [0, 0, 0],
  'AT': [0],
  'CRICKET': [0, 0, 0, 0],
  'INNINGS': [0, 0, 0, 0],
  'LEICESTERSHIRE': [0, 0, 0, 0, 0, 0, 0, 0, 0],
  'TAKE': [0, 0, 0],
  'TOP': [0, 0]},
 {'LONDON': [0, 0, 0, 0]},
 {'Friday': [1],
  'Indian': [1],
  'Leicestershire': [1],
  'Phil': [2],
  'Simmons': [2],
  'Somerset': [1],
  'West': [2]},
 {'Derbyshire': [1],
  'Essex': [2],
  'Kent': [1],
  'Nottinghamshire': [1],
  'Surrey': [1]},
 {'Andy': [2],
  'Caddick': [3, 3, 3],
  'England': [1],
  'Grace': [2],
  'Leicestershire': [1],
  'Road': [2],
  'Somerset': [1]},
 {'Simmons': [2], 'Somerset': [2]},
 {'Essex': [2],
  'Headingley': [3, 3, 2],
  'Hussain': [2],
  'Nasser': [3, 3],
  'Peter': [2],
  'Yorkshire': [1]},
 {'England': [1], 'Essex': [2]},
 {'Yorkshire': [2]},
 {'Chris': [2],
  'England': [1],
  'Friday': [1],
  'Lewis': [2],
  'Surrey': [1],
  'Thursday': [1],
  'Warwickshire': [1]},
 {'Butcher': [2], 'England': [1], 'Mark': [2], 'Surrey': [1]},
 {'Worcestershire': [1]},
 {'Adams': [2],
