<a href="https://colab.research.google.com/github/serveer/Sentence_Simplification/blob/main/Project_Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install the required pacakges

In [None]:
%%capture
!pip install transformers datasets evaluate rouge_score sacrebleu sacremoses wordfreq levenshtein

### Import and pre-process the splitted data

In [None]:
# mount to google drive
from google.colab import drive
drive.mount('/content/drive')

# unzip file
import tarfile
tar = tarfile.open('/content/drive/Shareddrives/CIS 530 Project/Data/wikipedia_v2_splitted.tar.gz', "r:gz")
tar.extractall()
tar.close()

# 1)load train/dev/test data
# 2) split the X and y
# 3.1) convert to lower case
# 3.2) recover the upper case parenthesis token
# 3.3) remove the \n at the end of y
import pandas as pd
import numpy as np
def preprocess_data(lines):
    X_out = []
    y_out = []
    for line in lines:
        X, y = line.split('\t')
        X = X.lower() \
        .replace("-lrb-", "") \
        .replace("-rrb-", "") \
        .replace(" â", "") \
        .replace(" ``", "") \
        .replace(" .", ".") \
        .replace(" ?", "?") \
        .replace(" !", "!") \
        .replace(" ,", ",") \
        .replace(" ' ", "'") \
        .replace(" n't", "n't") \
        .replace(" 'm", "'m") \
        .replace(" 's", "'s") \
        .replace(" 've", "'ve") \
        .replace(" 're", "'re")
        y = y.strip().lower() \
        .replace("-lrb-", "") \
        .replace("-rrb-", "") \
        .replace(" â", "") \
        .replace(" ``", "") \
        .replace(" .", ".") \
        .replace(" ?", "?") \
        .replace(" !", "!") \
        .replace(" ,", ",") \
        .replace(" ' ", "'") \
        .replace(" n't", "n't") \
        .replace(" 'm", "'m") \
        .replace(" 's", "'s") \
        .replace(" 've", "'ve") \
        .replace(" 're", "'re")
        if len(X.split(' ')) <= 103 and len(y.split(' ')) <= 103:
            X_out.append(X)
            y_out.append(y)
    
    return X_out, y_out

with open('train.txt') as train:
    lines_train = train.readlines()
    X_train, y_train = preprocess_data(lines_train)
    print('Train size: {}'.format(len(X_train)))

with open('dev.txt') as dev:
    lines_dev = dev.readlines()
    X_dev, y_dev = preprocess_data(lines_dev)
    print('Dev size: {}'.format(len(X_dev)))

with open('test.txt') as test:
    lines_test = test.readlines()
    X_test, y_test = preprocess_data(lines_test)
    print('Test size: {}'.format(len(X_test)))

Mounted at /content/drive
Train size: 134027
Dev size: 16758
Test size: 16754


In [None]:
def generate_stats(X, y):
    avg_in_len = sum([len(sent) for sent in X]) / len(X)
    avg_out_len = sum([len(sent) for sent in y]) / len(y)
    train = X + y
    all_tok = [word for sent in train for word in sent.split(' ')]
    uni_tok = len(set(all_tok))
    tot_tok = len(all_tok)
    print('average input length: {:.2f}, average output length: {:.2f}, unique token: {}, total token: {}'.format(avg_in_len, avg_out_len, uni_tok, tot_tok))

In [None]:
# randomly sample data from training, development, and test set
train_size = 4000
dev_size = 500
test_size = 500

def random_sampling(X, y, size):
    import random
    random.seed(42)
    idx = random.sample(range(len(X)), size)
    return [X[i] for i in idx], [y[i] for i in idx]

X_train, y_train = random_sampling(X_train, y_train, train_size)
X_dev, y_dev = random_sampling(X_dev, y_dev, dev_size)
X_test, y_test = random_sampling(X_test, y_test, test_size)
generate_stats(X_train, y_train)

average input length: 130.57, average output length: 114.73, unique token: 22973, total token: 169797


### Simple Baseline

In [None]:
def print_metrics(metrics):
    for k, v in metrics.items():
        if k == "precisions":
            for i, n in enumerate(v):
                print("{}-gram precision: {:.2f}".format(i+1, n))
        else:
            if type(v) == float:
                print("{}: {:.2f}".format(k, v))
            else:
                print("{}: {}".format(k, v))

In [None]:
import pandas as pd
import evaluate
import re
from collections import defaultdict

lexicon = pd.read_csv('/content/drive/Shareddrives/CIS 530 Project/Data/words_substitute.csv', 
                      header=None, names=['Complex', 'Simple'], encoding='unicode_escape')
for idx, row in lexicon.iterrows():
    lexicon.loc[idx, 'Simple'] = str(row['Simple']).split(',')[0]
    lexicon.loc[idx, 'Simple'] = re.sub("[\(\[].*?[\)\]]", "", lexicon.loc[idx, 'Simple'])

lexicon = pd.Series(lexicon.Simple.values,index=lexicon.Complex).to_dict()

# Wikipedia Dataset
predictions = []
for sent in X_test:
    words = sent.split(' ')
    for word in words:
        if word in lexicon.keys():
            sent.replace(word, lexicon[word])
    predictions.append(sent)
references = [[sent] for sent in y_test]

bleu = evaluate.load("bleu")
bleu_score = bleu.compute(predictions=predictions, references=references)
print_metrics(bleu_score)

Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

bleu: 0.59
1-gram precision: 0.74
2-gram precision: 0.62
3-gram precision: 0.55
4-gram precision: 0.50
brevity_penalty: 1.00
length_ratio: 1.11
translation_length: 12225
reference_length: 11024


In [None]:
# Human Simplification with Sentence Fusion Dataset
human_sim = pd.read_csv('/content/drive/Shareddrives/CIS 530 Project/Data/human_simplification/test.csv')
human_sim = human_sim[human_sim.Source=='Human Simplification']
test_hs = defaultdict(list)
for x_sent, y_sent in zip(human_sim.Original.to_list(), human_sim.Simplified.to_list()):
    test_hs[x_sent].append(y_sent)
X_test_hs = list(test_hs.keys())
predictions = []
for sent in X_test_hs:
    words = sent.split(' ')
    for word in words:
        if word in lexicon.keys():
            sent.replace(word, lexicon[word])
    predictions.append(sent)
references = list(test_hs.values())

sari = evaluate.load("sari")
sari_score = sari.compute(sources=X_test_hs, predictions=predictions, references=references)
print_metrics(sari_score)

Downloading builder script:   0%|          | 0.00/12.1k [00:00<?, ?B/s]

sari: 54.85


### Complex Word Identification model

In [None]:
import torch
import nltk
nltk.download('brown')
nltk.download('averaged_perceptron_tagger')
from nltk.corpus import brown
from nltk import pos_tag
from keras_preprocessing.sequence import pad_sequences
from wordfreq import zipf_frequency

class cwi(object):
    def __init__(self, model_cwi, tokenizer_bert, model_bert, word2index, device):
        self.model_cwi = model_cwi
        self.tokenizer_bert = tokenizer_bert
        self.model_bert = model_bert
        self.word2index = word2index
        self.device = device
        self.POS_TAG = ['JJ', 'JJR', 'JJS', 
                        'RB', 'RBR', 'RBS',
                        'VB', 'VBD','VBG', 'VBN', 'VBP', 'VBZ']

    @staticmethod
    def cleaner(word):
        word = re.sub(r'((http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.([a-zA-Z]){2,6}([a-zA-Z0-9\.\&\/\?\:@\-_=#])*', 
                        '', word, flags=re.MULTILINE)
        word = re.sub('[\W]', ' ', word)
        word = re.sub('[^a-zA-Z]', ' ', word)
        return word.lower().strip()

    def process_input(self, input_text):
        input_text = self.cleaner(input_text)
        clean_text = []
        index_list = []
        input_token = []
        index_list_zipf = []
        for i, word in enumerate(input_text.split()):
            if word in self.word2index:
                clean_text.append(word)
                input_token.append(self.word2index[word])
            else:
                index_list.append(i)
        input_padded = pad_sequences(maxlen=103, sequences=[input_token], padding="post", value=0)
        return input_padded, index_list, len(clean_text)

    @staticmethod
    def complete_missing_word(pred_binary, index_list, len_list):
        list_cwi_predictions = list(pred_binary[0][:len_list])
        for i in index_list:
            list_cwi_predictions.insert(i, 0)
        return list_cwi_predictions

    def get_bert_candidates(self, input_text, list_cwi_predictions, num_predictions_displayed=10):
        list_candidates_bert = []
        for word, pred in zip(input_text.split(), list_cwi_predictions):
            if pred and (pos_tag([word])[0][1] in self.POS_TAG):
                replace_word_mask = input_text.replace(word, '[MASK]')
                text = f'[CLS]{replace_word_mask} [SEP] {input_text} [SEP] '
                tokenized_text = self.tokenizer_bert.tokenize(text)
                masked_index = [i for i, x in enumerate(tokenized_text) if x=='[MASK]'][0]
                indexed_tokens = self.tokenizer_bert.convert_tokens_to_ids(tokenized_text)
                segments_ids = [0]*len(tokenized_text)
                tokens_tensor = torch.tensor([indexed_tokens], device=self.device)
                segments_tensors = torch.tensor([segments_ids], device=self.device)
                # Predict all tokens
                with torch.no_grad():
                    outputs = self.model_bert(tokens_tensor, token_type_ids=segments_tensors)
                    predictions = outputs[0][0][masked_index]
                predicted_ids = torch.argsort(predictions, descending=True)[:num_predictions_displayed]
                predicted_tokens = self.tokenizer_bert.convert_ids_to_tokens(list(predicted_ids))
                list_candidates_bert.append((word, predicted_tokens))
        return list_candidates_bert

    def replace(self, input_text):
        new_text = input_text
        input_padded, index_list, len_list = self.process_input(input_text)
        pred_cwi = self.model_cwi.predict(input_padded)
        pred_cwi_binary = np.argmax(pred_cwi, axis = 2)
        complete_cwi_predictions = self.complete_missing_word(pred_cwi_binary, index_list, len_list)
        bert_candidates = self.get_bert_candidates(input_text, complete_cwi_predictions)
        for word_to_replace, l_candidates in bert_candidates:
            tuples_word_zipf = []
            for w in l_candidates:
                if w.isalpha():
                    tuples_word_zipf.append((w, zipf_frequency(w, 'en')))
            tuples_word_zipf = sorted(tuples_word_zipf, key = lambda x: x[1], reverse=True)
            if len(tuples_word_zipf) > 0:
                new_text = re.sub(word_to_replace, tuples_word_zipf[0][0], new_text)
        return new_text

def extend_lists(lists):
    out = []
    for l in lists:
        out.extend(l)
    return out

[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Unzipping corpora/brown.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


### Extract and append feature tokens

In [None]:
import Levenshtein

def round_point5(num):
  return round(num*2, 1) / 2

def get_sim(sent, target):
  return round_point5(Levenshtein.ratio(sent, target))

def get_len_ratio(sent, target):
  return round_point5(len(target) / len(sent))

def add_parameter(sent, training, target=None, len_ratio=0, sim=0):
    if not training:
        return '{} {} {}'.format(sim, len_ratio, sent)
    else:
        return '{} {} {}'.format(get_sim(sent,target), 
                                 get_len_ratio(sent,target), 
                                 sent)

### Convert complex words to simple words

In [None]:
%%capture
# import complex word identification model
from keras.models import load_model
model_cwi = load_model("/content/drive/Shareddrives/CIS 530 Project/Data/model_CWI.h5")

# download bert model
from transformers import BertTokenizer, BertForMaskedLM
tokenizer_bert = BertTokenizer.from_pretrained('bert-large-uncased')
model_bert = BertForMaskedLM.from_pretrained('bert-large-uncased')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_bert.to(device)
model_bert.eval()

# intiate complex word identification model
import json
with open('/content/drive/Shareddrives/CIS 530 Project/Data/word2index.json') as f:
    word2index = json.load(f)
identifier = cwi(model_cwi, tokenizer_bert, model_bert, word2index, device)

# replace complex words in the complex sentences
X_train = [identifier.replace(sent) for sent in X_train]
X_dev = [identifier.replace(sent) for sent in X_dev]
X_test = [identifier.replace(sent) for sent in X_test]
X_test_hs = [identifier.replace(sent) for sent in X_test_hs]

# extract and add feature tokens
X_train = [add_parameter(X_train[i], training=True, target=y_train[i]) for i in range(train_size)]
X_dev = [add_parameter(sent, training=False, len_ratio=0.95, sim=0.85) for sent in X_dev]
X_test = [add_parameter(sent, training=False, len_ratio=0.95, sim=0.85) for sent in X_test]
X_test_hs = [add_parameter(sent, training=False, len_ratio=0.95, sim=0.85) for sent in X_test_hs]

# create dataframe and save to csv
train_df = pd.DataFrame({'complex': X_train, 'simple': y_train})
dev_df = pd.DataFrame({'complex': X_dev, 'simple': y_dev})
test_df = pd.DataFrame({'complex': X_test, 'simple': y_test})
test_hs_df = pd.DataFrame({'complex': X_test_hs, 'simple': references})
train_df.to_csv('train.csv', index=False)
dev_df.to_csv('dev.csv', index=False)
test_df.to_csv('test.csv', index=False)
test_hs_df.to_csv('test_hs.csv', index=False)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
n = 2 # change n to see the (n+1)th sample in the train data
print('-------------------------Train Sample {}-------------------------'.format(n))
print('Cleaned input: {}'.format(X_train[n]))
print('Cleaned output: {}'.format(y_train[n]))

-------------------------Train Sample 2-------------------------
Cleaned input: 0.6 0.7 final fantasy v is a computer role-playing video game developed and published by square  now square enix  in 1992 as a part of the final fantasy series.
Cleaned output: final fantasy v is a fantasy role-playing video game. it was made by squaresoft, now called square enix.


### Fine-tuning

In [None]:
nltk.download('punkt')
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, \
                         Seq2SeqTrainingArguments, Seq2SeqTrainer, \
                         DataCollatorForSeq2Seq, AdamW, get_scheduler, \
                         SchedulerType
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
def tokenize_function(examples):
    inputs, targets = examples["complex"], examples["simple"]
    model_inputs = tokenizer(inputs, max_length=128, padding="max_length", truncation=True)
    labels = tokenizer(text_target=targets, max_length=128, padding="max_length", truncation=True)

    # if we are padding here, replace all tokenizer.pad_token_id in the labels 
    # by -100 when we want to ignore padding in the loss.
    if ignore_pad_token_for_loss:
        labels["input_ids"] = [[(l if l != tokenizer.pad_token_id else -100) 
                                for l in label] 
                               for label in labels["input_ids"]]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

In [None]:
ignore_pad_token_for_loss = True
num_epochs = 3
batch_size = 16
data_files = {"train": "train.csv", "dev": "dev.csv", "test": "test.csv", "test_hs": "test_hs.csv"}
raw_dataset = load_dataset("csv", data_files=data_files)
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

tokenized_dataset = raw_dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(["complex", "simple"])
tokenized_dataset.set_format("torch")

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
train_dataloader = DataLoader(tokenized_dataset["train"], 
                              shuffle=True, 
                              batch_size=batch_size, 
                              collate_fn=data_collator)
dev_dataloader = DataLoader(tokenized_dataset["dev"], 
                            batch_size=batch_size, 
                            collate_fn=data_collator)
test_dataloader = DataLoader(tokenized_dataset["test"], 
                             batch_size=batch_size, 
                             collate_fn=data_collator)
test_hs_dataloader = DataLoader(tokenized_dataset["test_hs"], 
                                batch_size=batch_size, 
                                collate_fn=data_collator)



  0%|          | 0/4 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?ba/s]



In [None]:
train_size = len(train_dataloader)
dev_size = len(dev_dataloader)
test_size = len(test_dataloader)
test_hs_size = len(test_dataloader)

optimizer = AdamW(model.parameters(), lr=1e-4)
num_training_steps = num_epochs * train_size
lr_scheduler = get_scheduler("cosine_with_restarts",
                             optimizer=optimizer,
                             num_warmup_steps=0,
                             num_training_steps=num_training_steps)

model.to(device)

best_evaluation_loss = float('inf')
progress_bar = tqdm(range(num_training_steps))
model.train()
for e in tqdm(range(num_epochs)):
    # training
    training_loss = 0
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        training_loss += loss.item()
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
    training_loss /= train_size
    
    # evaluation
    print()
    print("======== DEV SET, epoch {} ========".format(e+1))
    print()
    metric_dev = {"training_loss": training_loss}
    metric_bleu = evaluate.load("bleu")
    model.eval()
    evaluation_loss = 0
    for batch in dev_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        evaluation_loss += loss.item()
        with torch.no_grad():
            preds = model.generate(batch["input_ids"], num_beams=5, min_length=0, max_length=128)
            decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
            if ignore_pad_token_for_loss:
                # replace -100 in the labels as we can't decode them.
                labels = torch.where(batch["labels"]!=-100, batch["labels"], tokenizer.pad_token_id)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            # some simple post-processing
            decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
            metric_bleu.add_batch(predictions=decoded_preds, references=decoded_labels)

    evaluation_loss /= dev_size
    if evaluation_loss < best_evaluation_loss:
        model.save_pretrained("bart-ext-2")
        best_evaluation_loss = evaluation_loss
    metric_dev["evaluation_loss"] = evaluation_loss
    metric_dev.update(metric_bleu.compute())
    print_metrics(metric_dev)
    
    model.train()

print()
print("======== TEST SET ========")
print()
metric_test = {}
metric_bleu = evaluate.load("bleu")
# metric_sari = evaluate.load("sari")
model.eval()
test_sents = []
for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        preds = model.generate(batch["input_ids"], num_beams=5, min_length=0, max_length=128)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = torch.where(batch["labels"]!=-100, batch["labels"], tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        test_sents.append([decoded_preds, decoded_labels])
        metric_bleu.add_batch(predictions=decoded_preds, references=decoded_labels)
        # metric_sari.add_batch(predictions=decoded_preds, references=decoded_labels)

metric_test.update(metric_bleu.compute())
# metric.update(metric_sari.compute())
print_metrics(metric_test)



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-0e81148f372263cb/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating dev split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-0e81148f372263cb/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]



  0%|          | 0/750 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.




training_loss: 1.228
evaluation_loss: 0.996
bleu: 0.592
1-gram precision: 0.743
2-gram precision: 0.620
3-gram precision: 0.546
4-gram precision: 0.488
brevity_penalty: 1.000
length_ratio: 1.066
translation_length: 11570
reference_length: 10851


training_loss: 0.847
evaluation_loss: 0.949
bleu: 0.602
1-gram precision: 0.754
2-gram precision: 0.631
3-gram precision: 0.555
4-gram precision: 0.496
brevity_penalty: 1.000
length_ratio: 1.048
translation_length: 11370
reference_length: 10851


training_loss: 0.651
evaluation_loss: 0.972
bleu: 0.593
1-gram precision: 0.748
2-gram precision: 0.623
3-gram precision: 0.547
4-gram precision: 0.486
brevity_penalty: 1.000
length_ratio: 1.054
translation_length: 11435
reference_length: 10851


bleu: 0.584
1-gram precision: 0.753
2-gram precision: 0.617
3-gram precision: 0.533
4-gram precision: 0.469
brevity_penalty: 1.000
length_ratio: 1.026
translation_length: 11307
reference_length: 11024


### SARI Test Results

In [None]:
model_bart_ext1 = AutoModelForSeq2SeqLM.from_pretrained("/content/drive/Shareddrives/CIS 530 Project/Data/bart-ext-1")
model_bart_ext1.to(device)
print("Extension 1 model successfully loaded!")
model_bart_ext2 = AutoModelForSeq2SeqLM.from_pretrained("/content/drive/Shareddrives/CIS 530 Project/Data/bart-ext-2")
model_bart_ext2.to(device)
print("Extension 2 model successfully loaded!")

Extension 1 model successfully loaded!
Extension 2 model successfully loaded!


In [None]:
# Get SARI score of Extension 1 model on the Human Simplification test set
metric = {}
metric_sari = evaluate.load("sari")
model_bart_ext1.eval()
test_sents = []
for i, batch in enumerate(test_hs_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        preds = model_bart_ext1.generate(batch["input_ids"], num_beams=5, min_length=0, max_length=128)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = torch.where(batch["labels"]!=-100, batch["labels"], tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        metric_sari.add_batch(sources=X_test_hs[batch_size*i:batch_size*(i+1)], predictions=decoded_preds, references=references[batch_size*i:batch_size*(i+1)])

metric.update(metric_sari.compute())
print_metrics(metric)

sari: 50.07


In [None]:
# Get SARI score of Extension 2 model on the Human Simplification test set
metric = {}
metric_sari = evaluate.load("sari")
model_bart_ext2.eval()
test_sents = []
for i, batch in enumerate(test_hs_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        preds = model_bart_ext2.generate(batch["input_ids"], num_beams=5, min_length=0, max_length=128)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = torch.where(batch["labels"]!=-100, batch["labels"], tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        metric_sari.add_batch(sources=X_test_hs[batch_size*i:batch_size*(i+1)], predictions=decoded_preds, references=references[batch_size*i:batch_size*(i+1)])

metric.update(metric_sari.compute())
print_metrics(metric)

sari: 51.56


### Error Analysis

In [None]:
test_preds = []
test_labels = []
for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        preds = model_bart_ext2.generate(batch["input_ids"], num_beams=5, min_length=0, max_length=128)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = torch.where(batch["labels"]!=-100, batch["labels"], tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        test_preds.extend(decoded_preds)
        test_labels.extend(decoded_labels)

print("{} test samples predicted.".format(len(test_preds)))

500 test samples predicted.


In [None]:
n = 0
print("Prediction: {}".format(test_preds[n]))
print("Reference: {}".format(test_labels[n]))

Prediction: the straits of georgia and juan de fuca are now all also part of the salish sea, which includes puget sound as well.
Reference: you can go there by boat from vancouver  across the strait of georgia  or from the state of washington  across the strait of juan de fuca.


In [None]:
with open("error_analysis.json", "w") as f:
    json.dump({"prediction": test_preds, "reference": test_labels}, f)