In [1]:
from transformers import BertTokenizer
from razdel import sentenize
from models.model_builder import AbsSummarizer
import torch
import numpy as np
import pandas as pd

In [82]:
class BertData:
    def __init__(self, bert_model, lower, max_src_tokens, max_tgt_tokens):
        self.max_src_tokens = max_src_tokens
        self.max_tgt_tokens = max_tgt_tokens
        self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=lower, do_basic_tokenize=False)
        self.sep_token = '[SEP]'
        self.cls_token = '[CLS]'
        self.pad_token = '[PAD]'
        self.tgt_bos = '[unused1] '
        self.tgt_eos = ' [unused2]'
        self.tgt_sent_split = ' [unused3] '
        self.sep_vid = self.tokenizer.vocab[self.sep_token]
        self.cls_vid = self.tokenizer.vocab[self.cls_token]
        self.pad_vid = self.tokenizer.vocab[self.pad_token]

    def preprocess(self, src, tgt):
        src_txt = [' '.join(s) for s in src]
        text = ' {} {} '.format(self.sep_token, self.cls_token).join(src_txt)
        src_tokens = self.tokenizer.tokenize(text)[:self.max_src_tokens]
        src_tokens.insert(0, self.cls_token)
        src_tokens.append(self.sep_token)
        src_indices = self.tokenizer.convert_tokens_to_ids(src_tokens)

        _segs = [-1] + [i for i, t in enumerate(src_indices) if t == self.sep_vid]
        segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
        segments_ids = []
        for i, s in enumerate(segs):
            if i % 2 == 0:
                segments_ids += s * [0]
            else:
                segments_ids += s * [1]

        return src_indices, segments_ids

In [126]:
def doc2bert(text):
    src = [s.text.lower().split() for s in sentenize(text)]
    src_indices, segments_ids = bert_data.preprocess(src, '')
    return { "src": src_indices, "segs": segments_ids }

def doc2vec(text, model, mode='MeanSum'):
    doc_bert = doc2bert(text)
    
    src = torch.tensor([doc_bert['src']])
    segs = torch.tensor([doc_bert['segs']])
    mask_src = ~(src == 0)
    
    output = model.bert(src, segs, mask_src)
    
    if mode == 'FirstCLS':
        return output[0][0]
    elif mode == 'MeanSum':
        return output[0].mean(0)
    else:
        raise Exception('Wrong mode')

In [127]:
checkpoint = torch.load('../model_step_40000_telegram.pt', map_location=lambda storage, loc: storage)

In [128]:
args = lambda a: b

args.model_path = '/Users/leshanbog/Downloads/rubert_cased_L-12_H-768_A-12_pt'
args.large = False
args.temp_dir = 'temp'
args.finetune_bert = False
args.encoder = 'bert'
args.max_pos = 256
args.dec_layers = 6
args.share_emb = False
args.dec_hidden_size = 768
args.dec_heads = 8
args.dec_ff_size = 2048
args.dec_dropout = 0.2
args.use_bert_emb = False

bert_data = BertData(args.model_path, True, 512, 128)

In [129]:
model = AbsSummarizer(args, 'cpu', checkpoint)
model.eval()

AbsSummarizer(
  (bert): Bert(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(119547, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
              

### Clustering

In [130]:
import tqdm
from sklearn.cluster import AgglomerativeClustering
from collections import defaultdict
import csv
import json

In [131]:
def read_markup(file_name):
    with open(file_name, "r") as r:
        reader = csv.reader(r, delimiter='\t', quotechar='"')
        header = next(reader)
        for row in reader:
            assert len(header) == len(row)
            record = dict(zip(header, row))
            yield record
            

In [132]:
from sklearn.metrics import classification_report

def calc_metrics(gold_markup, url2label, url2record):
    not_found_count = 0
    for first_url, second_url in list(gold_markup.keys()):
        not_found_in_labels = first_url not in url2label or second_url not in url2label
        not_found_in_records = first_url not in url2record or second_url not in url2record
        if not_found_in_labels or not_found_in_records:
            not_found_count += 1
            gold_markuo.pop((first_url, second_url))
    print("Not found {} pairs from markup".format(not_found_count))
    targets = []
    predictions = []
    for (first_url, second_url), target in gold_markup.items():
        prediction = int(url2label[first_url] == url2label[second_url])
        first = url2record.get(first_url)
        second = url2record.get(second_url)
        targets.append(target)
        predictions.append(prediction)
    print(classification_report(targets, predictions))

In [133]:
markup_path = '/Users/leshanbog/Documents/dataset/news_clustering/ru_pairs_raw_markup.tsv'
full_jsonl_path = '/Users/leshanbog/Documents/dataset/news_clustering/ru_tg_1101_0510.jsonl'

In [134]:
markup = defaultdict(dict)
for record in read_markup(markup_path):
    first_url = record["INPUT:first_url"]
    second_url = record["INPUT:second_url"]
    quality = int(record["OUTPUT:quality"] == "OK")
    markup[(first_url, second_url)] = quality


In [135]:
url2record = dict()
filename2url = dict()
with open(full_jsonl_path, "r") as r:
    for line in r:
        record = json.loads(line)
        url2record[record["url"]] = record
        filename2url[record["file_name"]] = record["url"]

In [142]:
embeds = np.zeros((len(url2record), 768))

for i, (url, record) in tqdm.tqdm(enumerate(url2record.items()), total=embeds.shape[0]):
    text = record["title"] + " " + record["text"]
    text = text.lower().replace('\xa0', ' ')
    embeds[i] = doc2vec(text, model, mode='MeanSum')


  0%|          | 0/487867 [00:00<?, ?it/s][A


In [None]:
clusterer = AgglomerativeClustering(n_clusters=None,
                                    distance_threshold=0.06,
                                    linkage="single",
                                    affinity="cosine")

clusterer.fit(embeds, affinity='cosine')
labels = clusterer.labels_

In [None]:
url2label = dict()
for i, label in enumerate(labels):
    url = etfidx2url[i]
    url2label[url] = label

In [None]:
calc_metrics(gold_markup, url2label, url2record)