This notebook

In [1]:
import torch
from transformers import BertTokenizer, BertForMaskedLM
import numpy as np
import pickle
from tqdm.auto import tqdm, trange

import sys
import os
from importlib import reload
sys.path.insert(1, '/kaggle/input/dale2020-emnlp-condbert')

In [7]:
import condbert
reload(condbert)
from condbert import CondBertRewriter

In [9]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda:0')

### Load the model

In [10]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [11]:
model = BertForMaskedLM.from_pretrained(model_name)

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
model.to(device);

#### Load vocabularies for spans detection

In [14]:
neg_w_p = '/kaggle/input/dale-output/negative-words.txt'
pos_w_p = '/kaggle/input/dale-output/positive-words.txt'

In [16]:
with open(neg_w_p, "r") as f:
    ss = f.readlines()
negative_words = list(map(lambda x: x[:-1], ss))

with open(pos_w_p, "r") as f:
    s = f.readlines()
positive_words = list(map(lambda x: x[:-1], s))

In [17]:
import pickle
with open('/kaggle/input/dale-output/word2coef.pkl', 'rb') as f:
    word2coef = pickle.load(f)

In [18]:
token_toxicities = []
with open('/kaggle/input/dale-output/vocabularies/token_toxicities.txt', 'r') as f: # я забыла поменять название
    for line in f.readlines():
        token_toxicities.append(float(line))
token_toxicities = np.array(token_toxicities)
token_toxicities = np.maximum(0, np.log(1/(1/token_toxicities-1)))   # log odds ratio

# discourage meaningless tokens
for tok in ['.', ',', '-']:
    token_toxicities[tokenizer.encode(tok)][1] = 3

for tok in ['you']:
    token_toxicities[tokenizer.encode(tok)][1] = 0

### Applying the model

In [19]:
reload(condbert)
from condbert import CondBertRewriter

editor = CondBertRewriter(
    model=model,
    tokenizer=tokenizer,
    device=device,
    neg_words=negative_words,
    pos_words=positive_words,
    word2coef=word2coef,
    token_toxicities=token_toxicities,
)

In [20]:
print(editor.translate('You are an idiot!', prnt=False))

you are an idiot !


### Multiunit

In [21]:
editor = CondBertRewriter(
    model=model,
    tokenizer=tokenizer,
    device=device,
    neg_words=negative_words,
    pos_words=positive_words,
    word2coef=word2coef,
    token_toxicities=token_toxicities,
    predictor=None,
)

In [23]:
import masked_token_predictor_bert
reload(masked_token_predictor_bert)
from masked_token_predictor_bert import MaskedTokenPredictorBert

In [24]:
predictor = MaskedTokenPredictorBert(model, tokenizer, max_len=250, device=device, label=0, contrast_penalty=0.0)
editor.predictor = predictor

def adjust_logits(logits, label):
    return logits - editor.token_toxicities * 3

predictor.logits_postprocessor = adjust_logits

print(editor.replacement_loop('You are an idiot!', verbose=False))

you are an idiot !


------------------------


In [37]:
import pandas as pd
df = pd.read_csv('/kaggle/input/politenessdataset/politeness.csv')

In [38]:
df['lens'] = [len(i.split()) for i in df['txt'].tolist()]
df = df[df['lens'] < 30]
df = df.sort_values(by='score')
df = df[df['is_useful']==1] # чтобы исключить рандомные мусорные сообщения
nonpolite_txt = df['txt'].tolist()[:1000]

In [45]:
nonpolite_txt[:5]

['the daughter bowed her head and said , " lord , why on earth did i invite all these people ? to dinner ? " ? ?',
 'if ( the allegation of market manipulation ) is n\'t true , then why are they insisting on secrecy ? " ? ?',
 'i asked ken to ask knight the obvious question : why did he stop sending ken the daily ptr ?',
 'why he said it : another way to say , " its not my fault , its their fault " fact :',
 'there is widespread disagreement about why -- mirroring the argument over why gas prices spiked in the first place . ? ?']

In [46]:
with open('original_nonpolite.txt', 'w', encoding='utf-8') as f:
    f.write('\n'.join(nonpolite_txt))

In [34]:
import choosers
reload(choosers)
from choosers import EmbeddingSimilarityChooser

# Reproduction

In [35]:
predictor = MaskedTokenPredictorBert(
    model, tokenizer, max_len=250, device=device, label=0, contrast_penalty=0.0, 
    confuse_bert_args=True, # this argument deteriorates quality but is used for backward compatibility
)
editor.predictor = predictor

def adjust_logits(logits, label=0):
    return logits - editor.token_toxicities * 10

predictor.logits_postprocessor = adjust_logits

cho = EmbeddingSimilarityChooser(sim_coef=100, tokenizer=tokenizer)

2023-03-23 16:47:48,975 https://flair.informatik.hu-berlin.de/resources/embeddings/token/glove.gensim.vectors.npy not found in cache, downloading to /tmp/tmpzmrzw1mb


100%|██████████| 153M/153M [00:16<00:00, 9.73MB/s]   

2023-03-23 16:48:06,535 copying /tmp/tmpzmrzw1mb to cache at /root/.flair/embeddings/glove.gensim.vectors.npy
2023-03-23 16:48:06,710 removing temp file /tmp/tmpzmrzw1mb





2023-03-23 16:48:08,104 https://flair.informatik.hu-berlin.de/resources/embeddings/token/glove.gensim not found in cache, downloading to /tmp/tmp_ow_jo6c


100%|██████████| 20.5M/20.5M [00:03<00:00, 5.74MB/s]

2023-03-23 16:48:13,193 copying /tmp/tmp_ow_jo6c to cache at /root/.flair/embeddings/glove.gensim
2023-03-23 16:48:13,220 removing temp file /tmp/tmp_ow_jo6c





In [48]:
redacted_out = []
for i, line in enumerate(tqdm(nonpolite_txt)):
    inp = line.strip()
    out = editor.replacement_loop(inp, verbose=False, chooser=cho, n_top=10, n_tokens=(1,2,3), n_units=1)
    redacted_out.append(out)
    
with open('redacted_out.txt', 'a', encoding='utf-8') as f:
    f.write('\n'.join(redacted_out))

  0%|          | 0/1000 [00:00<?, ?it/s]

In [49]:
for i in range(20):
    print(f'orifinal: {nonpolite_txt[i]}')
    print(f'redacted: {redacted_out[i]}')
    print('-'*15)

orifinal: the daughter bowed her head and said , " lord , why on earth did i invite all these people ? to dinner ? " ? ?
redacted: the daughter i her head and said , " lord , . . the on the . i invite the . . these . . and ? to . ? " ? ?
---------------
orifinal: if ( the allegation of market manipulation ) is n't true , then why are they insisting on secrecy ? " ? ?
redacted: if ( the . the . of market or ) is n ' t true , then . . ' are they the on secrecy ? " ? ?
---------------
orifinal: i asked ken to ask knight the obvious question : why did he stop sending ken the daily ptr ?
redacted: i asked ken to ask him the obvious question : it did he it sending ken the . . the it ?
---------------
orifinal: why he said it : another way to say , " its not my fault , its their fault " fact :
redacted: so he said it : another way to say , " its not my it , its their it " it :
---------------
orifinal: there is widespread disagreement about why -- mirroring the argument over why gas prices sp