In [1]:
from pathlib import Path
from transformers import BertTokenizer
from dataset import get_hypernyms_list_from_train, HypoDataset


data_path = Path('/home/hdd/data/hypernym/')
corpus_path = data_path / 'corpus.news_dataset-sample.token.txt'
hypo_index_path = data_path / 'index.full.news_dataset-sample.json'
train_set_path = data_path / 'train.cased.json'
candidates_path = data_path / 'candidates.cased.tsv'
wordnet_path = Path('/home/vimary/code-projects/dialog-2020-challenge/taxonomy-enrichment/data')

model_path = Path('/home/hdd/models/rubert_v2/rubert_cased_L-12_H-768_A-12_v2/')
tokenizer_vocab_path = model_path / 'vocab.txt'

In [2]:
tokenizer = BertTokenizer(tokenizer_vocab_path, do_lower_case=False)

In [3]:
# train_hype_list = get_hypernyms_list_from_train(train_set_path)

# ds = HypoDataset(tokenizer,
#                  corpus_path,
#                  hypo_index_path,
#                  train_set_path,
#                  train_hype_list)

# corpus = df.corpus
# index = ds.hypo_index[hypo]

In [4]:
# sample = ds.train_set[23]
# sample

In [5]:
# hypo = sample[0][0]
# hypo

In [3]:
from corpus_indexed import CorpusIndexed

corpus_indexed = CorpusIndexed(hypo_index_path)

index = corpus_indexed.idx
corpus = corpus_indexed.corpus

Loading index from /home/hdd/data/hypernym/index.full.news_dataset-sample.json.
Loading corpus from /home/hdd/data/hypernym/corpus.news_dataset-sample.token.txt.


### Explore predictions for specific hypo

In [56]:
hypo = 'школьный учитель'

In [57]:
hypo in index

True

In [58]:
count = 0
for corp_id, start_idx, end_idx in index[hypo]:
    count += 1
    text = corpus[corp_id]
    
    print(count, '\t', text.split()[start_idx: end_idx])
    print(text)

1 	 ['школьного', 'учителя']
Соколовская : Размер оклада школьного учителя должен составлять 70 % от всей заработной платы Размер оклада школьного учителя должен составлять 70 % от всей заработной платы .
2 	 ['школьного', 'учителя']
Соколовская : Размер оклада школьного учителя должен составлять 70 % от всей заработной платы Размер оклада школьного учителя должен составлять 70 % от всей заработной платы .
3 	 ['школьных', 'учителей']
При обсуждении вопросов заработной платы и нагрузки на школьных учителей необходимо в первую очередь обратиться к федеральному закону о труде , в котором говорится , что должностной оклад – это фиксированный размер оплаты труда работника за исполнение должностных обязанностей без учета компенсационных , стимулирующих и социальных выплат .
4 	 ['школьного', 'учителя']
Первыми Героями СССР были летчики М . В . Водопьянов , И . В . Доронин , Н . П . Каманин , А . В . Ляпидевский ( наш земляк родом из села Белая Глина , сын школьного учителя .
5 	 ['школьных'

In [59]:
from embedder import get_word_embeddings
from transformers import BertConfig, BertTokenizer, BertModel
import torch

In [60]:
config = BertConfig.from_pretrained(model_path / 'bert_config.json')
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=False)
model = BertModel.from_pretrained(str(model_path / 'ptrubert.pt'), config=config)

emb_mat = model.embeddings.word_embeddings.weight

In [61]:
from utils import get_wordnet_synsets, enrich_with_wordnet_relations, synsets2senses


wordnet_synsets = get_wordnet_synsets(wordnet_path.glob('synsets.N*'))
# wordnet_synsets = get_wordnet_synsets(wordnet_path.glob('synsets.*'))
# enrich_with_wordnet_relations(wordnet_synsets, wordnet_path.glob('synset_relations.*'))

Parsing /home/vimary/code-projects/dialog-2020-challenge/taxonomy-enrichment/data/synsets.N.xml.


In [62]:
LEVEL = 'sense'

if LEVEL == 'synset':
    # embedding all synsets (list of phrases)
    hype_embs, hype_list = [], []
    for s_id, synset in wordnet_synsets.items():
        if s_id[-1] != 'N':
            print(f'skipping {s_id}')
        hype_list.append(synset['ruthes_name'].lower())
        senses = [sense['content'].lower() for sense in synset['senses']]
        hype_embs.append(get_word_embeddings(senses, emb_mat, tokenizer=tokenizer).mean(dim=0))
    hype_embs = torch.stack(hype_embs, dim=0)
elif LEVEL == 'sense':
    # embedding all senses (phrases)
    hype_list = sorted(set(synset['ruthes_name'].lower()
                           for s_id, synset in wordnet_synsets.items()
                           if s_id[-1] == 'N'))
    hype_embs = get_word_embeddings(hype_list, emb_mat, tokenizer=tokenizer)
print(f"Hypernym embeddings are of shape {hype_embs.shape}")

Hypernym embeddings are of shape torch.Size([29296, 768])


In [63]:
NUM_CONTEXTS = 1


token_idxs, hypo_masks = [], []
max_len = 0

for i, (corp_id, start_idx, end_idx) in enumerate(index[hypo]):
    text = corpus[corp_id]
    
    subtokens, mask = ['[CLS]'], [0.0]
    for n, token in enumerate(text.split()):
        current_subtokens = tokenizer.tokenize(token)
        subtokens.extend(current_subtokens)
        mask_val = 1.0 if n in range(start_idx, end_idx) else 0.0
        mask.extend([mask_val] * len(current_subtokens))
    subtokens.append('[SEP]')
    print(sum(mask))
    mask.append(0.0)
    
    token_idxs.append(tokenizer.convert_tokens_to_ids(subtokens))
    hypo_masks.append(mask)
    max_len = max_len if max_len > len(subtokens) else len(subtokens)
    if i >= NUM_CONTEXTS - 1:
        break
    
# pad to max_len
attn_masks = [[1.0] * len(idxs) + [0.0] * (max_len - len(idxs)) for idxs in token_idxs]
token_idxs = [idxs + [0] * (max_len - len(idxs)) for idxs in token_idxs]
hypo_masks = [mask + [0.0] * (max_len - len(mask)) for mask in hypo_masks]
    
# h: [batch_size, seq_len, hidden_size]
h, v = model(torch.tensor(token_idxs), attention_mask=torch.tensor(attn_masks))
print('h norm =', h.norm(dim=2))
# hypo_mask_t: [batch_size, seq_len]
hypo_mask_t = torch.tensor(hypo_masks)
# r: [batch_size, seq_len, hidden_size]
r = h * hypo_mask_t.unsqueeze(2)
# r: [batch_size, hidden_size]
r = r.sum(dim=1) / hypo_mask_t.sum(dim=1, keepdim=True)
# # r: [hidden_size]
# r = r.mean(dim=0)

2.0
h norm = tensor([[12.6992, 16.4481, 14.2615, 13.5348, 17.9549, 18.5590, 18.9328, 15.3944,
         17.7733, 18.5994, 18.4554, 18.4417, 18.6928, 18.7628, 19.0013, 17.8019,
         18.4879, 17.7364, 18.3655, 19.0120, 14.6933, 16.9564, 18.1769, 18.5649,
         18.6077, 18.7214, 18.2263, 18.9732, 17.5855, 18.7139, 11.4048, 11.7314]],
       grad_fn=<NormBackward3>)


In [64]:
METRIC = 'cosine'

# scores: [batch_size, vocab_size]
if METRIC == 'cosine':
    # hype_embs_norm: [vocab_size, hidden_size]
    hype_embs_norm = hype_embs / hype_embs.norm(dim=1, keepdim=True)
    scores = r @ hype_embs_norm.T / r.norm(dim=1)
elif METRIC == 'product':
    scores = r @ hype_embs.T
    
scores = torch.log_softmax(scores, dim=1)
# scores: [vocab_size]
scores = scores.mean(dim=0)

In [65]:
vals, indices = torch.topk(scores, 50)

In [66]:
r.shape, r.norm()

(torch.Size([1, 768]), tensor(15.1952, grad_fn=<NormBackward0>))

In [67]:
idx = 6154

hype_embs[idx].norm()

tensor(0.4623, grad_fn=<NormBackward0>)

In [68]:
(hype_embs[idx] * r).sum()

tensor(-0.5005, grad_fn=<SumBackward0>)

In [69]:
'учитель' in hype_list

False

In [70]:
for score, i in zip(vals.tolist(), indices.tolist()):
    print(f'{hype_list[i]:25} {score:.3} {i:10}')

школьный учитель          -10.1      28423
школьник                  -10.2      28414
школьный урок             -10.2      28422
преодолеть                -10.2      19067
организованный преступник -10.2      15550
преступник                -10.2      19120
школьный предмет          -10.2      28421
школьное обучение         -10.2      28418
учитель физкультуры       -10.2      26714
карат                     -10.2       8950
преподаватель             -10.2      19073
алфавит                   -10.2        547
вельск                    -10.2       2657
эльф                      -10.2      28909
ушица                     -10.2      26731
школьное изложение        -10.2      28417
онега                     -10.2      15317
ядро процессора           -10.2      29203
покрыть поверхность       -10.2      18109
ассамблея                 -10.2        967
нагорье                   -10.2      13299
начальные классы          -10.2      13753
истина, правда            -10.2       8575
школьный за

### Calculate bert's metrics on train

In [5]:
train_hype_list = get_hypernyms_list_from_train(train_set_path)

ds = HypoDataset(tokenizer,
                 corpus_path,
                 hypo_index_path,
                 train_set_path,
                 train_hype_list)

corpus = ds.corpus
index = ds.hypo_index

0.40 hyponyms are not found in the index


In [14]:
from embedder import get_word_embeddings
from transformers import BertConfig, BertTokenizer, BertModel
import torch


config = BertConfig.from_pretrained(model_path / 'bert_config.json')
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=False)
model = BertModel.from_pretrained(str(model_path / 'ptrubert.pt'), config=config)

emb_mat = model.embeddings.word_embeddings.weight

In [18]:
from utils import get_wordnet_synsets, enrich_with_wordnet_relations

wordnet_synsets = get_wordnet_synsets(wordnet_path.glob('synsets.*'))
enrich_with_wordnet_relations(wordnet_synsets, wordnet_path.glob('synset_relations.*'))

Parsing /home/vimary/code-projects/dialog-2020-challenge/taxonomy-enrichment/data/synsets.A.xml.
Parsing /home/vimary/code-projects/dialog-2020-challenge/taxonomy-enrichment/data/synsets.N.xml.
Parsing /home/vimary/code-projects/dialog-2020-challenge/taxonomy-enrichment/data/synsets.V.xml.
Parsing /home/vimary/code-projects/dialog-2020-challenge/taxonomy-enrichment/data/synset_relations.N.xml.
Parsing /home/vimary/code-projects/dialog-2020-challenge/taxonomy-enrichment/data/synset_relations.A.xml.
Parsing /home/vimary/code-projects/dialog-2020-challenge/taxonomy-enrichment/data/synset_relations.V.xml.


In [25]:
import collections

LEVEL = 'sense'

hypo2synset = collections.defaultdict(list)
if LEVEL == 'synset':
    # embedding all synsets (list of phrases)
    hype_embs, hype_list = [], []
    for s_id, synset in wordnet_synsets.items():
        if s_id[-1] != 'N':
            print(f'skipping {s_id}')
        hype_list.append(synset['ruthes_name'].lower())
        hypo2synset[hype_list[-1]].append(s_id)
        senses = [sense['content'].lower() for sense in synset['senses']]
        hype_embs.append(get_word_embeddings(senses, emb_mat, tokenizer=tokenizer).mean(dim=0))
    hype_embs = torch.stack(hype_embs, dim=0)
elif LEVEL == 'sense':
    # embedding all senses (phrases)
    hype_list = []
    for s_id, synset in wordnet_synsets.items():
        if s_id[-1] == 'N':
            hype_list.append(synset['ruthes_name'].lower())
            hypo2synset[hype_list[-1]].append(s_id)
    hype_embs = get_word_embeddings(hype_list, emb_mat, tokenizer=tokenizer)
print(f"Hypernym embeddings are of shape {hype_embs.shape}")

Hypernym embeddings are of shape torch.Size([29296, 768])


In [64]:
def encode_hypo(hypo: str, num_contexts: int = 1):
    token_idxs, hypo_masks = [], []
    max_len = 0

    for i, (corp_id, start_idx, end_idx) in enumerate(index[hypo]):
        text = corpus[corp_id]
        if len(text.split()) > 250:
            continue

        subtokens, mask = ['[CLS]'], [0.0]
        for n, token in enumerate(text.split()):
            current_subtokens = tokenizer.tokenize(token)
            subtokens.extend(current_subtokens)
            mask_val = 1.0 if n in range(start_idx, end_idx) else 0.0
            mask.extend([mask_val] * len(current_subtokens))
        subtokens.append('[SEP]')
#         print(sum(mask))
        mask.append(0.0)

        token_idxs.append(tokenizer.convert_tokens_to_ids(subtokens))
        hypo_masks.append(mask)
        max_len = max_len if max_len > len(subtokens) else len(subtokens)
        if i >= num_contexts - 1:
            break

    # pad to max_len
    attn_masks = [[1.0] * len(idxs) + [0.0] * (max_len - len(idxs)) for idxs in token_idxs]
    token_idxs = [idxs + [0] * (max_len - len(idxs)) for idxs in token_idxs]
    hypo_masks = [mask + [0.0] * (max_len - len(mask)) for mask in hypo_masks]
    return token_idxs, hypo_masks, attn_masks

encode_hypo('эпилепсия')[0]

[[101,
  5405,
  9111,
  29924,
  22622,
  40623,
  128,
  1997,
  54560,
  851,
  34005,
  98701,
  2790,
  108930,
  132,
  102]]

In [None]:
import itertools


METRIC = 'cosine'
NUM_CONTEXTS = 1

set_matches, set_matches_ext = [], []
for i, (hypo, hypes, hype_hypes) in enumerate(ds.dataset):
    if i % 100 == 1:
        print(sum(set_matches) / len(set_matches), '\t', sum(set_matches_ext) / len(set_matches_ext))
    try:
        token_idxs, hypo_masks, attn_masks = encode_hypo(hypo, num_contexts=NUM_CONTEXTS)
        hypernyms = list(itertools.chain(*(hypes + hype_hypes)))

        # h: [batch_size, seq_len, hidden_size]
        h, v = model(torch.tensor(token_idxs), attention_mask=torch.tensor(attn_masks))
    #     print('h norm =', h.norm(dim=2))
        # hypo_mask_t: [batch_size, seq_len]
        hypo_mask_t = torch.tensor(hypo_masks)
        # r: [batch_size, seq_len, hidden_size]
        r = h * hypo_mask_t.unsqueeze(2)
        # r: [batch_size, hidden_size]
        r = r.sum(dim=1) / hypo_mask_t.sum(dim=1, keepdim=True)

        # scores: [batch_size, vocab_size]
        if METRIC == 'cosine':
            # hype_embs_norm: [vocab_size, hidden_size]
            hype_embs_norm = hype_embs / hype_embs.norm(dim=1, keepdim=True)
            scores = r @ hype_embs_norm.T / r.norm(dim=1)
        elif METRIC == 'product':
            scores = r @ hype_embs.T

        scores = torch.log_softmax(scores, dim=1)
        # scores: [vocab_size]
        scores = scores.mean(dim=0).detach().cpu().numpy()

        preds = sorted(zip(hype_list, scores), key=lambda x: x[1], reverse=True)[:10]
        pred_hypers = [(sense['content'].lower(), sc) for h, sc in preds for s_id in hypo2synset[h]
                       for h_s_id in wordnet_synsets[s_id].get('hypernyms', [])
                       for sense in wordnet_synsets[h_s_id['id']]['senses']]
        set_matches.append(bool(set(p for p, sc in preds) & set(hypernyms)))
        set_matches_ext.append(bool(set(p for p, sc in preds+pred_hypers) & set(hypernyms)))
    except Exception as msg:
        print('warning: ', msg)
#     if set_matches[-1] != 0:
#         print(hypo)
#         print(hypernyms)
#         print(set(p for p, sc in preds) & set(hypernyms))
    
print(sum(set_matches) / len(set_matches))
print(sum(set_matches_ext) / len(set_matches_ext))

0.0 	 1.0
0.009900990099009901 	 0.31683168316831684
0.03980099502487562 	 0.40298507462686567
0.03986710963455149 	 0.38870431893687707
0.0399002493765586 	 0.40399002493765584
0.03992015968063872 	 0.4471057884231537
0.048252911813643926 	 0.4425956738768719
0.04992867332382311 	 0.44935805991440797
0.052434456928838954 	 0.4606741573033708
0.05549389567147614 	 0.46503884572697
0.057942057942057944 	 0.4675324675324675
0.055404178019981834 	 0.4641235240690282
0.05661948376353039 	 0.4612822647793505
0.05534204458109147 	 0.46887009992313605
0.055674518201284794 	 0.4725196288365453
0.05263157894736842 	 0.4776815456362425
0.05121798875702686 	 0.48282323547782635
0.052322163433274546 	 0.5008818342151675
0.053303720155469185 	 0.5047196002220988
0.052603892688058915 	 0.5023671751709626
0.050474762618690654 	 0.5077461269365318


In [66]:
print(sum(set_matches) / len(set_matches))
print(sum(set_matches_ext) / len(set_matches_ext))

0.05763473053892216
2.5773453093812377


In [None]:
import matplotlib.pyplot as plt

plt.hist(set_matches_ext, bins=10)

In [None]:
hypo, preds, pred_hypers