In [14]:
from transformers import RobertaTokenizer, RobertaModel, BertTokenizer, BertModel
from kmeans_pytorch import kmeans, kmeans_predict
import json
from pprint import pprint
import spacy
from tqdm import tqdm
import torch
import gc
from datasets import load_dataset
from nltk.tokenize.treebank import TreebankWordDetokenizer

In [2]:
model_name = 'bert-base-cased'

# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# model = RobertaModel.from_pretrained('roberta-base', output_hidden_states=True)

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained("bert-base-cased", output_hidden_states=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
dataset = load_dataset("conll2003")

for split in ['train', 'validation', 'test']:
    dataset[split] = dataset[split].map(lambda batch: {
        'text': TreebankWordDetokenizer().detokenize(batch['tokens']),
        'encoded_PROPNS': [ids for tok, pos_tag in zip(batch['tokens'], batch['pos_tags']) for ids in tokenizer.encode(tok)[1:-1] if pos_tag in [22, 23]]
        }).remove_columns(['id', 'chunk_tags'])

print(dataset['train'][0])

Found cached dataset conll2003 (/home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-219dc2dcd710c718.arrow
Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-8ec9299742aa5129.arrow
Loading cached processed dataset at /home/william/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98/cache-14e3409507fc4dc2.arrow


{'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0], 'text': 'EU rejects German call to boycott British lamb.', 'encoded_PROPNS': [7270]}


In [4]:
texts = dataset['train']['text']
encoded_PROPNS = dataset['train']['encoded_PROPNS']

In [35]:
def get_context_vectors(sentence, encoded_ids):
    enc_sent = tokenizer.encode_plus(sentence, return_tensors='pt')
    ids_t = torch.tensor(encoded_ids)
    
    with torch.no_grad():
        outputs = model(enc_sent.input_ids.to(device), enc_sent.attention_mask.to(device))
        hidden_states = outputs[2]

    # Extract model embeddings layer activations
    token_embeddings = torch.stack(hidden_states, dim=0)

    # Remove batches dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=1)

    # Swap layer and token dimensions
    token_embeddings = token_embeddings.permute(1, 0, 2)
    
    # Identify indices within encoded text to calculate context embeddings
    target_indices = (enc_sent.input_ids.T == ids_t).nonzero()

    # Use the sum of the last 4 embedding layers as an aggregation of context for the selected indices
    stacked_token_embeddings = token_embeddings.repeat(ids_t.size(0), 1, 1, 1)
    embedding_aggregate = torch.sum(stacked_token_embeddings[target_indices[:, 1], target_indices[:, 0], -4:], dim=1)

    print(embedding_aggregate.shape)

    return embedding_aggregate

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.eval()
context_vectors = torch.empty(0, 768).to(device)

model.to(device)

for sentence, encoded_ids in tqdm(zip(texts, encoded_PROPNS)):

    enc_sent = tokenizer.encode_plus(sentence, return_tensors='pt')
    ids_t = torch.tensor(encoded_ids)
    
    with torch.no_grad():
        outputs = model(enc_sent.input_ids.to(device), enc_sent.attention_mask.to(device))
        hidden_states = outputs[2]

    # Extract model embeddings layer activations
    token_embeddings = torch.stack(hidden_states, dim=0)

    # Remove batches dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=1)

    # Swap layer and token dimensions
    token_embeddings = token_embeddings.permute(1, 0, 2)
    
    # Identify indices within encoded text to calculate context embeddings
    target_indices = (enc_sent.input_ids.T == ids_t).nonzero()

    # Use the sum of the last 4 embedding layers as an aggregation of context for the selected indices
    stacked_token_embeddings = token_embeddings.repeat(ids_t.size(0), 1, 1, 1)
    embedding_aggregate = torch.sum(stacked_token_embeddings[target_indices[:, 1], target_indices[:, 0], -4:], dim=1)

    context_vectors = torch.cat((context_vectors, embedding_aggregate))

print(context_vectors.shape)
print(len([i for id in dataset['train']['encoded_PROPNS'] for i in id]))

14041it [02:34, 91.16it/s]

torch.Size([75289, 768])
66214





## Cluster

Cluster context vectors using kmeans with $k = 4$

In [70]:
cluster_ids, cluster_centres = kmeans(context_vectors, 4, distance='cosine', device=device)

print(context_vectors.shape)

running k-means on cuda..


[running kmeans]: 49it [00:15,  3.25it/s, center_shift=0.000062, iteration=49, tol=0.000100]  

torch.Size([75289, 768])





Manually inspect results

In [77]:
i = 106
text = dataset['train'][i]['text']
encoded_ids = dataset['train'][i]['encoded_PROPNS']

embedding_aggregate = get_context_vectors(text, encoded_ids)

if embedding_aggregate.size(0) == 1:
    embedding_aggregate = embedding_aggregate.repeat(2, 1)

clusters = kmeans_predict(embedding_aggregate, cluster_centres, device=device, distance='cosine')

encoded_sent = tokenizer.encode(text)
encoded_tokens = [(token, tokenizer.encode(token)[1:-1]) for token, pos_tag in zip(dataset['train'][i]['tokens'], dataset['train'][i]['pos_tags']) if pos_tag in [22, 23]]
clusters_list = clusters.tolist()

thing = [(enc_toks[0], [clusters_list.pop(0) for _ in range(len(enc_toks[1]))]) for enc_toks in encoded_tokens]
print(thing)

torch.Size([13, 768])
predicting on cuda..
[('President', [1]), ('Hafez', [3, 3, 3]), ('Israel', [1]), ('Foreign', [1]), ('Minister', [1]), ('David', [1]), ('Levy', [3]), ('Israel', [3]), ('Radio', [1])]
