In [16]:
from collections import Counter
from collections import defaultdict
from functools import reduce
from operator import itemgetter
from pathlib import Path
from typing import List, Union
from random import randint, sample
import json

from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch.utils.data import IterableDataset

from transformers import BertModel, BertConfig, BertTokenizer
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch

In [28]:
model_path = Path('/home/arcady/data/models/rubert_cased_L-12_H-768_A-12_v2/')
# model_path = Path('/home/mishanya/models/rubert_cased_L-12_H-768_A-12_v2/')
convert_tf_checkpoint_to_pytorch(model_path / 'bert_model.ckpt.index', 
                                 model_path / 'bert_config.json',
                                 model_path / 'ptrubert.pt')
# config = BertConfig.from_pretrained(model_path / 'bert_config.json')
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=False)
# model = BertModel.from_pretrained(str(model_path / 'ptrubert.pt'), config=config)

Building PyTorch model from configuration: {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 119547
}

Save PyTorch model to /home/arcady/data/models/rubert_cased_L-12_H-768_A-12_v2/ptrubert.pt


INFO:transformers.modeling_bert:Converting TensorFlow checkpoint from /home/arcady/data/models/rubert_cased_L-12_H-768_A-12_v2/bert_model.ckpt.index
INFO:transformers.tokenization_utils:Model name '/home/arcady/data/models/rubert_cased_L-12_H-768_A-12_v2' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased). Assuming '/home/arcady/data/models/rubert_cased_L-12_H-768_A-12_v2' is a path or url to a directory containing tokenizer files.
INFO:transformers.tokenization_utils:Didn't find file /home/arcady/data/models/rubert_cased_L-12_H-768_A-12_v2/added_to

In [29]:
from random import randint


class HypoDataset(IterableDataset):
    def __init__(self,
                 tokenizer: BertTokenizer,
                 corpus_path: Union[str, Path],
                 hypo_index_path: Union[str, Path],
                 train_set_path: Union[str, Path]):
        self.tokenizer = tokenizer
        self.corpus = self._read_corpus(corpus_path)
        self.hypo_index = self._read_json(hypo_index_path)
        self.train_set = self._read_json(train_set_path)
    
    @staticmethod
    def _read_json(hypo_index_path: Union[str, Path]):
        with open(hypo_index_path, encoding='utf8') as handle:
            return json.load(handle)
        
    @staticmethod
    def _read_corpus(corpus_path: Union[str, Path]):
        with open(corpus_path, encoding='utf8') as handle:
            return handle.readlines()
    
    def __iter__(self):
        while True:
            train_ind = randint(0, len(self.train_set) - 1)
            hypos, hypes, hype_hypes = self.train_set[train_ind]
            hypos_in_index = [h for h in hypos if h in self.hypo_index]
            
            if not hypos_in_index:
                print(f'Empty index for hypos: {hypos}')
                continue
            if len(hypos) != len(hypos_in_index):
                print(f'Some hypos are lost. Original: {hypos},'
                      f' In index: {hypos_in_index }')
                
            hypo = sample(hypos_in_index, 1)[0]
            sent_idx, inner_hypo_idx = sample(self.hypo_index[hypo], 1)[0]
            sent_toks = self.corpus[sent_idx].split()
            sent_subtok_idxs = []
            subtokens_sent = []
            hypo_mask = []
            for n, tok in enumerate(sent_toks):
                subtokens = self.tokenizer.tokenize(tok)
                subtokens_sent.extend(subtokens)
                subtok_idxs = self.tokenizer.convert_tokens_to_ids(subtokens)
                sent_subtok_idxs.extend(subtok_idxs)
                mask_value = float(n == inner_hypo_idx)
                hypo_mask.extend([mask_value] * len(subtok_idxs))
            
            yield sent_subtok_idxs, subtokens_sent, hypo_mask
            
    def embed_hypernym(self, hypernyms: List[str]):
        raise NotImplementedError  


In [30]:
data_path = Path('/home/arcady/data/hypernym/')
# data_path = Path('/home/mishanya/data/hypernym/')
ds = HypoDataset(tokenizer,
                 data_path / 'tst_corpus.txt',
                 data_path / 'tst_index.json',
                 data_path / 'tst_train.json')



In [31]:
sti, st, m = next(iter(ds))

Some hypos are lost. Original: ['кот', 'кошара'], In index: ['кот']


In [32]:
st, m

(['А',
  'это',
  'странный',
  'пример',
  'в',
  'котором',
  'кот',
  '##ейка',
  'не',
  'участвует',
  '.'],
 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0])

In [33]:
import torch.nn as nn
from typing import Union
from pathlib import Path
import torch


class BertHypoHype(nn.Module):
    def __init__(self,
                 model_path: Union[str, Path],
                 config_path: Union[str, Path]):
        super(BertHypoHype, self).__init__()
        config = BertConfig.from_pretrained(str(config_path))
        self.bert = BertModel.from_pretrained(str(model_path), config=config)
        
    def forward(self, indices_batch: torch.LongTensor, hypo_mask: torch.Tensor):
        h = self.bert(indices_batch)[0]
        m = torch.tensor(hypo_mask).unsqueeze(2)
        return torch.mean(h * m, 1)
        

In [34]:
from torch.utils.data import DataLoader
from pprint import pprint

def batch_collate(batch):
    indices, strings, mask = list(zip(*batch))
    batch_size = len(indices)
    max_len = max(len(idx) for idx in indices)
    padded_indices = torch.zeros(batch_size, max_len, dtype=torch.long)
    padded_mask = torch.zeros(batch_size, max_len, dtype=torch.float)
    for n, (sent_idxs, sent_mask) in enumerate(zip(indices, mask)):
        up_to = len(sent_idxs)
        sent_idxs = torch.tensor(sent_idxs)
        sent_mask = torch.tensor(sent_mask)
        padded_indices[n, :up_to] = sent_idxs
        padded_mask[n, :up_to] = sent_mask
    return padded_indices, padded_mask
    

dl = DataLoader(ds, batch_size=2, collate_fn=batch_collate)

In [35]:
idxs, mask = next(iter(dl))

Some hypos are lost. Original: ['кот', 'кошара'], In index: ['кот']
Some hypos are lost. Original: ['кот', 'кошара'], In index: ['кот']


In [36]:
model = BertHypoHype(model_path / 'ptrubert.pt',
                     model_path / 'bert_config.json')

INFO:transformers.configuration_utils:loading configuration file /home/arcady/data/models/rubert_cased_L-12_H-768_A-12_v2/bert_config.json
INFO:transformers.configuration_utils:Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 119547
}

INFO:transformers.modeling_utils:loading weights file /home/arcady/data/mo

In [37]:
model(idxs, mask)



tensor([[-0.0145,  0.0589, -0.0246,  ..., -0.0310, -0.0322, -0.0069],
        [-0.0131,  0.0444,  0.0253,  ..., -0.0762, -0.1393, -0.0675]],
       grad_fn=<MeanBackward1>)

In [38]:
list(model.state_dict())
# torch.optim.Adam

['bert.embeddings.word_embeddings.weight',
 'bert.embeddings.position_embeddings.weight',
 'bert.embeddings.token_type_embeddings.weight',
 'bert.embeddings.LayerNorm.weight',
 'bert.embeddings.LayerNorm.bias',
 'bert.encoder.layer.0.attention.self.query.weight',
 'bert.encoder.layer.0.attention.self.query.bias',
 'bert.encoder.layer.0.attention.self.key.weight',
 'bert.encoder.layer.0.attention.self.key.bias',
 'bert.encoder.layer.0.attention.self.value.weight',
 'bert.encoder.layer.0.attention.self.value.bias',
 'bert.encoder.layer.0.attention.output.dense.weight',
 'bert.encoder.layer.0.attention.output.dense.bias',
 'bert.encoder.layer.0.attention.output.LayerNorm.weight',
 'bert.encoder.layer.0.attention.output.LayerNorm.bias',
 'bert.encoder.layer.0.intermediate.dense.weight',
 'bert.encoder.layer.0.intermediate.dense.bias',
 'bert.encoder.layer.0.output.dense.weight',
 'bert.encoder.layer.0.output.dense.bias',
 'bert.encoder.layer.0.output.LayerNorm.weight',
 'bert.encoder.layer

In [39]:
torch.optim.Adam(model.bert.encoder.parameters(), lr=1e-5)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 1e-05
    weight_decay: 0
)

In [None]:
# BertTokenizerFast('vocab.txt', do_lower_case=False)

In [None]:
'/home/mishanya/data/hypernym/candidates.tsv'

tokenizer = BertTokenizer('vocab.txt', do_lower_case=False)
token_lengths = Counter()
prefix_index = defaultdict(set)
description_to_synset = {}

# with open('/home/arcady/data/hypernym/candidates.tsv') as handle:
with open('/home/mishanya/data/hypernym/candidates.tsv') as handle:
    for line in tqdm(handle):
        synset_id, description, tokens_str = line.split('\t')
        tokens = tokens_str.strip().lower().split(', ')
        subtokens = [tokenizer.tokenize(tok) for tok in tokens]
        token_lengths.update(len(st) for st in subtokens)
        description_to_synset[description] = tokens
        for st in subtokens:
            prefix_index[st[0]].add(description)

In [None]:
sorted_index = sorted(prefix_index.items(), key=lambda x: len(x[1]), reverse=True)

lens = [len(vals) for descript, vals in sorted_index]

plt.plot(lens)
plt.yscale('log')
plt.grid()

In [None]:
description_to_synset['БЫТОВКА ДЛЯ РАБОЧИХ']

In [None]:
jupytsorted_index[100]

In [None]:
tokenizer.tokenize('мвд')

In [None]:
import json

path = '/home/arcady/data/hypernym/train.json'
p = '/home/arcady/data/'
with open(path) as handle:
    data = json.load(handle)

with open(path, 'w', encoding='utf-8') as handle:
    json.dump(data, handle, indent=4, ensure_ascii=False)
    


In [None]:
from itertools import chain
from collections import Counter


Counter(len(q.split(' ')) for q in chain(*[item[0] for item in data]))