In [None]:
import re
import os
import torch
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from transformers import T5TokenizerFast, T5ForConditionalGeneration
from dataset_utils import EditDataset
from tqdm.notebook import tqdm
from collections import Counter
from torch import nn
from catalyst import dl
from langdetect import detect
from dataclasses import dataclass
from metrics_utils import *

DOCS_DIR = 'data/documents'
PAGES_DIR = 'data/pages'

In [None]:
dataset_ids = list(map(lambda x: int(x.split('.')[0]), os.listdir(DOCS_DIR)))
dataset_ids[:5]

In [None]:
db_dict = {'obj_id': [], 'old_text': [], 'new_text': [], 'comment': [], 'docs': [], 'diff': []}
for dataset_obj_id in tqdm(dataset_ids):
    with open(f"{PAGES_DIR}/{dataset_obj_id}.json", 'r', encoding='utf-8') as f:
        page_json = json.load(f)
    
    docs_text = ''
    added_docs = set()
    count_docs = 0
    with open(f"{DOCS_DIR}/{dataset_obj_id}.txt", 'r', encoding='utf-8') as f:
        docs_text_plain = f.read()
        docs = docs_text_plain.split('\nNEW_DOC\n')
        for doc_id, doc in enumerate(docs):
            if doc not in added_docs:
                added_docs.add(doc)
                count_docs += 1
                docs_text += f" DOC{count_docs}: {doc}"
            if count_docs > 2:
                break
                
    diff = '\n'.join(page_json['change_texts'][0][0])
    db_dict['diff'].append(diff)
    db_dict['obj_id'].append(dataset_obj_id)
    db_dict['old_text'].append(page_json['old_text'])
    db_dict['new_text'].append(page_json['new_text'])
    db_dict['comment'].append('Comment: ' + page_json['comment'])
    db_dict['docs'].append(docs_text.strip())

In [None]:
df = pd.DataFrame.from_dict(db_dict)
print(df.shape)

In [None]:
df2 = df.sample(80)

In [None]:
@dataclass
class Configuration:
    pass
    
CONFIG = Configuration()
CONFIG.seed = 888
CONFIG.src_max_len = 512
CONFIG.tgt_max_len = 512
CONFIG.pretrained = 't5-small'
CONFIG.batch_size = 4

tokenizer = T5TokenizerFast.from_pretrained(CONFIG.pretrained, model_max_length=1000)

In [None]:
ds_edit = EditDataset(df, tokenizer, CONFIG, text_to_lower=True, comment_to_lower=True)
ds_val = EditDataset(df2, tokenizer, CONFIG, text_to_lower=True, comment_to_lower=True)

## Learning

In [None]:
CONFIG.final_src_max_len = CONFIG.src_max_len
CONFIG.final_tgt_max_len = CONFIG.tgt_max_len
loaders = {
    'train': torch.utils.data.DataLoader(ds_edit, 
                                         batch_size=CONFIG.batch_size,
                                         collate_fn=lambda x: EditDataset.collate_fn(x, tokenizer, CONFIG),
                                         num_workers=4, shuffle=True),
    'valid': torch.utils.data.DataLoader(ds_val, 
                                         batch_size=CONFIG.batch_size,
                                         collate_fn=lambda x: EditDataset.collate_fn(x, tokenizer, CONFIG),
                                         num_workers=4, shuffle=True)
}

In [None]:
class WikiEditModel(nn.Module):
    def __init__(self, pretrained):
        super(WikiEditModel, self).__init__()
        self.pretrained = pretrained
        
    def forward(self, x):
        src, tgt = x
        
        tgt[tgt == 0] == -100
        
        loss = self.pretrained(
            input_ids = src,
            attention_mask = (src != 0).float(),
            labels=tgt,
        ).loss
        return loss
    
    
class Criterion(nn.Module):
    def __init__(self):
        super(Criterion, self).__init__()
        
    def forward(self, pred, tgt):
        return pred

In [None]:
model = WikiEditModel(
    T5ForConditionalGeneration.from_pretrained(CONFIG.pretrained)
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
CONFIG.name = f'{CONFIG.pretrained}'
CONFIG.description = f'{CONFIG.name} p(comment, x_t+1 | x_t, doc)'
CONFIG.group = "wiki_edit"

In [None]:
!rm -rf data/models/"{CONFIG.description}"

In [None]:
CONFIG.n_epochs = 1000
CONFIG.beam_size = 5

In [None]:
runner = dl.SupervisedRunner()
wandb_logger = dl.WandbLogger(project="iterative_lm", 
                              name=CONFIG.name, 
                              group=CONFIG.group)

runner.train(
    loaders=loaders,
    model=model,
    criterion=Criterion(),
    optimizer=optimizer,                            
    num_epochs=CONFIG.n_epochs,
    callbacks=[
        ExactMatchCallback(beam_size=CONFIG.beam_size),
    ],
    loggers={'wandb': wandb_logger},
    logdir=f'data/models/{CONFIG.description}',
    valid_loader='valid',
    valid_metric='loss',
    verbose=True,
)

## Example

In [21]:
device = runner.engine.device
runner.model.eval()
with torch.no_grad():
    for i in [0, 1, 2, 3, 4]:
        src_, tgt_ = ds_edit[i]
        src_inp = torch.tensor(src_['input_ids']).view(1,-1).to(runner.engine.device)
        generated = runner.model.pretrained.generate(src_inp,
                                              attention_mask=(src_inp != 0).float().to(runner.engine.device),
                                              num_beams=CONFIG.beam_size,
                                              num_return_sequences=CONFIG.beam_size,
                                              max_length=1000
        )
        src_text = tokenizer.decode(src_['input_ids'], skip_special_tokens=True)
        tgt_text = tokenizer.decode(tgt_['input_ids'], skip_special_tokens=True)
        print(f'\n\n----------------------------\t QUERY {i}\t ----------------------------\n')
        print(f'Src query:\n {src_text}')
        print(f'\nTgt query:\n {tgt_text}')
        
        print(f'\n\n----------------------------\t GENERATED\t ----------------------------\n')

        for j in range(CONFIG.beam_size):
            to_gen = generated[j]
            gen_text = tokenizer.decode(to_gen, skip_special_tokens=True)
            print(f'{j}:\n{gen_text}\n')
        
        src_text = df.iloc[i]['old_text']
        tgt_text = df.iloc[i]['new_text']
        diff_text = df.iloc[i]['diff']
        print(f'\nX_T:\n{src_text}')
        print(f'\nX_T+1:\n{tgt_text}')
        print(f'\ndiff:\n{diff_text}')



----------------------------	 QUERY 0	 ----------------------------

Src query:
 TEXT what a load of drivel. deb 18:29, 6 feb 2005 (utc) * drivel, yes; true drivel, also yes. (note: i mean that this really did exist as an anti-david beckham soccer chant, not that it accurately describes their sexual practices.) but, that said, i don't think it's notable enough to warrant its own article. delete. bearcat 20:51, 6 feb 2005 (utc) *...and yes, it's nowhere as notable as any other member of :category:football songs and chants. delete. samaritan 23:10, 6 feb 2005 (utc) *delete. non-notable rjfjr 00:02, feb 7, 2005 (utc) DOCS doc1: victoria beckhams been causing something of a stir on her victoria beckham beauty instagram account of late, sharing a series of throwback spice girls photos and videos with the hashtag... doc2: former spice girl admits she never sang live. victoria beckham, aka posh spice, has admitted miming all her live performances with the spice girls - shocker. beckhams rel



----------------------------	 QUERY 2	 ----------------------------

Src query:
 TEXT the blood orange is a variety of orange (citrus sinensis) with crimson, blood-colored flesh. the fruit is smaller than an average orange; its skin is usually pitted, but can be smooth. the juice is sweet but somewhat bitter and less acidic than regular table oranges.blood orange origin the distinctive dark flesh color is due to the presence of anthocyanin, a pigment common to many flowers and fruit, but uncommon in citrus fruits. sometimes there is dark coloring on the exterior of the rind as well, depending on the variety of blood orange. the degree of coloration depends on to light, temperature and variety. the blood orange is a hybrid of ancient origin, possibly between the pomelo and the tangerine. it probably originated in sicily. DOCS doc1: blood oranges are highly acidic, and regular consumption can cause problems for those with acid refluxdisease. this can lead to regurgitation or heartburn(



----------------------------	 QUERY 4	 ----------------------------

Src query:
 TEXT apur sansar (, yr. the world of apu), also known as the world of apu, is the third and final part of the apu trilogy, about a boy named apu in early twentieth century bengal, directed by satyajit ray. released in 1959, the world of apu focuses on apu's adult life, and also introduces the actors soumitra chatterjee and sharmila tagore, who would go on to appear in many subsequent ray films. the film is based on the novel aparajito by bibhutibhushan bandopadhyay. the film won the national film award for best film and several international awards, including the sutherland award for best original and imaginative film and national board of review award for best foreign language film. it has frequently been listed among the greatest films of all time. DOCS doc1: the world of apu original title: apur sansar 1959 not rated 1 h 45 m imdb rating 8.5 /10 15k your rating rate play trailer 2:03 1 video 52 photos