# Self-Supervised Named Entity Recognition (SSLNER)

Project based on [this repository](https://github.com/ajitrajasekharan/unsupervised_NER) and [this medium blog post](https://towardsdatascience.com/ssl-could-avoid-supervised-learning-fd049a27cd1b). RoBERTa is used instead of BERT because of the improved robustness of the model from dynamic masking instead of static masking.

In [372]:
from transformers import RobertaTokenizer, RobertaModel, BertTokenizer, BertModel
import json
import urllib.request
from pprint import pprint
import kmeans_pytorch

In [373]:
model_name = 'bert-base-cased'

# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# model = RobertaModel.from_pretrained('roberta-base', output_hidden_states=True)

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained("bert-base-cased", output_hidden_states=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [374]:
response = urllib.request.urlopen("https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json")
thing = response.read()

pprint(thing)

(b'{"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, ".": 4, "\xc4\xa0the": 5, ",'
 b'": 6, "\xc4\xa0to": 7, "\xc4\xa0and": 8, "\xc4\xa0of": 9, "\xc4\xa0a": 10, "'
 b'\xc4\xa0in": 11, "-": 12, "\xc4\xa0for": 13, "\xc4\xa0that": 14, "\xc4\xa0on'
 b'": 15, "\xc4\xa0is": 16, "\xc3\xa2\xc4\xa2": 17, "\'s": 18, "\xc4\xa0with'
 b'": 19, "\xc4\xa0The": 20, "\xc4\xa0was": 21, "\xc4\xa0\\"": 22, "\xc4\xa0'
 b'at": 23, "\xc4\xa0it": 24, "\xc4\xa0as": 25, "\xc4\xa0said": 26, "\xc4\xbb":'
 b' 27, "\xc4\xa0be": 28, "s": 29, "\xc4\xa0by": 30, "\xc4\xa0from": 31, '
 b'"\xc4\xa0are": 32, "\xc4\xa0have": 33, "\xc4\xa0has": 34, ":": 35, "\xc4\xa0'
 b'(": 36, "\xc4\xa0he": 37, "\xc4\xa0I": 38, "\xc4\xa0his": 39, "\xc4\xa0will"'
 b': 40, "\xc4\xa0an": 41, "\xc4\xa0this": 42, ")": 43, "\xc4\xa0'
 b'\xc3\xa2\xc4\xa2": 44, "\xc4\xa0not": 45, "\xc4\xbf": 46, "\xc4\xa0you": 4'
 b'7, "\xc4\xbe": 48, "\xc4\xa0their": 49, "\xc4\xa0or": 50, "\xc4\xa0they": 51'
 b', "\xc4\xa0we": 52, "\xc4\xa0but": 53, "\xc4\xa0who": 5

In [375]:
response = urllib.request.urlopen("https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json")
response = urllib.request.urlopen(tokenizer.pretrained_vocab_files_map['vocab_file'][model_name])
try:
    vocab = dict(zip((tokens := response.read().decode('utf-8').splitlines()), tokenizer.convert_tokens_to_ids(tokens)))
except Exception as e:
    print(e)

pprint(vocab)

{'!': 106,
 '"': 107,
 '#': 108,
 '##!': 28125,
 '##"': 28126,
 '###': 28127,
 '##$': 28128,
 '##%': 28129,
 '##&': 28130,
 "##'": 28131,
 '##(': 28132,
 '##)': 28133,
 '##*': 28134,
 '##+': 28135,
 '##,': 28136,
 '##-': 28137,
 '##.': 28138,
 '##/': 28139,
 '##0': 1568,
 '##00': 7629,
 '##01': 24400,
 '##0s': 13031,
 '##1': 1475,
 '##10': 10424,
 '##100': 20150,
 '##11': 14541,
 '##12': 11964,
 '##13': 17668,
 '##14': 17175,
 '##15': 16337,
 '##16': 16229,
 '##17': 16770,
 '##18': 15292,
 '##19': 16382,
 '##2': 1477,
 '##20': 10973,
 '##21': 18202,
 '##22': 20581,
 '##23': 22737,
 '##24': 19598,
 '##25': 17600,
 '##26': 25129,
 '##27': 24458,
 '##28': 24606,
 '##29': 26752,
 '##3': 1495,
 '##30': 13144,
 '##31': 22639,
 '##32': 17101,
 '##33': 23493,
 '##34': 23124,
 '##35': 19297,
 '##36': 22997,
 '##37': 26303,
 '##38': 23249,
 '##39': 24786,
 '##4': 1527,
 '##40': 12882,
 '##41': 25892,
 '##42': 23117,
 '##43': 25631,
 '##44': 25041,
 '##45': 21336,
 '##46': 23435,
 '##47': 24766,


In [376]:
entity_types = {'PERSON', 'ORG', 'LOC', 'DATE_TIME', 'WORK_OF_ART', 'MATERIAL', 'EVENT', 'MONEY', 'CARDINAL', 'ORDINAL', 'MISC'}

In [377]:
with open('ordinal.lst', 'r') as file:
    ordinals = file.read().splitlines()

ordinals = {ord for ord in ordinals if ord in vocab}

In [378]:
cardinals = {s for s in vocab if s.isdigit()}

with open('numbers.lst', 'r') as file:
    string_cardinals = file.read().splitlines()

cardinals = cardinals | {card for card in string_cardinals if card in vocab}

In [379]:
with open('person_first.lst', 'r') as file:
    first_names = file.read().splitlines()

names = {name for name in first_names if name in vocab} | {n for name in first_names if (n := name.lower()) in vocab}

Extracting contextual embeddings by summing over the lasy 4 layers

In [380]:
import torch
import torch.nn as nn

text = 'After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank.'
enc_text = tokenizer.encode_plus(text, return_tensors='pt')

model.eval()
with torch.no_grad():
    outputs = model(enc_text.input_ids, enc_text.attention_mask)
    hidden_states = outputs[2]


token_embeddings = torch.stack(hidden_states, dim=0)
token_embeddings = torch.squeeze(token_embeddings, dim=1)
token_embeddings = token_embeddings.permute(1, 0, 2)

token_vecs_sum = [torch.sum(token[-4:], dim=0) for token in token_embeddings]
print("bank vault   ", str(token_vecs_sum[6][:5]))
print("bank robber  ", str(token_vecs_sum[10][:5]))
print("river bank   ", str(token_vecs_sum[19][:5]))

cos = nn.CosineSimilarity(dim=0)
print(f'Bank Vault - Bank Robber: {cos(token_vecs_sum[6], token_vecs_sum[10])}')
print(f'Bank Vault - River Bank: {cos(token_vecs_sum[6], token_vecs_sum[19])}')
print(f'Bank Robber - River Bank: {cos(token_vecs_sum[10], token_vecs_sum[19])}')

bank vault    tensor([-4.1573,  1.7935, -2.6768,  3.3647,  1.6400])
bank robber   tensor([-2.0299, -0.9155, -4.2468,  4.2540,  2.1868])
river bank    tensor([ 3.1898, -0.4442, -1.9258, -0.4348,  2.5461])
Bank Vault - Bank Robber: 0.903982400894165
Bank Vault - River Bank: 0.7330647110939026
Bank Robber - River Bank: 0.6783679723739624


Extracting raw word embeddings

In [381]:
#print([module for module in model.modules()])
test_terms = ['stream', 'Stream']
print([t in vocab for t in test_terms])

similarity_test = tokenizer.convert_tokens_to_ids(test_terms)
print(similarity_test)

raw_word_embeddings = model.embeddings.word_embeddings.weight.detach()

cos = nn.CosineSimilarity(dim=0)
with torch.no_grad():
    tok_embeds = raw_word_embeddings[similarity_test]
    print(tok_embeds)
    print(cos(*tok_embeds))

print(len(raw_word_embeddings))

[True, True]
[5118, 22627]
tensor([[-0.0263, -0.0251, -0.0134,  ..., -0.0169, -0.0548, -0.0229],
        [-0.0144, -0.0418, -0.0290,  ...,  0.0302, -0.0696,  0.0051]])
tensor(0.5783)
28996


In [382]:
enc_tok = tokenizer.convert_tokens_to_ids('United')

dist = torch.norm(raw_word_embeddings - raw_word_embeddings[enc_tok], dim=1, p=None)
knn = dist.topk(20, largest=False)
print(knn)

print(list(zip([cos(raw_word_embeddings[enc_tok], raw_word_embeddings[k]) for k in knn.indices], tokenizer.convert_ids_to_tokens(knn.indices))))

torch.return_types.topk(
values=tensor([0.0000, 1.3220, 1.3394, 1.3422, 1.3484, 1.3511, 1.3578, 1.3616, 1.3660,
        1.3695, 1.3733, 1.3790, 1.3815, 1.3826, 1.3832, 1.3832, 1.3856, 1.3859,
        1.3880, 1.3886]),
indices=tensor([1244, 1646, 1392, 1237, 1203, 1570, 1456, 1970, 1993, 2921, 1305, 1109,
        2579, 1291, 3466, 3604, 2690, 1978, 2685, 1287]))
[(tensor(1.), 'United'), (tensor(0.2931), 'US'), (tensor(0.2538), 'City'), (tensor(0.2395), 'American'), (tensor(0.2826), 'New'), (tensor(0.2471), 'International'), (tensor(0.2403), 'North'), (tensor(0.2515), 'Central'), (tensor(0.2230), 'UK'), (tensor(0.1580), 'Johnson'), (tensor(0.2120), 'National'), (tensor(0.1072), 'The'), (tensor(0.2352), 'Northern'), (tensor(0.2352), 'World'), (tensor(0.1613), 'Harry'), (tensor(0.2574), 'FC'), (tensor(0.1618), 'Jones'), (tensor(0.1055), 'seven'), (tensor(0.2046), 'Southern'), (tensor(0.1081), 'John')]


## Filtering of RoBERTa vocabulary

Remove any punctuation strings, sub-word level tokens and single character tokens

In [383]:
import string
import re

filtered_vocab = {token: idx for token, idx in vocab.items() if not any(s in string.punctuation for s in token) and len(token) > 1}

print(f'Filtered Vocab: {len(filtered_vocab)}/{len(vocab)}')

Filtered Vocab: 21415/28996


Computing cosine similarity matrix for all filtered token embeddings

In [384]:
# Dictionary mapping incrementing indices to the vocab ids
index_idxs = {idx: k for idx, k in enumerate(filtered_vocab.values())}

filtered_vocab_embeds = raw_word_embeddings[list(filtered_vocab.values())]

# Calculate cosine similarity matrix 
with torch.no_grad():
    normalised = filtered_vocab_embeds / filtered_vocab_embeds.norm(dim=1)[:, None]
    similarity_matrix = torch.mm(normalised, normalised.transpose(0, 1))

similarity_matrix.fill_diagonal_(0)
print(similarity_matrix)

tensor([[ 0.0000,  0.4298,  0.4382,  ..., -0.1729, -0.1898, -0.2358],
        [ 0.4298,  0.0000,  0.4958,  ..., -0.1483, -0.1938, -0.1900],
        [ 0.4382,  0.4958,  0.0000,  ..., -0.1942, -0.2092, -0.1902],
        ...,
        [-0.1729, -0.1483, -0.1942,  ...,  0.0000,  0.2980,  0.2323],
        [-0.1898, -0.1938, -0.2092,  ...,  0.2980,  0.0000,  0.2660],
        [-0.2358, -0.1900, -0.1902,  ...,  0.2323,  0.2660,  0.0000]])


Implement graph based clustering. First split into graphs where similarity is above a threshold

In [385]:
threshold = 0.4

similarity_matrix[similarity_matrix < threshold] = 0
print(similarity_matrix)

tensor([[0.0000, 0.4298, 0.4382,  ..., 0.0000, 0.0000, 0.0000],
        [0.4298, 0.0000, 0.4958,  ..., 0.0000, 0.0000, 0.0000],
        [0.4382, 0.4958, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])


[NEED TO FIND ALL FULLY CONNECTED SUBGRAPHS]

Then find node with highest average connection strength for each graph

In [386]:
# Calculate mean within graph, i.e. mean of all non-zero values
mean_connections = similarity_matrix.sum(dim=1) / (similarity_matrix != 0).sum(dim=1)
mean_connections = mean_connections.nan_to_num(0)

# Column vector of average weights within fully connected graphs
print(mean_connections)
print(torch.max(mean_connections))
print(torch.argmax(mean_connections))

tensor([0.4788, 0.5025, 0.4936,  ..., 0.5003, 0.4588, 0.4784])
tensor(0.8169)
tensor(5251)


In [387]:
stacked_means = mean_connections.repeat(similarity_matrix.size(0), 1)

pivots = torch.argmax(stacked_means * (similarity_matrix != 0), dim=1)
print(pivots)

tensor([   38,    11,    33,  ...,  4803,  7684, 12445])


In [388]:
print(len(pivots.unique()))
pivot_tokens = list(set(pivots))
pivot_vectors = raw_word_embeddings[[index_idxs[int(token)] for token in pivot_tokens]]

8771


In [389]:
grouped_by_pivot = {int(i): (pivots == i).nonzero().squeeze() for i in pivots.unique()}
pprint(grouped_by_pivot)

{0: tensor([  125,   202,   267,  ..., 21377, 21378, 21381]),
 1: tensor([ 3,  4,  8, 10, 11, 14, 16, 17]),
 2: tensor(33),
 3: tensor(546),
 4: tensor([26, 50, 59, 80]),
 7: tensor([ 23,  40, 961]),
 8: tensor(244),
 10: tensor([ 95, 111]),
 11: tensor([1, 9]),
 12: tensor(30),
 13: tensor(53),
 14: tensor(494),
 15: tensor(36),
 16: tensor(212),
 17: tensor(468),
 20: tensor([ 15,  31,  49, 103, 113]),
 22: tensor(595),
 23: tensor([ 5, 28]),
 25: tensor(2713),
 27: tensor(48),
 28: tensor(1130),
 32: tensor([188, 589]),
 33: tensor([   2,   29, 1633]),
 34: tensor(2955),
 35: tensor([ 19,  39, 281, 291]),
 37: tensor([ 207, 1883]),
 38: tensor([ 0, 12, 13, 18, 75]),
 39: tensor([  65, 1259]),
 40: tensor([  7,  21,  34, 377]),
 42: tensor([  45,   88, 2362]),
 46: tensor([1061, 1364]),
 49: tensor([ 20,  27, 298]),
 50: tensor(926),
 51: tensor([ 94, 122, 269]),
 52: tensor(4153),
 53: tensor(895),
 54: tensor([ 410, 1353]),
 55: tensor([  37,   93,  123,  196,  772, 1411]),
 56: te

In [394]:
original_ids_pivots = {index_idxs[int(k)]: [index_idxs[int(v_i)] for v_i in v.flatten().tolist()] for k, v in grouped_by_pivot.items()}

decoded_pivots = {tokenizer.convert_ids_to_tokens(k): tokenizer.convert_ids_to_tokens(v) for k, v in original_ids_pivots.items()}
pprint(decoded_pivots)

{'10th': ['1st',
          'fourth',
          '2nd',
          '19th',
          '3rd',
          'fifth',
          '20th',
          '4th',
          'sixth',
          '5th',
          '18th',
          '6th',
          '7th',
          '17th',
          'seventh',
          '16th',
          '8th',
          '12th',
          '13th',
          '11th',
          '9th',
          '15th',
          '14th',
          'eighth',
          '21st',
          'ninth',
          'Fourth',
          'tenth',
          'nineteenth',
          'Fifth',
          'twentieth',
          '25th',
          'Sixth',
          'Seventh',
          '50th',
          '30th',
          'twelfth',
          '24th',
          'eleventh',
          '22nd',
          '23rd',
          'eighteenth',
          '26th',
          '27th',
          '28th',
          'sixteenth',
          '29th',
          'Eighth',
          '40th',
          'thirteenth',
          '100th',
          'seventeenth',
          

In [391]:
i = 0

dist = torch.norm(filtered_vocab_embeds - pivot_vectors[i], dim=1, p=None)
knn = dist.topk(20, largest=False)
print(knn)

print()
print(tokenizer.convert_ids_to_tokens(index_idxs[int(pivot_tokens[i])]))
print(tokenizer.convert_ids_to_tokens([index_idxs[int(k)] for k in knn.indices]))

torch.return_types.topk(
values=tensor([0.0000, 1.0205, 1.2042, 1.2632, 1.2837, 1.2867, 1.3012, 1.3058, 1.3082,
        1.3172, 1.3287, 1.3305, 1.3372, 1.3387, 1.3394, 1.3451, 1.3564, 1.3584,
        1.3606, 1.3608]),
indices=tensor([13248, 20180, 13822, 11735, 11728, 18947,  5919,  9665,  8955,  7664,
        15550,  3685, 20662, 15282, 15005,  8836,  3909,  9914,  6813, 14785]))

papal
['papal', 'Papal', 'pope', 'ecclesiastical', 'theological', 'episcopal', 'imperial', 'Vatican', 'bishops', 'judicial', 'monastic', 'presidential', 'diocesan', 'theologian', 'archbishop', 'governmental', 'bishop', 'clergy', 'Giovanni', 'Habsburg']


[COME BACK AND FIND A GOOD WAY TO CALCULATE STD]

In [392]:
thing = similarity_matrix != 0
print(thing)
print((similarity_matrix[thing]).std(dim=1))
similarity_matrix.std(dim=1)

tensor([[False,  True,  True,  ..., False, False, False],
        [ True, False,  True,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

k-means clustering

In [None]:
from kmeans_pytorch import kmeans

cluster_ids, cluster_centres = kmeans(filtered_vocab_embeds, 11, distance='cosine', device=torch.device('cuda:0'))

running k-means on cuda:0..


[running kmeans]: 18it [00:06,  2.88it/s, center_shift=0.000079, iteration=18, tol=0.000100] 


In [None]:
print(0 in cluster_ids)

assignments = {int(i): (cluster_ids == int(i)).nonzero().squeeze(dim=1) for i in cluster_ids.unique()}
print(assignments)

assignment_lens = {k: len(v) for k, v in assignments.items()}
print(assignment_lens)



True
{0: tensor([    1,     2,     3,  ..., 48160, 48494, 49159]), 1: tensor([   80,   119,   143,  ..., 48343, 48511, 48672]), 2: tensor([  457,   763,   958,  ..., 49161, 49175, 49223]), 3: tensor([    0,    24,    25,  ..., 49017, 49019, 49081]), 4: tensor([   10,    71,    85,  ..., 49243, 49244, 49245]), 5: tensor([    9,    21,    23,  ..., 48516, 48815, 49184]), 6: tensor([  140,   144,   481,  ..., 49078, 49089, 49091]), 7: tensor([   12,    64,    74,  ..., 48911, 48982, 49117]), 8: tensor([  100,   135,   202,  ..., 49190, 49201, 49205]), 9: tensor([   18,    28,    30,  ..., 49066, 49068, 49071]), 10: tensor([   15,    48,    55,  ..., 48755, 48833, 48891])}
{0: 6227, 1: 3473, 2: 3623, 3: 5536, 4: 4904, 5: 3528, 6: 3338, 7: 5296, 8: 3246, 9: 3716, 10: 6359}
