# Self-Supervised Named Entity Recognition (SSLNER)

Project based on [this repository](https://github.com/ajitrajasekharan/unsupervised_NER) and [this medium blog post](https://towardsdatascience.com/ssl-could-avoid-supervised-learning-fd049a27cd1b). RoBERTa is used instead of BERT because of the improved robustness of the model from dynamic masking instead of static masking.

In [70]:
from transformers import RobertaTokenizer, RobertaModel, BertTokenizer, BertModel
import json
import urllib.request
from pprint import pprint
import kmeans_pytorch

In [71]:
model_name = 'bert-base-cased'

# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# model = RobertaModel.from_pretrained('roberta-base', output_hidden_states=True)

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained("bert-base-cased", output_hidden_states=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [72]:
response = urllib.request.urlopen("https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json")
thing = response.read()

#pprint(thing)

In [73]:
response = urllib.request.urlopen("https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json")
response = urllib.request.urlopen(tokenizer.pretrained_vocab_files_map['vocab_file'][model_name])
try:
    vocab = dict(zip((tokens := response.read().decode('utf-8').splitlines()), tokenizer.convert_tokens_to_ids(tokens)))
except Exception as e:
    print(e)

#pprint(vocab)

In [74]:
entity_types = {'PERSON', 'ORG', 'LOC', 'DATE_TIME', 'WORK_OF_ART', 'MATERIAL', 'EVENT', 'MONEY', 'CARDINAL', 'ORDINAL', 'MISC'}

In [75]:
with open('ordinal.lst', 'r') as file:
    ordinals = file.read().splitlines()

ordinals = {ord for ord in ordinals if ord in vocab}

In [76]:
cardinals = {s for s in vocab if s.isdigit()}

with open('numbers.lst', 'r') as file:
    string_cardinals = file.read().splitlines()

cardinals = cardinals | {card for card in string_cardinals if card in vocab}

In [77]:
with open('person_first.lst', 'r') as file:
    first_names = file.read().splitlines()

names = {name for name in first_names if name in vocab} | {n for name in first_names if (n := name.lower()) in vocab}

Extracting contextual embeddings by summing over the lasy 4 layers

In [99]:
import torch
import torch.nn as nn

text = 'After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank.'
enc_text = tokenizer.encode_plus(text, return_tensors='pt')

model.eval()
with torch.no_grad():
    outputs = model(enc_text.input_ids, enc_text.attention_mask)
    hidden_states = outputs[2]

print(enc_text.input_ids.shape)
token_embeddings = torch.stack(hidden_states, dim=0)
token_embeddings = torch.squeeze(token_embeddings, dim=1)
print(token_embeddings.shape)
token_embeddings = token_embeddings.permute(1, 0, 2)
print(token_embeddings.shape)

token_vecs_sum = [torch.sum(token[-4:], dim=0) for token in token_embeddings]
print("bank vault   ", str(token_vecs_sum[5][:5]))
print("bank robber  ", str(token_vecs_sum[10][:5]))
print("river bank   ", str(token_vecs_sum[19][:5]))

cos = nn.CosineSimilarity(dim=0)
print(f'Bank Vault - Bank Robber: {cos(token_vecs_sum[6], token_vecs_sum[10])}')
print(f'Bank Vault - River Bank: {cos(token_vecs_sum[6], token_vecs_sum[19])}')
print(f'Bank Robber - River Bank: {cos(token_vecs_sum[10], token_vecs_sum[19])}')

torch.Size([1, 24])
torch.Size([13, 24, 768])
torch.Size([24, 13, 768])
bank vault    tensor([-1.1211, -1.5639,  0.5538, -0.5256,  0.1242])
bank robber   tensor([-2.0299, -0.9155, -4.2468,  4.2540,  2.1868])
river bank    tensor([ 3.1898, -0.4442, -1.9258, -0.4348,  2.5461])
Bank Vault - Bank Robber: 0.903982400894165
Bank Vault - River Bank: 0.7330647110939026
Bank Robber - River Bank: 0.6783679723739624


Extracting raw word embeddings

In [79]:
#print([module for module in model.modules()])
test_terms = ['stream', 'Stream']
print([t in vocab for t in test_terms])

similarity_test = tokenizer.convert_tokens_to_ids(test_terms)
print(similarity_test)

raw_word_embeddings = model.embeddings.word_embeddings.weight.detach()

cos = nn.CosineSimilarity(dim=0)
with torch.no_grad():
    tok_embeds = raw_word_embeddings[similarity_test]
    print(tok_embeds)
    print(cos(*tok_embeds))

print(len(raw_word_embeddings))

[True, True]
[5118, 22627]
tensor([[-0.0263, -0.0251, -0.0134,  ..., -0.0169, -0.0548, -0.0229],
        [-0.0144, -0.0418, -0.0290,  ...,  0.0302, -0.0696,  0.0051]])
tensor(0.5783)
28996


In [80]:
enc_tok = tokenizer.convert_tokens_to_ids('United')

dist = torch.norm(raw_word_embeddings - raw_word_embeddings[enc_tok], dim=1, p=None)
knn = dist.topk(20, largest=False)
print(knn)

print(list(zip([cos(raw_word_embeddings[enc_tok], raw_word_embeddings[k]) for k in knn.indices], tokenizer.convert_ids_to_tokens(knn.indices))))

torch.return_types.topk(
values=tensor([0.0000, 1.3220, 1.3394, 1.3422, 1.3484, 1.3511, 1.3578, 1.3616, 1.3660,
        1.3695, 1.3733, 1.3790, 1.3815, 1.3826, 1.3832, 1.3832, 1.3856, 1.3859,
        1.3880, 1.3886]),
indices=tensor([1244, 1646, 1392, 1237, 1203, 1570, 1456, 1970, 1993, 2921, 1305, 1109,
        2579, 1291, 3466, 3604, 2690, 1978, 2685, 1287]))
[(tensor(1.), 'United'), (tensor(0.2931), 'US'), (tensor(0.2538), 'City'), (tensor(0.2395), 'American'), (tensor(0.2826), 'New'), (tensor(0.2471), 'International'), (tensor(0.2403), 'North'), (tensor(0.2515), 'Central'), (tensor(0.2230), 'UK'), (tensor(0.1580), 'Johnson'), (tensor(0.2120), 'National'), (tensor(0.1072), 'The'), (tensor(0.2352), 'Northern'), (tensor(0.2352), 'World'), (tensor(0.1613), 'Harry'), (tensor(0.2574), 'FC'), (tensor(0.1618), 'Jones'), (tensor(0.1055), 'seven'), (tensor(0.2046), 'Southern'), (tensor(0.1081), 'John')]


## Filtering of RoBERTa vocabulary

Remove any punctuation strings, sub-word level tokens and single character tokens

In [81]:
import string
import re

filtered_vocab = {token: idx for token, idx in vocab.items() if not any(s in string.punctuation for s in token) and len(token) > 1}

print(f'Filtered Vocab: {len(filtered_vocab)}/{len(vocab)}')

Filtered Vocab: 21415/28996


Computing cosine similarity matrix for all filtered token embeddings

In [82]:
# Dictionary mapping incrementing indices to the vocab ids
index_idxs = {idx: k for idx, k in enumerate(filtered_vocab.values())}

filtered_vocab_embeds = raw_word_embeddings[list(filtered_vocab.values())]

# Calculate cosine similarity matrix 
with torch.no_grad():
    normalised = filtered_vocab_embeds / filtered_vocab_embeds.norm(dim=1)[:, None]
    similarity_matrix = torch.mm(normalised, normalised.transpose(0, 1))

similarity_matrix.fill_diagonal_(0)
print(similarity_matrix)

tensor([[ 0.0000,  0.4298,  0.4382,  ..., -0.1729, -0.1898, -0.2358],
        [ 0.4298,  0.0000,  0.4958,  ..., -0.1483, -0.1938, -0.1900],
        [ 0.4382,  0.4958,  0.0000,  ..., -0.1942, -0.2092, -0.1902],
        ...,
        [-0.1729, -0.1483, -0.1942,  ...,  0.0000,  0.2980,  0.2323],
        [-0.1898, -0.1938, -0.2092,  ...,  0.2980,  0.0000,  0.2660],
        [-0.2358, -0.1900, -0.1902,  ...,  0.2323,  0.2660,  0.0000]])


Implement graph based clustering. First split into graphs where similarity is above a threshold

In [83]:
threshold = 0.4

similarity_matrix[similarity_matrix < threshold] = 0
print(similarity_matrix)

tensor([[0.0000, 0.4298, 0.4382,  ..., 0.0000, 0.0000, 0.0000],
        [0.4298, 0.0000, 0.4958,  ..., 0.0000, 0.0000, 0.0000],
        [0.4382, 0.4958, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])


[NEED TO FIND ALL FULLY CONNECTED SUBGRAPHS]

Then find node with highest average connection strength for each graph

In [84]:
# Calculate mean within graph, i.e. mean of all non-zero values
mean_connections = similarity_matrix.sum(dim=1) / (similarity_matrix != 0).sum(dim=1)
mean_connections = mean_connections.nan_to_num(0)

# Column vector of average weights within fully connected graphs
print(mean_connections)
print(torch.max(mean_connections))
print(torch.argmax(mean_connections))

tensor([0.4788, 0.5025, 0.4936,  ..., 0.5003, 0.4588, 0.4784])
tensor(0.8169)
tensor(5251)


In [85]:
stacked_means = mean_connections.repeat(similarity_matrix.size(0), 1)

pivots = torch.argmax(stacked_means * (similarity_matrix != 0), dim=1)
print(pivots)

tensor([   38,    11,    33,  ...,  4803,  7684, 12445])


In [86]:
print(len(pivots.unique()))
pivot_tokens = list(set(pivots))
pivot_vectors = raw_word_embeddings[[index_idxs[int(token)] for token in pivot_tokens]]

8771


In [87]:
grouped_by_pivot = {int(i): (pivots == i).nonzero().squeeze() for i in pivots.unique()}
pprint(grouped_by_pivot)

{0: tensor([  125,   202,   267,  ..., 21377, 21378, 21381]),
 1: tensor([ 3,  4,  8, 10, 11, 14, 16, 17]),
 2: tensor(33),
 3: tensor(546),
 4: tensor([26, 50, 59, 80]),
 7: tensor([ 23,  40, 961]),
 8: tensor(244),
 10: tensor([ 95, 111]),
 11: tensor([1, 9]),
 12: tensor(30),
 13: tensor(53),
 14: tensor(494),
 15: tensor(36),
 16: tensor(212),
 17: tensor(468),
 20: tensor([ 15,  31,  49, 103, 113]),
 22: tensor(595),
 23: tensor([ 5, 28]),
 25: tensor(2713),
 27: tensor(48),
 28: tensor(1130),
 32: tensor([188, 589]),
 33: tensor([   2,   29, 1633]),
 34: tensor(2955),
 35: tensor([ 19,  39, 281, 291]),
 37: tensor([ 207, 1883]),
 38: tensor([ 0, 12, 13, 18, 75]),
 39: tensor([  65, 1259]),
 40: tensor([  7,  21,  34, 377]),
 42: tensor([  45,   88, 2362]),
 46: tensor([1061, 1364]),
 49: tensor([ 20,  27, 298]),
 50: tensor(926),
 51: tensor([ 94, 122, 269]),
 52: tensor(4153),
 53: tensor(895),
 54: tensor([ 410, 1353]),
 55: tensor([  37,   93,  123,  196,  772, 1411]),
 56: te

In [88]:
original_ids_pivots = {index_idxs[int(k)]: [index_idxs[int(v_i)] for v_i in v.flatten().tolist()] for k, v in grouped_by_pivot.items()}

decoded_pivots = {tokenizer.convert_ids_to_tokens(k): tokenizer.convert_ids_to_tokens(v) for k, v in original_ids_pivots.items()}
pprint(decoded_pivots)

{'10th': ['1st',
          'fourth',
          '2nd',
          '19th',
          '3rd',
          'fifth',
          '20th',
          '4th',
          'sixth',
          '5th',
          '18th',
          '6th',
          '7th',
          '17th',
          'seventh',
          '16th',
          '8th',
          '12th',
          '13th',
          '11th',
          '9th',
          '15th',
          '14th',
          'eighth',
          '21st',
          'ninth',
          'Fourth',
          'tenth',
          'nineteenth',
          'Fifth',
          'twentieth',
          '25th',
          'Sixth',
          'Seventh',
          '50th',
          '30th',
          'twelfth',
          '24th',
          'eleventh',
          '22nd',
          '23rd',
          'eighteenth',
          '26th',
          '27th',
          '28th',
          'sixteenth',
          '29th',
          'Eighth',
          '40th',
          'thirteenth',
          '100th',
          'seventeenth',
          

In [89]:
i = 0

dist = torch.norm(filtered_vocab_embeds - pivot_vectors[i], dim=1, p=None)
knn = dist.topk(20, largest=False)
print(knn)

print()
print(tokenizer.convert_ids_to_tokens(index_idxs[int(pivot_tokens[i])]))
print(tokenizer.convert_ids_to_tokens([index_idxs[int(k)] for k in knn.indices]))

torch.return_types.topk(
values=tensor([0.0000, 0.8689, 0.9224, 1.2187, 1.2297, 1.2297, 1.2633, 1.2662, 1.2722,
        1.2759, 1.2766, 1.2784, 1.2798, 1.2812, 1.2825, 1.2827, 1.2926, 1.2926,
        1.2928, 1.2958]),
indices=tensor([ 4219,  6505,  8135, 13305,  3650,  2018,  7186, 21257, 14126,  9515,
         1868, 14133, 14770, 20911,  7064,  4425, 14584, 10083, 13542,  3057]))

ignored
['ignored', 'ignore', 'ignoring', 'neglected', 'rejected', 'refused', 'avoided', 'Ignoring', 'omitted', 'annoyed', 'grabbed', 'endured', 'scowled', 'mocked', 'glared', 'shouted', 'overlooked', 'utilized', 'irritated', 'regarded']


In [90]:
from transformers import pipeline

unmasker = pipeline('fill-mask', model='bert-base-cased')
unmasked = unmasker("Carravaggio is a [MASK].")

pprint(unmasked)

#pprint({res['token_str']: {k: v for k, v in decoded_pivots.items() if res['token_str'] in v} for res in unmasked})

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[{'score': 0.03785141184926033,
  'sequence': 'Carravaggio is a politician.',
  'token': 2931,
  'token_str': 'politician'},
 {'score': 0.03195589780807495,
  'sequence': 'Carravaggio is a village.',
  'token': 1491,
  'token_str': 'village'},
 {'score': 0.02933131717145443,
  'sequence': 'Carravaggio is a commune.',
  'token': 5188,
  'token_str': 'commune'},
 {'score': 0.025602901354432106,
  'sequence': 'Carravaggio is a painter.',
  'token': 5125,
  'token_str': 'painter'},
 {'score': 0.020912228152155876,
  'sequence': 'Carravaggio is a comune.',
  'token': 24382,
  'token_str': 'comune'}]


In [91]:
import spacy

nlp = spacy.load("en_core_web_sm")
doc = nlp("It was commissioned by the local confraternity dedicated to Saint Anne for the altar of their oratory.")

print(doc.spans)
print([(token.text, token.pos_) for token in doc if token.pos_ in ['PROPN', 'NOUN']])

unmasked = unmasker("oratory is a [MASK].")

pprint(unmasked)

{}
[('confraternity', 'NOUN'), ('Saint', 'PROPN'), ('Anne', 'PROPN'), ('altar', 'NOUN'), ('oratory', 'NOUN')]
[{'score': 0.09746652841567993,
  'sequence': 'oratory is a requirement.',
  'token': 8875,
  'token_str': 'requirement'},
 {'score': 0.05214102193713188,
  'sequence': 'oratory is a virtue.',
  'token': 13456,
  'token_str': 'virtue'},
 {'score': 0.040402546525001526,
  'sequence': 'oratory is a profession.',
  'token': 9545,
  'token_str': 'profession'},
 {'score': 0.026807719841599464,
  'sequence': 'oratory is a verb.',
  'token': 12464,
  'token_str': 'verb'},
 {'score': 0.02123214863240719,
  'sequence': 'oratory is a practice.',
  'token': 2415,
  'token_str': 'practice'}]


k-means clustering

In [92]:
from kmeans_pytorch import kmeans

cluster_ids, cluster_centres = kmeans(filtered_vocab_embeds, 11, distance='cosine', device=torch.device('cuda:0'))

running k-means on cuda:0..


[running kmeans]: 34it [00:05,  5.76it/s, center_shift=0.000078, iteration=34, tol=0.000100]


In [93]:
print(0 in cluster_ids)

assignments = {int(i): (cluster_ids == int(i)).nonzero().squeeze(dim=1) for i in cluster_ids.unique()}
print(assignments)

assignment_lens = {k: len(v) for k, v in assignments.items()}
print(assignment_lens)



True
{0: tensor([  167,   450,   460,  ..., 21377, 21381, 21410]), 1: tensor([  309,   346,   348,   396,   484,   495,   630,   743,   855,   873,
          883,   907,   928,   939,  1027,  1061,  1078,  1205,  1242,  1303,
         1308,  1373,  1386,  1449,  1525,  1550,  1553,  1647,  1732,  1762,
         1794,  1849,  1921,  1964,  2001,  2134,  2161,  2165,  2201,  2260,
         2351,  2359,  2451,  2502,  2523,  2535,  2649,  2650,  2674,  2733,
         2808,  2877,  2965,  2995,  3132,  3225,  3245,  3260,  3291,  3302,
         3307,  3343,  3389,  3453,  3465,  3491,  3500,  3532,  3586,  3659,
         3676,  3699,  3705,  3735,  3768,  3885,  3887,  3889,  3912,  3923,
         3953,  3965,  4097,  4121,  4151,  4184,  4216,  4218,  4301,  4326,
         4409,  4433,  4457,  4570,  4682,  4716,  4730,  4751,  4764,  4778,
         4790,  4817,  4831,  4846,  4890,  4941,  4951,  4954,  4981,  4994,
         5089,  5097,  5115,  5125,  5147,  5154,  5194,  5211,  5213,  