# Подготовка модели

In [None]:
import os

model_filename = 'XLMR_triple_1.torch'

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


if torch.cuda.is_available():
    from torch.cuda import FloatTensor, LongTensor
    DEVICE = torch.device('cuda')
else:
    from torch import FloatTensor, LongTensor
    DEVICE = torch.device('cpu')

np.random.seed(42)

In [None]:
from transformers import XLMRobertaModel, XLMRobertaTokenizer, XLMRobertaConfig

model_version = 'xlm-roberta-large'

config = XLMRobertaConfig.from_pretrained(model_version)
config.output_hidden_states = True
tokenizer = XLMRobertaTokenizer.from_pretrained(model_version, do_lower_case=False)
xlm_model = XLMRobertaModel.from_pretrained(model_version, config=config).to(DEVICE)

num_unfreezed = 0
freezed_params_num = len(list(xlm_model.parameters())) - num_unfreezed
for i, param in enumerate(xlm_model.parameters()):
    if i >= freezed_params_num:
        param.requires_grad = True
    else:
        param.requires_grad = False

# Общие классы и методы

In [None]:
import pandas as pd
import math

class Batch():
    def __init__(self, anchor, positive, negative):
        self.anchor = anchor
        self.positive = positive
        self.negative = negative


class BatchIterator():
    def __init__(self, data, batch_size=128, shuffle=True):
        self._data = data
        self._num_samples = len(data)
        self._batch_size = batch_size
        self._shuffle = shuffle
        self._batches_count = int(math.ceil(len(data) / batch_size))
        
    def __len__(self):
        return self._batches_count
    
    def __iter__(self):
        return self._iterate_batches()

    def _iterate_batches(self):
        indices = np.arange(self._num_samples)
        if self._shuffle:
            np.random.shuffle(indices)

        for start in range(0, self._num_samples, self._batch_size):
            end = min(start + self._batch_size, self._num_samples)

            batch_indices = indices[start:end]
            batch = self._data.iloc[batch_indices]

            yield Batch(tokenize(batch['anchor'].values), tokenize(batch['positive'].values), tokenize(batch['negative'].values))

# Методы для обучения и тестирования бинарных классификаторов

In [None]:
import math
from tqdm import tqdm
tqdm.get_lock().locks = []

def do_epoch(model, criterion, criterion_emb, data_iter, optimizer=None, name=None):
    epoch_loss = 0
    
    is_train = not optimizer is None
    name = name or ''
    model.train(is_train)
    
    batches_count = len(data_iter)
    
    with torch.autograd.set_grad_enabled(is_train):
        with tqdm(total=batches_count) as progress_bar:
            for i, batch in enumerate(data_iter):
                inputs = convert_batch(batch)
                logits_anchor, logits_positive, logits_negative = model.forward(inputs)
                
                loss = criterion_emb(logits_anchor, logits_positive, logits_negative)

                epoch_loss += loss.item()

                if optimizer:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                progress_bar.update()
                progress_bar.set_description('{:>5s} Loss = {:.5f}'.format(name, loss.item()))
                
            progress_bar.set_description('{:>5s} Loss = {:.5f}'.format(
                name, epoch_loss / batches_count)
            )
            progress_bar.refresh()

    return epoch_loss / batches_count


def fit(model, criterion, criterion_emb, optimizer, train_iter, epochs_count=1, val_iter=None):
    best_val_loss = None
    for epoch in range(epochs_count):
        name_prefix = '[{} / {}] '.format(epoch + 1, epochs_count)
        train_loss = do_epoch(model, criterion, criterion_emb, train_iter, optimizer, name_prefix + 'Train:')
        
        if not val_iter is None:
            val_loss = do_epoch(model, criterion, criterion_emb, val_iter, None, name_prefix + '  Val:')
        
        # Сохраняем модель каждую эпоху
        torch.save(model.state_dict(), model_filename)

In [None]:
from sklearn.metrics import precision_recall_fscore_support

# Returns precision, recall and F-score for given model and batch generator
def evaluate_model(model, test_iter):
    for i, batch in enumerate(test_iter):
        inputs = convert_batch(batch)
        logits = model.forward(inputs)
        logits = logits.argmax(-1).cpu().detach().numpy()
        target = batch.label.cpu().detach().numpy()
        precision, recall, f_score, _ = precision_recall_fscore_support(logits, target, average='binary')
        return precision * 100, recall * 100, f_score * 100

# Модель: сиамская сеть + triplet loss

In [None]:
def tokenize(sents):
    with torch.no_grad():
        batch_inputs = []
        for sent in sents:
            inputs = tokenizer.encode(sent, add_special_tokens=True, pad_to_max_length=True, max_length=60)
            batch_inputs.append(inputs)
        return torch.tensor(batch_inputs).to(DEVICE)


def convert_batch(batch, do_tokenize=False):
    if do_tokenize:
        return tokenize(batch.anchor), tokenize(batch.positive), tokenize(batch.negative)
    return batch.anchor, batch.positive, batch.negative

In [None]:
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

data = pd.read_csv('triplet_en-ru.csv')
data = data.dropna()

In [None]:
train_data, test_data = train_test_split(data, train_size=0.7)
train_iter = BatchIterator(train_data, 128)
test_iter = BatchIterator(test_data, 128)

In [None]:
class TripletLoss(nn.Module):

    def __init__(self, margin=0.5):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

In [None]:
def mean_pool_embedding(all_layer_outputs, masks):
  sent_embeds = []
  for embeds in all_layer_outputs:
    embeds = (embeds * masks.unsqueeze(2).float()).sum(dim=1) / masks.sum(dim=1).view(-1, 1).float()
    sent_embeds.append(embeds)
  return sent_embeds


def cls_pool_embedding(all_layer_outputs):
  sent_embeds = []
  for embeds in all_layer_outputs:
    embeds = embeds[:, 0, :]
    sent_embeds.append(embeds)
  return sent_embeds

In [None]:
import sys, os
import logging

class TransformerEncoder(nn.Module):
    def __init__(self, pretrained, emb_dim=1024, lstm_size=128, out_emb_size=1024):
        super().__init__()

        self.pretrained = pretrained
        self.lstm = nn.LSTM(input_size=emb_dim, hidden_size=lstm_size, batch_first=True, bidirectional=True, num_layers=1)
        self.linear = nn.Linear(lstm_size * 2, out_emb_size)
        self.dropout = nn.Dropout(0.2)

    def forward_(self, input_ids):
        attn_mask = input_ids != tokenizer.pad_token_id
        _, _, all_layer_outputs = self.pretrained(input_ids, attention_mask=attn_mask)
        output = mean_pool_embedding(all_layer_outputs, attn_mask)[16]
        # Uncomment to enable training
        # _, (ht, _) = self.lstm(all_layer_outputs[16])
        # output = torch.cat([ht[0, :, :], ht[1, :, :]], dim=1)
        # output = self.dropout(output)
        # output = self.linear(output)
        return output

    def forward(self, inputs):
        emb_anchor = self.forward_(inputs[0])
        emb_positive = self.forward_(inputs[1])
        emb_negative = self.forward_(inputs[2])
        return emb_anchor, emb_positive, emb_negative

In [None]:
%%time

batch = next(iter(train_iter))
model = TransformerEncoder(xlm_model).to(DEVICE)
inputs = convert_batch(batch)
logits = model(inputs)
print(logits[0].shape)

In [None]:
model = TransformerEncoder(xlm_model).to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
criterion_emb = TripletLoss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())

In [None]:
# Uncomment to enable training
# fit(model, criterion, criterion_emb, optimizer, train_iter, epochs_count=30, val_iter=test_iter)

In [None]:
# Предикт по паре. Возвращает вероятность того, что пара является переводной.

def predict(model, source, target):
    source, target = tokenize(source), tokenize(target)
    logits = model.forward([source, target])
    predicted = logits.argmax(-1).cpu().detach().numpy()
    return logits, predicted

# Предикт по батчу предложений. Возвращает эмбеддинги предложений.

def get_emb(model, sents, is_source=True, do_tokenize=True):
    if do_tokenize:
        inputs = tokenize(sents)
    else:
        inputs = torch.stack(sents)
    source, target, _ = model([inputs, inputs, inputs])
    if is_source:
        emb = source
    else:
        emb = target
    return emb.cpu().numpy()

# Тестирование на BUCC и TTW

In [None]:
def score(gold, predicted):
    error_ids = []
    gold = set(gold)
    predicted = set(predicted)
    correct = gold.intersection(predicted)
    error_ids += list(gold.difference(predicted))
    num_correct = len(correct)
    if num_correct > 0:
        precision = num_correct / len(predicted)
        recall = num_correct / len(gold)
        f1_score = 2 * precision * recall / (precision + recall)
        return precision, recall, f1_score, error_ids
    return 0, 0, 0, error_ids

In [None]:
# BUCC-based

import pandas as pd
import os

lang = 'ru'
lang_root = 'bucc2018/{}-en/'.format(lang)

gold = pd.read_csv(os.path.join(lang_root, '{}-en.training.gold'.format(lang)), sep='\t', names=['source', 'target'])
gold = {(item.source, item.target) for _, item in gold.iterrows()}

In [None]:
lang = 'ru'

lang_root = 'bucc2018/{}-en/'.format(lang)
map_file_prefix = 'dict_bucc2018_{}-en'.format(lang)

data_source_map = pd.read_csv(os.path.join(lang_root, '{}-en.training.ru'.format(lang)), sep='\t', names=['id', 'text'])
data_target_map = pd.read_csv(os.path.join(lang_root, '{}-en.training.en'.format(lang)), sep='\t', names=['id', 'text'])

# key = sentence id, value = sentence text
source_sents = data_source_map.set_index('id')['text'].to_dict()
target_sents = data_target_map.set_index('id')['text'].to_dict()

In [None]:
# Выкинем предложения на англе, которые должны быть русскими

en_letters = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z']
cleaned_source_sents = {}
for sent_id, sent in tqdm(source_sents.items()):
    # Больше половины символов - английские
    if sum(map(sent.count, en_letters)) < len(sent) / 2:
        cleaned_source_sents[sent_id] = sent
source_sents = cleaned_source_sents


# Чистим голд от отсутствующих предложений
print(len(gold))
cleaned_gold = []
for source, target in gold:
    if source in source_sents and target in target_sents:
        cleaned_gold.append((source, target))
gold = cleaned_gold
print(len(gold))

In [None]:
# TTW Test

import pickle

def load_data(name):
    with open('{}.pkl'.format(name), 'rb') as f:
        return pickle.load(f)

source_sents = load_data('ttw_test_source_sents')
target_sents = load_data('ttw_test_target_sents')
gold = load_data('ttw_test_gold')

In [None]:
import pickle

map_file_prefix = 'ttw'

def load_map(name):
    with open(os.path.join('{}_{}.pkl'.format(map_file_prefix, name)), 'rb') as f:
        return pickle.load(f)

source_sents = load_map('source')
target_sents = load_map('target')
gold = load_map('gold')

In [None]:
import torch
assert torch.cuda.is_available()

from knn_cuda import KNN
from tqdm import tqdm
import numpy as np
import faiss
from collections import defaultdict

assert faiss.get_num_gpus() > 0

In [None]:
import pickle

def save_data(data, name):
    with open('{}.pkl'.format(name), 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_data(name):
    with open('{}.pkl'.format(name), 'rb') as f:
        return pickle.load(f)

In [None]:
def knn(x, y, k, mem=5*1024*1024*1024):
  dim = x.shape[1]
  batch_size = mem // (dim*4)
  sim = np.zeros((x.shape[0], k), dtype=np.float32)
  ind = np.zeros((x.shape[0], k), dtype=np.int64)
  for xfrom in range(0, x.shape[0], batch_size):
    xto = min(xfrom + batch_size, x.shape[0])
    bsims, binds = [], []
    for yfrom in range(0, y.shape[0], batch_size):
      yto = min(yfrom + batch_size, y.shape[0])
      print('{}-{}  ->  {}-{}'.format(xfrom, xto, yfrom, yto))
      idx = faiss.IndexFlatIP(dim)
      idx = faiss.index_cpu_to_all_gpus(idx)
      idx.add(y[yfrom:yto])
      bsim, bind = idx.search(x[xfrom:xto], min(k, yto-yfrom))
      bsims.append(bsim)
      binds.append(bind + yfrom)
      del idx
    bsims = np.concatenate(bsims, axis=1)
    binds = np.concatenate(binds, axis=1)
    aux = np.argsort(-bsims, axis=1)
    for i in range(xfrom, xto):
      for j in range(k):
        sim[i, j] = bsims[i-xfrom, aux[i-xfrom, j]]
        ind[i, j] = binds[i-xfrom, aux[i-xfrom, j]]
  return sim, ind


def get_embeddings(model, id2sent, batch_size=512):
    model.eval()
    sent_list = list(id2sent.values())
    ids = list(id2sent.keys())
    with torch.no_grad():
        vectors = []
        for start in tqdm(range(0, len(sent_list), batch_size)):
            end = min(start + batch_size, len(sent_list))
            batch_list = sent_list[start:end]
            vectors.extend(get_emb(model, batch_list))
        
    assert len(vectors) == len(ids) == len(sent_list)
    vectors = np.array(vectors)
    faiss.normalize_L2(vectors)
    return vectors


def score_pair(x, y, fwd_mean, bwd_mean, margin, dist='cosine'):
  if dist == 'cosine':
    return margin(x.dot(y), (fwd_mean + bwd_mean) / 2)
  else:
    l2 = ((x - y) ** 2).sum()
    sim = 1 / (1 + l2)
    return margin(sim, (fwd_mean + bwd_mean) / 2)


def score_candidates(x, y, candidate_inds, fwd_mean, bwd_mean, margin, dist='cosine'):
  scores = np.zeros(candidate_inds.shape)
  for i in range(scores.shape[0]):
    for j in range(scores.shape[1]):
      k = candidate_inds[i, j]
      scores[i, j] = score_pair(x[i], y[k], fwd_mean[i], bwd_mean[k], margin, dist)
  return scores


def shift_embeddings(x, y):
  print(' - shift embeddings')
  delta = x.mean(axis=0) - y.mean(axis=0)
  x2y = x - delta
  y2x = y + delta
  return x2y, y2x


def get_candidates(model, sources, targets, return_all=False, do_save=False, save_prefix='xlmr_',
                   n_candidates=10, batch_size=512, margin='ratio', threshold=0, retrieval='max', use_shift=True, do_load=False):
    if do_load and os.path.exists(save_prefix + 'source_vectors.pkl'):
        print('Loading source embeddings...')
        source_vectors = load_data(save_prefix + 'source_vectors')
    else:
        print('Computing source embeddings...')
        source_vectors = get_embeddings(model, sources)
        if do_save:
            save_data(source_vectors, save_prefix + 'source_vectors')
    print(source_vectors.shape)
    assert len(sources) == len(source_vectors)

    if do_load and os.path.exists(drive_root, save_prefix + 'target_vectors.pkl'):
        print('Loading target embeddings...')
        target_vectors = load_data(drive_root, save_prefix + 'target_vectors')
    else:
        print('Computing target embeddings...')
        target_vectors = get_embeddings(model, targets)
        if do_save:
            save_data(target_vectors, drive_root, save_prefix + 'target_vectors')
    print(target_vectors.shape)
    assert len(targets) == len(target_vectors)

    if use_shift:
        x2y, y2x = shift_embeddings(source_vectors, target_vectors)
    
    print('Computing distances...')
    if use_shift:
        x2y_sim, x2y_ind = knn(x2y, target_vectors, min(target_vectors.shape[0], n_candidates))
        x2y_mean = x2y_sim.mean(axis=1)
    else:
        x2y_sim, x2y_ind = knn(source_vectors, target_vectors, min(target_vectors.shape[0], n_candidates))
        x2y_mean = x2y_sim.mean(axis=1)

    print('Computing reverse distances...')
    if use_shift:
        y2x_sim, y2x_ind = knn(y2x, source_vectors, min(source_vectors.shape[0], n_candidates))
        y2x_mean = y2x_sim.mean(axis=1)
    else:
        y2x_sim, y2x_ind = knn(target_vectors, source_vectors, n_candidates)
        y2x_mean = y2x_sim.mean(axis=1)

    if margin == 'absolute':
        margin = lambda a, b: a
    elif margin == 'distance':
        margin = lambda a, b: a - b
    else:  # margin == 'ratio':
        margin = lambda a, b: a / b

    print('Scoring candidates...')
    if use_shift:
        fwd_scores = score_candidates(x2y, target_vectors, x2y_ind, x2y_mean, y2x_mean, margin)
        bwd_scores = score_candidates(y2x, source_vectors, y2x_ind, y2x_mean, x2y_mean, margin)
    else:
        fwd_scores = score_candidates(source_vectors, target_vectors, x2y_ind, x2y_mean, y2x_mean, margin)
        bwd_scores = score_candidates(target_vectors, source_vectors, y2x_ind, y2x_mean, x2y_mean, margin)

    fwd_best = x2y_ind[np.arange(source_vectors.shape[0]), fwd_scores.argmax(axis=1)]
    bwd_best = y2x_ind[np.arange(target_vectors.shape[0]), bwd_scores.argmax(axis=1)]

    print('Retrieving results...')
    source_keys = list(sources.keys())
    target_keys = list(targets.keys())
    predicted = []
    distances = []
    if retrieval == 'intersection':
        for i, j in enumerate(fwd_best):
            if bwd_best[j] == i:
                predicted.append((source_keys[i], target_keys[j]))
                distances.append(fwd_scores[i].max())

    if retrieval == 'max':
        indices = np.stack((np.concatenate((np.arange(source_vectors.shape[0]), bwd_best)),
                            np.concatenate((fwd_best, np.arange(target_vectors.shape[0])))), axis=1)
        scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
        seen_src, seen_trg = set(), set()
        for i in np.argsort(-scores):
            src_ind, trg_ind = indices[i]
            if not src_ind in seen_src and not trg_ind in seen_trg:
                seen_src.add(src_ind)
                seen_trg.add(trg_ind)
                if scores[i] > threshold:
                    predicted.append((source_keys[src_ind], target_keys[trg_ind]))
                    distances.append(scores[i])

    id2candidates_full = defaultdict(list)
    id2distances_full = defaultdict(list)
    for i, _ in enumerate(x2y_ind):
        for jj, j in enumerate(x2y_ind[i]):
            id2candidates_full[source_keys[i]].append(target_keys[j])
            id2distances_full[source_keys[i]].append(fwd_scores[i][jj])

    if return_all:
        return predicted, distances, id2candidates_full, id2distances_full
    return predicted, distances

In [None]:
%%time

predicted, id2dist, id2candidates_full, id2distances_full = get_candidates(model, source_sents, target_sents, save_prefix='ttw_trained_',
                                                                           return_all=True, do_save=True, retrieval='intersection', do_load=True)

In [None]:
candidate2score = {}
for i, pair in enumerate(predicted):
    candidate2score[pair] = id2dist[i]

In [None]:
def bucc_optimize(candidate2score, gold):
  items = sorted(candidate2score.items(), key=lambda x: -x[1])
  ngold = len(gold)
  nextract = ncorrect = 0
  threshold = 0
  best_f1 = 0
  for i in tqdm(range(len(items))):
    nextract += 1
    if items[i][0] in gold:
      ncorrect += 1
    if ncorrect > 0:
      precision = ncorrect / nextract
      recall = ncorrect / ngold
      f1 = 2 * precision * recall / (precision + recall)
      if f1 > best_f1:
        best_f1 = f1
        threshold = (items[i][1] + items[i + 1][1]) / 2
  print(best_f1)
  return threshold

In [None]:
threshold = bucc_optimize(candidate2score, gold)
print(threshold)

In [None]:
print(len(predicted))
predicted = [pair for pair, score in candidate2score.items() if score >= threshold]
print(len(predicted))

In [None]:
for i, (source_id, target_id) in enumerate(predicted[:50]):
    print(f'[{source_id}] {source_sents[source_id]}')
    print(f'[{target_id}, dist={id2dist[i]:.6f}] {target_sents[target_id]}\n')

In [None]:
pre, rec, f1, error_ids = score(gold, predicted)
print(f'Precision: {pre:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}')

In [None]:
print(f'Accuracy 10: {accuracy_n(id2candidates_full, gold, 10):.2f}')
print(f'Accuracy 5: {accuracy_n(id2candidates_full, gold, 5):.2f}')
print(f'Accuracy 1: {accuracy_n(id2candidates_full, gold, 1):.2f}')

In [None]:
# Посмотрим на ошибки

for err_source, err_target in error_ids[:100]:
    print('{}\nGold: [{}] {}'.format(source_sents[err_source], err_target, target_sents[err_target]))
    
    for i, target_id in enumerate(id2candidates_full[err_source]):
        print('dist={:.6f} [{}] {}'.format(id2distances_full[err_source][i], target_id, target_sents[target_id]))
    print('\n')

In [None]:
# Длина правильных пар и неправильных

source_lens_true = []
source_lens_false = []
target_lens_true = []
target_lens_false = []
for source, target in tqdm(predicted):
    if (source, target) in gold:
        source_lens_true.append(len(source_sents[source]))
        target_lens_true.append(len(target_sents[target]))
    else:
        source_lens_false.append(len(source_sents[source]))
        target_lens_false.append(len(target_sents[target]))

print(f'True ru mean len: {np.mean(source_lens_true):.4f}')
print(f'True en mean len: {np.mean(target_lens_true):.4f}')
print(f'False ru mean len: {np.mean(source_lens_false):.4f}')
print(f'False en mean len: {np.mean(target_lens_false):.4f}')