## Model Init

In [None]:
from IPython.display import clear_output

!pip3 install transformers==2.8.0
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1Sp3QLnEUSHMk3q81YJKAPiT-17EAKMVb' -O condBERT.zip
!unzip condBERT.zip

clear_output()

In [2]:
BERT_WEIGHTS = 'ru_cond_bert_geotrend/checkpoint-9000/pytorch_model.bin'

In [3]:
VOCAB_DIRNAME = 'ru_vocabularies_geotrend' 

In [4]:
import os
import sys
from importlib import reload


def add_sys_path(p):
    p = os.path.abspath(p)
    print(p)
    if p not in sys.path:
        sys.path.append(p)

In [5]:
add_sys_path('multiword')

/notebook/workspace/style_transfer/condBERT/multiword


In [6]:
from condbert import CondBertRewriter
from choosers import EmbeddingSimilarityChooser
from masked_token_predictor_bert import MaskedTokenPredictorBert

In [8]:
import torch
from transformers import BertTokenizer, BertForMaskedLM
import numpy as np
import pickle
from tqdm.auto import tqdm, trange

device = torch.device('cuda:0')

In [9]:
model_name = 'Geotrend/bert-base-ru-cased'
tokenizer_ru = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)

model_dict = torch.load(BERT_WEIGHTS, map_location=device)

# You can experiment with zero-shot setup or load pretrained weights
# model.load_state_dict(model_dict, strict=False)

In [10]:
model.to(device);

## Loading of pre-defined toxicity weights of vocabulary

In [11]:
with open(VOCAB_DIRNAME + "/negative-words.txt", "r") as f:
    s = f.readlines()
negative_words = list(map(lambda x: x[:-1], s))

with open(VOCAB_DIRNAME + "/positive-words.txt", "r") as f:
    s = f.readlines()
positive_words = list(map(lambda x: x[:-1], s))

In [12]:
import pickle
with open(VOCAB_DIRNAME + '/word2coef.pkl', 'rb') as f:
    word2coef = pickle.load(f)

In [13]:
token_toxicities = []
with open(VOCAB_DIRNAME + '/token_toxicities.txt', 'r') as f:
    for line in f.readlines():
        token_toxicities.append(float(line))
token_toxicities = np.array(token_toxicities)
token_toxicities = np.maximum(0, np.log(1/(1/token_toxicities-1)))   # log odds ratio

# discourage meaningless tokens
for tok in ['.', ',', '-']:
    token_toxicities[tokenizer_ru.encode(tok)][1] = 3

In [14]:
def adjust_logits(logits):
    return logits - token_toxicities * 100

predictor = MaskedTokenPredictorBert(model, tokenizer_ru, max_len=250, device=device, label=0, contrast_penalty=0.0, logits_postprocessor=adjust_logits)

editor = CondBertRewriter(
    model=model,
    tokenizer=tokenizer_ru,
    device=device,
    neg_words=negative_words,
    pos_words=positive_words,
    word2coef=word2coef,
    token_toxicities=token_toxicities,
    predictor=predictor,
)

In [15]:
import numpy as np

from flair.data import Sentence
from flair.embeddings import WordEmbeddings
import gensim

def cosine(v1, v2):
    return np.dot(v1, v2) / np.sqrt(sum(v1**2) * sum(v2**2) + 1e-10)

class RuEmbeddingSimilarityChooser:
    def __init__(self, sim_coef=100, tokenizer=None):
        self.embedding = gensim.models.KeyedVectors.load('../ru_fasttext/model.model')
        self.sim_coef = sim_coef
        self.tokenizer = tokenizer

    def embed(self, text):
        toks = [tok.text for tok in Sentence(text).tokens]
        toks = [self.embedding[tok] for tok in toks]
        return np.mean(toks, axis=0)
    
    def decode(self, tokens):
        if isinstance(tokens, str):
            return tokens
        if self.tokenizer:
            return self.tokenizer.convert_tokens_to_string(tokens)
        return ' '.join(tokens).replace(' ##', '')

    def __call__(self, hypotheses, original=None, scores=None, **kwargs):
        e = self.embed(self.decode(original))
        candidates = [
            (fill_words, score, cosine(e, self.embed(self.decode(fill_words)))) 
            for fill_words, score in zip(hypotheses, scores)
        ]
        candidates = sorted(candidates, key=lambda x: x[1] + x[2] * self.sim_coef, reverse=True)
        return candidates[0][0]

In [16]:
chooser = RuEmbeddingSimilarityChooser(sim_coef=10, tokenizer=tokenizer_ru)

## Inference exmple

In [20]:
text = 'Ты дурак и ничего не понимаешь. Что значит по-твоему построить дорогу?'

You can change locally by one word

In [21]:
print(editor.translate(text, prnt=True))

Ты дурак и ничего не понимаешь. Что значит по-твоему построить дорогу?
Ты дурак и ничего не понимаешь . Что значит по - твоему построить дорогу ?


You can try generate multi-word substitution

In [22]:
print(editor.replacement_loop(text, verbose=False, chooser=chooser, n_tokens=(1,2,3), n_top=10))

Ты знаешь и ничего не понимаешь . Что значит по - прежнему построить дорогу ?
