## Install Requirements

In [None]:
!pip install bert_score click keybert matplotlib nltk numpy optuna-integration pandas plotly python_Levenshtein pytorch_lightning rouge sacrebleu sacremoses spacy scikit_learn simalign stanfordnlp summarizer torch torchfile tqdm transformers yattag

Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m814.5 kB/s[0m eta [36m0:00:00[0m
Collecting keybert
  Downloading keybert-0.8.4.tar.gz (29 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting optuna-integration
  Downloading optuna_integration-3.6.0-py3-none-any.whl (93 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.4/93.4 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Collecting python_Levenshtein
  Downloading python_Levenshtein-0.25.1-py3-none-any.whl (9.4 kB)
Collecting pytorch_lightning
  Downloading pytorch_lightning-2.2.2-py3-none-any.whl (801 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.9/801.9 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rouge
  Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Collecting sacrebleu
  Downloading sacrebleu-2.4.2-py3-none-any.whl (106 kB)
[2K   

## EASSE Library

In [None]:
!git clone https://github.com/feralvam/easse.git

Cloning into 'easse'...
remote: Enumerating objects: 1964, done.[K
remote: Counting objects: 100% (145/145), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 1964 (delta 118), reused 104 (delta 104), pack-reused 1819[K
Receiving objects: 100% (1964/1964), 33.15 MiB | 23.72 MiB/s, done.
Resolving deltas: 100% (1231/1231), done.


In [None]:
%cd /content/easse

/content/easse


In [None]:
!pip install -e .

Obtaining file:///content/easse
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting tseval@ git+https://github.com/facebookresearch/text-simplification-evaluation.git@main (from easse==0.2.4)
  Cloning https://github.com/facebookresearch/text-simplification-evaluation.git (to revision main) to /tmp/pip-install-j7f2fxu2/tseval_75031b0f9eb441c695c8f3c2c89678f5
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/text-simplification-evaluation.git /tmp/pip-install-j7f2fxu2/tseval_75031b0f9eb441c695c8f3c2c89678f5
  Resolved https://github.com/facebookresearch/text-simplification-evaluation.git to commit dea8863683ea5946fd50184883c9be7a7339e821
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gitpython (from tseval@ git+https://github.com/facebookresearch/text-simplification-evaluation.git@main->easse==0.2.4)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [None]:
%cd /content/

/content


In [None]:
import shutil
import os

root_dir = '/content/'
out_dir = os.path.join(root_dir, 'easse')
inner_easse_dir = os.path.join(root_dir, 'easse', 'easse')

if os.path.exists(out_dir):

    # Move the contents of the inner easse folder back to 'easse'
    if os.path.exists(inner_easse_dir):
        inner_easse_files = os.listdir(inner_easse_dir)
        for file in inner_easse_files:
            src = os.path.join(inner_easse_dir, file)
            dst = os.path.join(root_dir, 'easse', file)
            shutil.move(src, dst)

        os.rmdir(inner_easse_dir)

## Preprocessor

In [None]:
import optparse
import os
from pathlib import Path
import sys
from functools import lru_cache
from multiprocessing import Pool, Lock
from string import punctuation
import multiprocessing
import Levenshtein
import numpy as np
import spacy
import nltk
import shutil
import time
import pickle
import hashlib


nltk.download('stopwords')
from nltk.corpus import stopwords
import re
from sacremoses import MosesDetokenizer, MosesTokenizer


EXP_DIR = '/content/experiments'
DUMPS_DIR = '/content/dumps'

PSAT = 'PSAT'

LANGUAGES = ['complex', 'simple']
PHASES = ['train','valid']

stopwords = set(stopwords.words('english'))

#######################
def get_tokenizer():
    return MosesTokenizer(lang='en')

def get_detokenizer():
    return MosesDetokenizer(lang='en')

def tokenize(sentence):
    return get_tokenizer().tokenize(sentence)

def write_lines(lines, filepath):
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)
    with filepath.open("w") as fout:
        for line in lines:
            fout.write(line + '\n')


def read_lines(filepath):
    return [line.rstrip() for line in yield_lines(filepath)]


def yield_lines(filepath):
    filepath = Path(filepath)
    with filepath.open('r') as f:
        for line in f:
            yield line.rstrip()


def yield_sentence_pair_with_index(filepath1, filepath2):
    index = 0
    with Path(filepath1).open('r') as f1, Path(filepath2).open('r') as f2:
        for line1, line2 in zip(f1, f2):
            index += 1
            yield (line1.rstrip(), line2.rstrip(), index)


def yield_sentence_pair(filepath1, filepath2):
    with Path(filepath1).open('r') as f1, Path(filepath2).open('r') as f2:
        for line1, line2 in zip(f1, f2):
            yield line1.rstrip(), line2.rstrip()


def count_line(filepath):
    filepath = Path(filepath)
    line_count = 0
    with filepath.open("r") as f:
        for line in f:
            line_count += 1
    return line_count


def load_dump(filepath):
    return pickle.load(open(filepath, 'rb'))


def dump(obj, filepath):
    pickle.dump(obj, open(filepath, 'wb'))

def save_preprocessor(preprocessor):
    DUMPS_DIR.mkdir(parents=True, exist_ok=True)
    PREPROCESSOR_DUMP_FILE = f'{DUMPS_DIR}/preprocessor.pickle'
    dump(preprocessor, PREPROCESSOR_DUMP_FILE)


def load_preprocessor():
    PREPROCESSOR_DUMP_FILE = f'{DUMPS_DIR}/preprocessor.pickle'
    if os.path.exists(PREPROCESSOR_DUMP_FILE):
        return load_dump(PREPROCESSOR_DUMP_FILE)
    else:
        return None

def generate_hash(data):
    h = hashlib.new('md5')
    h.update(str(data).encode())
    return h.hexdigest()

def get_data_filepath(dataset, phase, type, i=None):
    suffix = ''
    if i is not None:
        suffix = f'.{i}'
    filename = f'{dataset}.{phase}.{type}{suffix}'
    return f'/content/{dataset}/{filename}'

#################################

def round(val):
    return '%.2f' % val


def safe_division(a, b):
    return a / b if b else 0


# def tokenize(sentence):
#     return sentence.split()

def is_punctuation(word):
    return ''.join([char for char in word if char not in punctuation]) == ''


def remove_punctuation(text):
    return ' '.join([word for word in tokenize(text) if not is_punctuation(word)])


def remove_stopwords(text):
    return ' '.join([w for w in tokenize(text) if w.lower() not in stopwords])


def get_dependency_tree_depth(sentence):
    def tree_height(node):
        if len(list(node.children)) == 0:
            return 0
        return 1 + max(tree_height(child) for child in node.children)

    tree_depths = [tree_height(spacy_sentence.root) for spacy_sentence in spacy_process(sentence).sents]
    if len(tree_depths) == 0:
        return 0
    return max(tree_depths)

def get_spacy_model():
    model = 'en_core_web_sm'
    if not spacy.util.is_package(model):
        spacy.cli.download(model)
        spacy.cli.link(model, model, force=True, model_path=spacy.util.get_package_path(model))
    return spacy.load(model)


def spacy_process(text):
    return get_spacy_model()(str(text))


def get_word2rank(vocab_size=np.inf):
    model_filepath = f'{DUMPS_DIR}/{WORD_EMBEDDINGS_NAME}.pk'
    if model_filepath.exists():
        return load_dump(model_filepath)


def get_normalized_rank(word):
    max = len(get_word2rank())
    rank = get_word2rank().get(word, max)
    return np.log(1 + rank) / np.log(1 + max)



def get_complexity_score2(sentence):
    words = tokenize(remove_stopwords(remove_punctuation(sentence)))
    words = [word for word in words if word in get_word2rank()]  # remove unknown words
    if len(words) == 0:
        return 1.0
    return np.array([get_normalized_rank(word) for word in words]).mean()

def get_word_frequency():
    model_filepath = f'{DUMPS_DIR}/{WORD_FREQUENCY_FILEPATH.stem}.pk'
    if model_filepath.exists():
        return load_dump(model_filepath)
    else:
        DUMPS_DIR.mkdir(parents=True, exist_ok=True)
        word_freq = {}
        for line in yield_lines(WORD_FREQUENCY_FILEPATH):
            chunks = line.split(' ')
            word = chunks[0]
            freq = int(chunks[1])
            word_freq[word] = freq
        dump(word_freq, model_filepath)
        return word_freq


def get_normalized_inverse_frequency(word):
    max = 153141437 # the 153141437, the max frequency
    freq = get_word_frequency().get(word, 0)
    return 1.0 - np.log(1 + freq) / np.log(1 + max)


def get_complexity_score(sentence, operation_type = None):
    words = tokenize(remove_stopwords(remove_punctuation(sentence)))
    #words = tokenize(remove_punctuation(sentence))
    words = [word for word in words if word in get_word2rank()]  # remove unknown words
    if len(words) == 0:
        return 1.0
    if operation_type == 'mean':
        return np.array([get_normalized_inverse_frequency(word.lower()) for word in words]).mean()
    else:
        return np.array([get_normalized_inverse_frequency(word.lower()) for word in words]).max()

class RatioFeature:
    def __init__(self, feature_extractor, target_ratio=0.8):
        self.feature_extractor = feature_extractor
        self.target_ratio = target_ratio

    def encode_sentence(self, sentence):
        return f'{self.name}_{self.target_ratio}'

    def encode_sentence_pair(self, complex_sentence, simple_sentence):
        return f'{self.name}_{self.feature_extractor(complex_sentence, simple_sentence)}', simple_sentence

    def decode_sentence(self, encoded_sentence):
        return encoded_sentence

    @property
    def name(self):
        class_name = self.__class__.__name__.replace('RatioFeature', '')
        name = ""
        for word in re.findall('[A-Z][^A-Z]*', class_name):
            if word: name += word[0]
        if not name: name = class_name
        return name

### tokens features ###
class WordRatioFeature(RatioFeature):
    def __init__(self, *args, **kwargs):
        super().__init__(self.get_word_length_ratio, *args, **kwargs)

    def get_word_length_ratio(self, complex_sentence, simple_sentence):
        return round(safe_division(len(tokenize(simple_sentence)), len(tokenize(complex_sentence))))


class CharRatioFeature(RatioFeature):
    def __init__(self, *args, **kwargs):
        super().__init__(self.get_char_length_ratio, *args, **kwargs)

    def get_char_length_ratio(self, complex_sentence, simple_sentence):
        return round(safe_division(len(simple_sentence), len(complex_sentence)))


class LevenshteinRatioFeature(RatioFeature):
    def __init__(self, *args, **kwargs):
        super().__init__(self.get_levenshtein_ratio, *args, **kwargs)

    def get_levenshtein_ratio(self, complex_sentence, simple_sentence):
        return round(Levenshtein.ratio(complex_sentence, simple_sentence))


class WordRankRatioFeature(RatioFeature):
    def __init__(self, *args, **kwargs):
        super().__init__(self.get_word_rank_ratio, *args, **kwargs)

    def get_word_rank_ratio(self, complex_sentence, simple_sentence):
        return round(min(safe_division(self.get_lexical_complexity_score(simple_sentence),
                                       self.get_lexical_complexity_score(complex_sentence)), 2))

    def get_lexical_complexity_score(self, sentence):
        words = tokenize(remove_stopwords(remove_punctuation(sentence)))
        words = [word for word in words if word in get_word2rank()]
        if len(words) == 0:
            return np.log(1 + len(get_word2rank()))
        return np.quantile([self.get_rank(word) for word in words], 0.75)


    def get_rank(self, word):
        # return get_normalized_inverse_frequency(word)
        rank = get_word2rank().get(word, len(get_word2rank()))
        return np.log(1 + rank)


class DependencyTreeDepthRatioFeature(RatioFeature):
    def __init__(self, *args, **kwargs):
        super().__init__(self.get_dependency_tree_depth_ratio, *args, **kwargs)

    def get_dependency_tree_depth_ratio(self, complex_sentence, simple_sentence):
        return round(
            safe_division(self.get_dependency_tree_depth(simple_sentence),
                          self.get_dependency_tree_depth(complex_sentence)))


    def get_dependency_tree_depth(self, sentence):
        def get_subtree_depth(node):
            if len(list(node.children)) == 0:
                return 0
            return 1 + max([get_subtree_depth(child) for child in node.children])

        tree_depths = [get_subtree_depth(spacy_sentence.root) for spacy_sentence in self.spacy_process(sentence).sents]
        if len(tree_depths) == 0:
            return 0
        return max(tree_depths)

    def spacy_process(self, text):
        return get_spacy_model()(text)

class Preprocessor:
    def __init__(self, features_kwargs=None):
        super().__init__()

        self.features = self.get_features(features_kwargs)
        if features_kwargs:
            self.hash = generate_hash(str(features_kwargs).encode())
        else:
            self.hash = "no_feature"

    def get_class(self, class_name, *args, **kwargs):
        return globals()[class_name](*args, **kwargs)

    def get_features(self, feature_kwargs):
        features = []
        for feature_name, kwargs in feature_kwargs.items():
            features.append(self.get_class(feature_name, **kwargs))
        return features

    def encode_sentence(self, sentence):
        if self.features:
            line = ''
            for feature in self.features:
                line += feature.encode_sentence(sentence) + ' '
            line += ' ' + sentence
            return line.rstrip()
        else:
            return sentence

    def encode_sentence_pair(self, complex_sentence, simple_sentence):
        # print(complex_sentence)
        if self.features:
            line = ''
            for feature in self.features:
                # startTime = timeit.default_timer()
                # print(feature)
                processed_complex, _ = feature.encode_sentence_pair(complex_sentence, simple_sentence)
                line += processed_complex + ' '
                # print(feature, timeit.default_timer() - startTime)
            line += ' ' + complex_sentence
            return line.rstrip()

        else:
            return complex_sentence

    def decode_sentence(self, encoded_sentence):
        for feature in self.features:
            decoded_sentence = feature.decode_sentence(encoded_sentence)
        return decoded_sentence

    def encode_file(self, input_filepath, output_filepath):
        with open(output_filepath, 'w') as f:
            for line in yield_lines(input_filepath):
                f.write(self.encode_sentence(line) + '\n')

    def decode_file(self, input_filepath, output_filepath):
        with open(output_filepath, 'w') as f:
            for line in yield_lines(input_filepath):
                f.write(self.decode_sentence(line) + '\n')

    def process_encode_sentence_pair(self, sentences):
        print(f"{sentences[2]}/{self.line_count}", sentences[0])  # sentence[0] index
        return (self.encode_sentence_pair(sentences[0], sentences[1]))

    def pool_encode_sentence_pair(self, args):
        # print(f"{processed_line_count}/{self.line_count}")
        complex_sent, simple_sent, queue = args
        queue.put(1)
        return self.encode_sentence_pair(complex_sent, simple_sent)


    def encode_file_pair(self, complex_filepath, simple_filepath):
        processed_complex_sentences = []
        self.line_count = count_line(simple_filepath)

        i = 0
        for complex_sentence, simple_sentence in yield_sentence_pair(complex_filepath, simple_filepath):
        # print(complex_sentence)
            processed_complex_sentence = self.encode_sentence_pair(complex_sentence, simple_sentence)
            i +=1
            print(f"{i}/{self.line_count}", processed_complex_sentence)
            processed_complex_sentences.append(processed_complex_sentence)

        return processed_complex_sentences

    def get_preprocessed_filepath(self, dataset, phase, type):
        filename = f'{dataset}.{phase}.{type}'
        return f'{self.preprocessed_data_dir}/{filename}'

    def preprocess_dataset(self, dataset):
        # download_requirements()
        self.preprocessed_data_dir = f'{PROCESSED_DATA_DIR}/{self.hash}/{dataset}'
        #self.preprocessed_data_dir.mkdir(parents=True, exist_ok=True)
        os.makedirs(self.preprocessed_data_dir, exist_ok=True)
        save_preprocessor(self)
        print(f'Preprocessing dataset: {dataset}')

        for phase in PHASES:
            # for phase in ["valid", "test"]:
            complex_filepath = get_data_filepath(dataset, phase, 'complex')
            simple_filepath = get_data_filepath(dataset, phase, 'simple')

            complex_output_filepath = f'{self.preprocessed_data_dir}/{complex_filepath.name}'
            simple_output_filepath = f'{self.preprocessed_data_dir}/{simple_filepath.name}'
            if complex_output_filepath.exists() and simple_output_filepath.exists():
                continue

            print(f'Prepocessing files: {complex_filepath.name} {simple_filepath.name}')
            processed_complex_sentences = self.encode_file_pair(complex_filepath, simple_filepath)

            write_lines(processed_complex_sentences, complex_output_filepath)
            shutil.copy(simple_filepath, simple_output_filepath)

        print(f'Preprocessing dataset "{dataset}" is finished.')
        return self.preprocessed_data_dir

if __name__ == '__main__':
    features_kwargs = {
        # 'WordRatioFeature': {'target_ratio': 0.8},
        'CharRatioFeature': {'target_ratio': 0.8},
        'LevenshteinRatioFeature': {'target_ratio': 0.8},
        'WordRankRatioFeature': {'target_ratio': 0.8},
        'DependencyTreeDepthRatioFeature': {'target_ratio': 0.8}
    }

    preprocessor = load_preprocessor()

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


## BART and BRIO Models

In [None]:
from functools import lru_cache
from gc import callbacks
from lib2to3.pgen2 import token
from pathlib import Path
from weakref import ref
import math
from pytorch_lightning.loggers import CSVLogger
from easse.sari import corpus_sari
from torch.nn import functional as F
import Levenshtein
import argparse
from argparse import ArgumentParser
import os
import logging
import random
import nltk
from summarizer import Summarizer

nltk.download('punkt')

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.trainer import seed_everything
from transformers import (
    T5ForConditionalGeneration,
    T5TokenizerFast,
    BertTokenizer, BertForPreTraining,
    BartForConditionalGeneration, BartTokenizer,pipeline,BartTokenizerFast, BartModel, PreTrainedTokenizerFast,
    get_linear_schedule_with_warmup, AutoConfig, AutoModel
)

#BERT_Sum = Summarizer(model='distilbert-base-uncased')

class MetricsCallback(pl.Callback):
  def __init__(self):
    super().__init__()
    self.metrics = []

  def on_validation_end(self, trainer, pl_module):
      self.metrics.append(trainer.callback_metrics)


class BartBaseLineFineTuned(pl.LightningModule):
    def __init__(self,args):
        super(BartBaseLineFineTuned, self).__init__()
        self.args = args
        self.save_hyperparameters()

        self.model = BartForConditionalGeneration.from_pretrained(self.args.sum_model)
        self.model = self.model.to(self.args.device)
        self.tokenizer = BartTokenizer.from_pretrained(self.args.sum_model)

    def is_logger(self):
        return self.trainer.global_rank <= 0

    def forward(self, input_ids,
    attention_mask = None,
    decoder_input_ids = None,
    decoder_attention_mask = None, labels = None):

        outputs = self.model(
            input_ids = input_ids,
            attention_mask = attention_mask,
            decoder_input_ids = decoder_input_ids,
            decoder_attention_mask =  decoder_attention_mask,
            labels = labels
        )

        return outputs

    def training_step(self, batch, batch_idx):
        source = batch["source"]
        labels = batch['target_ids']
        labels[labels[:,:] == self.tokenizer.pad_token_id] = -100
        # zero the gradient buffers of all parameters
        self.opt.zero_grad()
        # forward pass
        outputs = self(
            input_ids = batch["source_ids"],
            attention_mask = batch["source_mask"],
            labels = labels,
            decoder_attention_mask = batch["target_mask"]

        )

        if self.args.custom_loss:

            loss = outputs.loss


            self.log('train_loss', loss, on_step=True, prog_bar=True, logger=True)
            # print(loss)
            return loss
        else:
            loss = outputs.loss
            self.log('train_loss', loss, on_step=True, prog_bar=True, logger=True)
            #print(loss)
            return loss


    def validation_step(self, batch, batch_idx):
        loss = self.sari_validation_step(batch)
        # loss = self._step(batch)
        print("Val_loss", loss)
        logs = {"val_loss": loss}

        self.log('val_loss', loss, batch_size = self.args.valid_batch_size)
        return torch.tensor(loss, dtype=float)

    def sari_validation_step(self, batch):
        def generate(sentence):
            text = sentence
            encoding = self.tokenizer(
            [text],
            max_length = 512,
            truncation = True,
            padding = 'max_length',
            return_tensors = 'pt'
        )
            input_ids = encoding['input_ids'].to(self.args.device)
            attention_mask = encoding['attention_mask'].to(self.args.device)

            beam_outputs = self.model.generate(
                input_ids = input_ids,
                attention_mask = attention_mask,
                do_sample = True,
                max_length = 256,
                num_beams = 5,
                top_k = 120,
                top_p = 0.95,
                early_stopping = True,
                num_return_sequences = 1
            ).to(self.args.device)


            # final_outputs = []
            # for beam_output in beam_outputs:

            sent = self.tokenizer.decode(beam_outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
            # if sent.lower() != sentence.lower() and sent not in final_outputs:
                # final_outputs.append(sent)

            return sent

        pred_sents = []
        for source in batch["source"]:
            pred_sent = generate(source)
            pred_sents.append(pred_sent)


        score = corpus_sari(batch["source"], pred_sents, [batch["targets"]])


        print("Sari score: ", score)

        return 1 - score / 100

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"

        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        self.opt = optimizer
        return [optimizer]

    def optimizer_step(self, epoch=None, batch_idx=None, optimizer=None,optimizer_closure=None):
        optimizer.step(closure=optimizer_closure)

        optimizer.zero_grad()
        self.lr_scheduler.step()

    def save_core_model(self):
      tmp = self.args.model_name + 'core'
      store_path = f'/content/experiments/{tmp}'
      self.model.save_pretrained(store_path)
      self.simplifier_tokenizer.save_pretrained(store_path)

    def train_dataloader(self):
        train_dataset = TrainDataset(dataset=self.args.dataset,
                                     tokenizer=self.tokenizer,
                                     max_len=self.args.max_seq_length,
                                     sample_size=self.args.train_sample_size)

        dataloader = DataLoader(train_dataset,
                                batch_size=self.args.train_batch_size,
                                drop_last=True,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=2)
        t_total = ((len(dataloader.dataset) // (self.args.train_batch_size * max(1, self.args.n_gpu)))
                   // self.args.gradient_accumulation_steps
                   * float(self.args.num_train_epochs)
                   )
        scheduler = get_linear_schedule_with_warmup(
            self.opt, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total
        )
        self.lr_scheduler = scheduler
        return dataloader

    def val_dataloader(self):
        val_dataset = ValDataset(dataset=self.args.dataset,
                                 tokenizer=self.tokenizer,
                                 max_len=self.args.max_seq_length,
                                 sample_size=self.args.valid_sample_size)
        return DataLoader(val_dataset,
                          batch_size=self.args.valid_batch_size,
                          num_workers=2)
    @staticmethod
    def add_model_specific_args(parent_parser):
      p = ArgumentParser(parents=[parent_parser],add_help = False)
      # facebook/bart-base (BART) Yale-LILY/brio-cnndm-uncased (BRIO)
      p.add_argument('-Summarizer','--sum_model', default='Yale-LILY/brio-cnndm-uncased')
      p.add_argument('-TrainBS','--train_batch_size',type=int, default=6)
      p.add_argument('-ValidBS','--valid_batch_size',type=int, default=6)
      p.add_argument('-lr','--learning_rate',type=float, default=1e-5)
      p.add_argument('-MaxSeqLen','--max_seq_length',type=int, default=256)
      p.add_argument('-AdamEps','--adam_epsilon', default=1e-8)
      p.add_argument('-WeightDecay','--weight_decay', default = 0.0001)
      p.add_argument('-WarmupSteps','--warmup_steps',default=5)
      p.add_argument('-NumEpoch','--num_train_epochs',default=3)
      p.add_argument('-CosLoss','--custom_loss', default=False)
      p.add_argument('-GradAccuSteps','--gradient_accumulation_steps', default=1)
      p.add_argument('-GPUs','--n_gpu',default=torch.cuda.device_count())
      p.add_argument('-nbSVS','--nb_sanity_val_steps',default = -1)
      p.add_argument('-TrainSampleSize','--train_sample_size', default=1)
      p.add_argument('-ValidSampleSize','--valid_sample_size', default=1)
      p.add_argument('-device','--device', default = 'cuda')
      #p.add_argument('-NumBeams','--num_beams', default=8)
      return p


logger = logging.getLogger(__name__)


class LoggingCallback(pl.Callback):
    def on_validation_end(self, trainer, pl_module):
        logger.info("***** Validation results *****")
        print("***** Validation results *****")
        if pl_module.is_logger():
            metrics = trainer.callback_metrics
            # Log results
            for key in sorted(metrics):
                print(key, metrics[key])
                if key not in ["log", "progress_bar"]:
                    logger.info("{} = {}\n".format(key, str(metrics[key])))
                    print("{} = {}\n".format(key, str(metrics[key])))

    def on_test_end(self, trainer, pl_module):
        logger.info("***** Test results *****")
        print("***** Test results *****")

        if pl_module.is_logger():
            metrics = trainer.callback_metrics

            # Log and save results to file
            output_test_results_file = os.path.join(pl_module.args.output_dir, "test_results.txt")
            with open(output_test_results_file, "w") as writer:
                for key in sorted(metrics):
                    if key not in ["log", "progress_bar"]:
                        logger.info("{} = {}\n".format(key, str(metrics[key])))
                        writer.write("{} = {}\n".format(key, str(metrics[key])))
                        print("{} = {}\n".format(key, str(metrics[key])))


##### build dataset Loader #####
class TrainDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_len=256, sample_size=1):
        self.sample_size = sample_size
        print("init TrainDataset ...")
        self.source_filepath = get_data_filepath(dataset,'train','complex')
        self.target_filepath = get_data_filepath(dataset,'train','simple')
        print(self.source_filepath)
        print("Initialized dataset done.....")

        self.max_len = max_len
        self.tokenizer = tokenizer

        self._load_data()

    def _load_data(self):
        self.inputs = read_lines(self.source_filepath)
        self.targets = read_lines(self.target_filepath)

    def __len__(self):
        return int(len(self.inputs) * self.sample_size)

    def __getitem__(self, index):
        source = self.inputs[index]
        target = self.targets[index]

        tokenized_inputs = self.tokenizer(
            [source],
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            return_tensors="pt"
        )
        tokenized_targets = self.tokenizer(
            [target],
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            return_tensors="pt"
        )
        source_ids = tokenized_inputs["input_ids"].squeeze()
        target_ids = tokenized_targets["input_ids"].squeeze()

        src_mask = tokenized_inputs["attention_mask"].squeeze()  # might need to squeeze
        target_mask = tokenized_targets["attention_mask"].squeeze()  # might need to squeeze

        return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask,
                'sources': self.inputs[index], 'targets': [self.targets[index]],
                'source': source, 'target': target}


class ValDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_len=256, sample_size=1):
        self.sample_size = sample_size
        ### WIKI-large dataset ###
        self.source_filepath = get_data_filepath(dataset, 'valid', 'complex')
        self.target_filepaths = get_data_filepath(dataset, 'valid', 'simple')
        print(self.source_filepath)

        self.max_len = max_len
        self.tokenizer = tokenizer

        self._build()

    def __len__(self):
        return int(len(self.inputs) * self.sample_size)

    def __getitem__(self, index):
        return {"source": self.inputs[index], "targets": self.targets[index]}

    def _build(self):
        self.inputs = []
        self.targets = []

        for source in yield_lines(self.source_filepath):
            self.inputs.append(source)

        for target in yield_lines(self.target_filepaths):
            self.targets.append(target)

def train(args):
    seed_everything(args.seed)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=args.output_dir,
        filename="checkpoint-{epoch}",
        monitor="val_loss",
        verbose=True,
        mode="min",
        save_top_k=1
    )
    bar_callback = pl.callbacks.TQDMProgressBar(refresh_rate=1)
    metrics_callback = MetricsCallback()
    train_params = dict(
        accumulate_grad_batches=args.gradient_accumulation_steps,
        accelerator = 'auto',
        max_epochs=args.num_train_epochs,
        callbacks=[
            LoggingCallback(),
            checkpoint_callback, bar_callback],
        logger = CSVLogger(f'{args.output_dir}/logs'),
        log_every_n_steps = 9,
        num_sanity_val_steps=0,  # skip sanity check to save time for debugging purpose
        # plugins='ddp_sharded',
        # progress_bar_refresh_rate=1,
    )

    print("Initialize model")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using {device}')
    model = BartBaseLineFineTuned(args)

    model.args.dataset = args.dataset
    trainer = pl.Trainer(**train_params)

    print("Training model")
    trainer.fit(model)

    print("training finished")

    print("Saving model")
    model.model.save_pretrained(args.output_dir)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


## Load Dataset

In [None]:
if not os.path.exists('/content/PSAT'):
  if os.path.exists('/content/PSAT.zip'):
    !unzip PSAT.zip -d PSAT
  else:
    print('Please upload the dataset zip file.')

Archive:  PSAT.zip
 extracting: PSAT/PSAT.train.complex  
 extracting: PSAT/PSAT.test.complex  
 extracting: PSAT/PSAT.valid.complex  
 extracting: PSAT/PSAT.train.simple  
 extracting: PSAT/PSAT.test.simple   
 extracting: PSAT/PSAT.valid.simple  


## Main Trainer

In [None]:
import torch
from pathlib import Path
import sys

PSAT = 'PSAT'
EXP_DIR = '/content/experiments'

import time
import json
import argparse

from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from optuna.integration import PyTorchLightningPruningCallback

def parse_arguments():
    p = ArgumentParser()

    p.add_argument('--seed', type=int, default=42, help='randomization seed')

    p = BartBaseLineFineTuned.add_model_specific_args(p)
    args,_ = p.parse_known_args()
    return args

# Create experiment directory

def get_experiment_dir(create_dir=False):
    path = os.path.join('/content/experiments')
    if os.path.exists(path):
      shutil.rmtree(path) # Delete the directory if it exists
    if create_dir:
        os.makedirs(path)
    return path

def log_params(filepath, kwargs):
    filepath = Path(filepath)
    kwargs_str = dict()
    for key in kwargs:
        kwargs_str[key] = str(kwargs[key])
    json.dump(kwargs_str, filepath.open('w'), indent=4)


def run_training(args, dataset):

    args.output_dir = get_experiment_dir(create_dir=True)
    # logging the args
    log_params(f"{args.output_dir}/params.json", vars(args))

    args.dataset = dataset
    print("Dataset: ",args.dataset)
    train(args)

dataset = PSAT
args = parse_arguments()
print(args)

Namespace(seed=42, sum_model='Yale-LILY/brio-cnndm-uncased', train_batch_size=6, valid_batch_size=6, learning_rate=1e-05, max_seq_length=256, adam_epsilon=1e-08, weight_decay=0.0001, warmup_steps=5, num_train_epochs=3, custom_loss=False, gradient_accumulation_steps=1, n_gpu=1, nb_sanity_val_steps=-1, train_sample_size=1, valid_sample_size=1, device='cuda')


In [None]:
run_training(args,dataset)

INFO:lightning_fabric.utilities.seed:Seed set to 42


Dataset:  PSAT
Initialize model
Using cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Training model


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /content/experiments exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | BartForConditionalGeneration | 406 M 
-------------------------------------------------------
406 M     Trainable params
0         Non-trainable params
406 M     Total params
1,625.162 Total estimated model params size (MB)


init TrainDataset ...
/content/PSAT/PSAT.train.complex
Initialized dataset done.....


  self.pid = os.fork()


/content/PSAT/PSAT.valid.complex


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Sari score:  37.57397796383342
Val_loss 0.6242602203616658
Sari score:  41.409848716171666
Val_loss 0.5859015128382834
Sari score:  38.91977351291177
Val_loss 0.6108022648708824


INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 9: 'val_loss' reached 0.61860 (best 0.61860), saving model to '/content/experiments/checkpoint-epoch=0.ckpt' as top 1


Sari score:  34.656834551807265
Val_loss 0.6534316544819274
***** Validation results *****
train_loss tensor(2.5249, device='cuda:0')
train_loss = tensor(2.5249, device='cuda:0')

val_loss tensor(0.6186, device='cuda:0')
val_loss = tensor(0.6186, device='cuda:0')



Validation: |          | 0/? [00:00<?, ?it/s]

Sari score:  44.01960420729633
Val_loss 0.5598039579270366
Sari score:  47.080934724211524
Val_loss 0.5291906527578847
Sari score:  42.88318682452272
Val_loss 0.5711681317547728


INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 18: 'val_loss' reached 0.55519 (best 0.55519), saving model to '/content/experiments/checkpoint-epoch=1.ckpt' as top 1


Sari score:  43.94168284926499
Val_loss 0.5605831715073502
***** Validation results *****
train_loss tensor(2.0337, device='cuda:0')
train_loss = tensor(2.0337, device='cuda:0')

val_loss tensor(0.5552, device='cuda:0')
val_loss = tensor(0.5552, device='cuda:0')



Validation: |          | 0/? [00:00<?, ?it/s]

Sari score:  46.42000862461784
Val_loss 0.5357999137538216
Sari score:  46.82505751623942
Val_loss 0.5317494248376058
Sari score:  43.13233355982076
Val_loss 0.5686766644017924


INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 27: 'val_loss' reached 0.55406 (best 0.55406), saving model to '/content/experiments/checkpoint-epoch=2.ckpt' as top 1


Sari score:  41.998960160978264
Val_loss 0.5800103983902174
***** Validation results *****
train_loss tensor(1.5704, device='cuda:0')
train_loss = tensor(1.5704, device='cuda:0')

val_loss tensor(0.5541, device='cuda:0')
val_loss = tensor(0.5541, device='cuda:0')



INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=3` reached.
Non-default generation parameters: {'max_length': 142, 'min_length': 56, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


training finished
Saving model


## D-Sari

In [None]:
from collections import Counter
import sys
import nltk
import math

def ReadInFile(filename):
    with open(filename) as f:
        lines = f.readlines()

        lines = [x.strip() for x in lines]

    return lines

def D_SARIngram(sgrams, cgrams, rgramslist, numref):
    rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams]

    rgramcounter = Counter(rgramsall)

    sgramcounter = Counter(sgrams)

    sgramcounter_rep = Counter()

    for sgram, scount in sgramcounter.items():
        sgramcounter_rep[sgram] = scount * numref

    cgramcounter = Counter(cgrams)

    cgramcounter_rep = Counter()

    for cgram, ccount in cgramcounter.items():
        cgramcounter_rep[cgram] = ccount * numref

    # KEEP

    keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep

    keepgramcountergood_rep = keepgramcounter_rep & rgramcounter

    keepgramcounterall_rep = sgramcounter_rep & rgramcounter

    keeptmpscore1 = 0

    keeptmpscore2 = 0

    for keepgram in keepgramcountergood_rep:
        keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram]

        keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram]

        # print "KEEP", keepgram, keepscore, cgramcounter[keepgram], sgramcounter[keepgram], rgramcounter[keepgram]

    keepscore_precision = 0

    if len(keepgramcounter_rep) > 0:
        keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep)

    keepscore_recall = 0

    if len(keepgramcounterall_rep) > 0:
        keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep)

    keepscore = 0

    if keepscore_precision > 0 or keepscore_recall > 0:
        keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall)

    # DELETION

    delgramcounter_rep = sgramcounter_rep - cgramcounter_rep

    delgramcountergood_rep = delgramcounter_rep - rgramcounter

    delgramcounterall_rep = sgramcounter_rep - rgramcounter

    deltmpscore1 = 0

    deltmpscore2 = 0

    for delgram in delgramcountergood_rep:
        deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram]

        deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram]

    delscore_precision = 0

    if len(delgramcounter_rep) > 0:
        delscore_precision = deltmpscore1 / len(delgramcounter_rep)

    delscore_recall = 0

    if len(delgramcounterall_rep) > 0:
        delscore_recall = deltmpscore1 / len(delgramcounterall_rep)

    delscore = 0

    if delscore_precision > 0 or delscore_recall > 0:
        delscore = 2 * delscore_precision * delscore_recall / (delscore_precision + delscore_recall)

    # ADDITION

    addgramcounter = set(cgramcounter) - set(sgramcounter)

    addgramcountergood = set(addgramcounter) & set(rgramcounter)

    addgramcounterall = set(rgramcounter) - set(sgramcounter)

    addtmpscore = 0

    for addgram in addgramcountergood:
        addtmpscore += 1

    addscore_precision = 0

    addscore_recall = 0

    if len(addgramcounter) > 0:
        addscore_precision = addtmpscore / len(addgramcounter)

    if len(addgramcounterall) > 0:
        addscore_recall = addtmpscore / len(addgramcounterall)

    addscore = 0

    if addscore_precision > 0 or addscore_recall > 0:
        addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall)

    return (keepscore, delscore_precision, addscore)

def count_length(ssent, csent, rsents):

    input_length = len(ssent.split(" "))

    output_length = len(csent.split(" "))

    reference_length = 0

    for rsent in rsents:

        reference_length += len(rsent.split(" "))

    reference_length = int(reference_length / len(rsents))

    return input_length, reference_length, output_length

def sentence_number(csent, rsents):

    output_sentence_number = len(nltk.sent_tokenize(csent))

    reference_sentence_number = 0

    for rsent in rsents:

        reference_sentence_number += len(nltk.sent_tokenize(rsent))

    reference_sentence_number = int(reference_sentence_number / len(rsents))

    return reference_sentence_number, output_sentence_number

def D_SARIsent(ssent, csent, rsents):
    numref = len(rsents)

    s1grams = ssent.lower().split(" ")

    c1grams = csent.lower().split(" ")

    s2grams = []

    c2grams = []

    s3grams = []

    c3grams = []

    s4grams = []

    c4grams = []

    r1gramslist = []

    r2gramslist = []

    r3gramslist = []

    r4gramslist = []

    for rsent in rsents:

        r1grams = rsent.lower().split(" ")

        r2grams = []

        r3grams = []

        r4grams = []

        r1gramslist.append(r1grams)

        for i in range(0, len(r1grams) - 1):

            if i < len(r1grams) - 1:
                r2gram = r1grams[i] + " " + r1grams[i + 1]

                r2grams.append(r2gram)

            if i < len(r1grams) - 2:
                r3gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2]

                r3grams.append(r3gram)

            if i < len(r1grams) - 3:
                r4gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2] + " " + r1grams[i + 3]

                r4grams.append(r4gram)

        r2gramslist.append(r2grams)

        r3gramslist.append(r3grams)

        r4gramslist.append(r4grams)

    for i in range(0, len(s1grams) - 1):

        if i < len(s1grams) - 1:
            s2gram = s1grams[i] + " " + s1grams[i + 1]

            s2grams.append(s2gram)

        if i < len(s1grams) - 2:
            s3gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2]

            s3grams.append(s3gram)

        if i < len(s1grams) - 3:
            s4gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2] + " " + s1grams[i + 3]

            s4grams.append(s4gram)

    for i in range(0, len(c1grams) - 1):

        if i < len(c1grams) - 1:
            c2gram = c1grams[i] + " " + c1grams[i + 1]

            c2grams.append(c2gram)

        if i < len(c1grams) - 2:
            c3gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2]

            c3grams.append(c3gram)

        if i < len(c1grams) - 3:
            c4gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] + " " + c1grams[i + 3]

            c4grams.append(c4gram)

    (keep1score, del1score, add1score) = D_SARIngram(s1grams, c1grams, r1gramslist, numref)

    (keep2score, del2score, add2score) = D_SARIngram(s2grams, c2grams, r2gramslist, numref)

    (keep3score, del3score, add3score) = D_SARIngram(s3grams, c3grams, r3gramslist, numref)

    (keep4score, del4score, add4score) = D_SARIngram(s4grams, c4grams, r4gramslist, numref)

    avgkeepscore = sum([keep1score, keep2score, keep3score, keep4score]) / 4

    avgdelscore = sum([del1score, del2score, del3score, del4score]) / 4

    avgaddscore = sum([add1score, add2score, add3score, add4score]) / 4

    input_length, reference_length, output_length = count_length(ssent, csent, rsents)

    reference_sentence_number, output_sentence_number = sentence_number(csent, rsents)

    if output_length >= reference_length:

        LP_1 = 1

    else:

        LP_1 = math.exp((output_length - reference_length) / output_length)

    if output_length > reference_length:

        LP_2 = math.exp((reference_length - output_length) / max(input_length - reference_length, 1))

    else:

        LP_2 = 1

    SLP = math.exp(-abs(reference_sentence_number - output_sentence_number) / max(reference_sentence_number,
                                                                                  output_sentence_number))

    avgkeepscore = avgkeepscore * LP_2 * SLP

    avgaddscore = avgaddscore * LP_1

    avgdelscore = avgdelscore * LP_2

    finalscore = (avgkeepscore + avgdelscore + avgaddscore) / 3

    return finalscore, avgkeepscore, avgdelscore, avgaddscore

def D_SARI_file(ssent, csent, rsents):
    D_SARI = 0
    for st, ct, rt in zip(ssent, csent, rsents):
        D_SARI += D_SARIsent(st, ct, [rt])[0]
    return 100 * D_SARI / len(ssent)

In [None]:
import click
from easse.utils.resources import get_orig_sents, get_refs_sents

def get_sys_sents(test_set, sys_sents_path=None):
    # Get system sentences to be evaluated
    if sys_sents_path is not None:
        return read_lines(sys_sents_path)
    else:
        # read the system output
        with click.get_text_stream("stdin", encoding="utf-8") as system_output_file:
            return system_output_file.read().splitlines()


def get_orig_and_refs_sents(test_set, orig_sents_path=None, refs_sents_paths=None):
    # Get original and reference sentences
    if test_set == "custom":
        assert orig_sents_path is not None
        assert refs_sents_paths is not None

        orig_sents = read_lines(orig_sents_path)
        refs_sents = [read_lines(refs_sents_paths)]
    else:
        orig_sents = get_orig_sents(test_set)
        refs_sents = get_refs_sents(test_set)
    return orig_sents, refs_sents

## Evaluation

In [None]:
from pathlib import Path
import sys
import os
import shutil

from easse.cli import evaluate_system_output
from easse.report import get_all_scores
from contextlib import contextmanager
import json
import torch
from easse.sari import corpus_sari
import time

@contextmanager
def log_stdout(filepath, mute_stdout=False):
    '''Context manager to write both to stdout and to a file'''

    class MultipleStreamsWriter:
        def __init__(self, streams):
            self.streams = streams

        def write(self, message):
            for stream in self.streams:
                stream.write(message)

        def flush(self):
            for stream in self.streams:
                stream.flush()

    save_stdout = sys.stdout
    log_file = open(filepath, 'w')
    if mute_stdout:
        sys.stdout = MultipleStreamsWriter([log_file])  # Write to file only
    else:
        sys.stdout = MultipleStreamsWriter([save_stdout, log_file])  # Write to both stdout and file
    try:
        yield
    finally:
        sys.stdout = save_stdout
        log_file.close()

# set random seed universal
def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(42)
model_dir = None
_model_dirname = None
max_len = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')

# ### Untuned BART model ####
# from transformers import BartForConditionalGeneration,BartTokenizer
# # load the model
# model = BartForConditionalGeneration.from_pretrained('facebook/bart-base').to(device)
# #print(model)
# tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
# ### Untuned BART model ####

#### Single model ####
# load the model
Model = BartBaseLineFineTuned.load_from_checkpoint('/content/experiments/checkpoint-epoch=2.ckpt').to(device)
model = Model.model.to(device)
tokenizer = Model.tokenizer
#### Single model ####

def generate_single(sentence, preprocessor = None):
    '''
    This function is for single BART/BRIO model to generate/predict
    '''

    text = sentence
    encoding = tokenizer(text, max_length=512,
                                     padding='max_length',
                                     truncation=True,
                                     return_tensors="pt")
    input_ids = encoding["input_ids"].to(device)
    attention_masks = encoding["attention_mask"].to(device)

    # set top_k = 50 and set top_p = 0.95 and num_return_sequences = 3
    beam_outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_masks,
        do_sample=True,
        max_length=256,
        num_beams=2,
        top_k=70,
        top_p=0.95,
        early_stopping=True,
        num_return_sequences=1,
    )
    sent = tokenizer.decode(beam_outputs[0].tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return sent

def evaluate(orig_filepath, sys_filepath, ref_filepaths):
    orig_sents = read_lines(orig_filepath)
    refs_sents = [read_lines(ref_filepaths)]

    return corpus_sari(orig_sents, read_lines(sys_filepath), refs_sents)

def back_translation(text):
    X = translator.translate(text, dest = 'de')
    return translator.translate(X.text, dest = 'en').text


def simplify_file(complex_filepath, output_filepath, features_kwargs=None, model_dirname=None, post_processing=True):
    '''
    Obtain the simplified sentences (predictions) from the original complex sentences.
    '''

    total_lines = count_line(complex_filepath)
    print(complex_filepath)
    print(complex_filepath.stem)

    output_file = Path(output_filepath).open("w")

    for n_line, complex_sent in enumerate(yield_lines(complex_filepath), start=1):
        output_sents = generate_single(complex_sent, preprocessor = None)
        print(f"{n_line}/{total_lines}", " : ", output_sents)

        if output_sents:
            output_file.write(output_sents + "\n")
        else:
            output_file.write("\n")
    output_file.close()

    if post_processing: post_process(output_filepath)

def post_process(filepath):
    lines = []
    for line in yield_lines(filepath):
        lines.append(line.replace("''", '"'))
    write_lines(lines, filepath)


def evaluate_on_PSAT(phase, features_kwargs=None,  model_dirname = None):
    dataset = PSAT
    output_dir = '/content/outputs'

    if not os.path.exists(output_dir):
      os.makedirs(output_dir)

    output_score_filepath = f'{output_dir}/score_{dataset}_{phase}.log.txt'
    complex_filepath = get_data_filepath(dataset, phase, 'complex')

    if not os.path.exists(output_score_filepath) or count_line(output_score_filepath)==0:
        start_time = time.time()
        complex_filepath =get_data_filepath(dataset, phase, 'complex')
        complex_filepath = Path(complex_filepath)

        pred_filepath = f'{output_dir}/{complex_filepath.stem}.txt'
        ref_filepaths = get_data_filepath(dataset, phase, 'simple')

        if os.path.exists(pred_filepath) and count_line(pred_filepath)==count_line(complex_filepath):
            print("File is already processed.")
        else:
            simplify_file(complex_filepath, pred_filepath, features_kwargs, model_dirname)

        print("Evaluate: ", pred_filepath)

        with log_stdout(output_score_filepath):

            scores  = evaluate_system_output(test_set='custom',
                                             sys_sents_path=str(pred_filepath),
                                             orig_sents_path=str(complex_filepath),
                                             refs_sents_paths=str(ref_filepaths),metrics = ['bleu', 'sari', 'fkgl'] )

            sys_sents = get_sys_sents(test_set = 'custom', sys_sents_path=str(pred_filepath))
            orig_sents, refs_sents = get_orig_and_refs_sents(test_set = 'custom',
                                                             orig_sents_path = str(complex_filepath),
                                                             refs_sents_paths = str(ref_filepaths))


            print("SARI: {:.2f}\t D-SARI: {:.2f} \t BLEU: {:.2f} \t FKGL: {:.2f} ".format(scores['sari'],
                                                                                          D_SARI_file(orig_sents,sys_sents,refs_sents[0],),
                                                                                          scores['bleu'],
                                                                                          scores['fkgl']))

            print("Execution time: --- %s seconds ---" % (time.time() - start_time))
            return None
    else:
        print("Already exists: ", output_score_filepath)
        print("".join(read_lines(output_score_filepath)))

evaluate_on_PSAT(phase='test', features_kwargs=None, model_dirname=None)

Using cuda
/content/PSAT/PSAT.test.complex
PSAT.test
1/33  :  How to Apply: You must submit an application. You must provide an application with the application fee. Official High School transcripts. Official GED passing transcripts. If you graduated from High School outside of the US, you must submit a copy of your High School diploma. Official Post-secondary transcripts from non-US institutions. Home School student must submit Home School Completion Affidavit. Official Home School transcripts must be signed by parent and notarized. Official post-secondary transcript for students who attended a non-U.S. institution must be sent electronically or in a sealed envelope directly from the school. You can submit these materials via the GED website. Official English language transcripts. You will only be required to submit English transcripts if they are in Hebrew, Arabic, or Asian language. You may submit these transcripts through the National Association of Credential Evaluation (NACES) If

## Download Outputs

In [None]:
# Path to the zip file
zip_file_path = '/content/outputs.zip'

if os.path.exists(zip_file_path):
    os.remove(zip_file_path)

# Compress the folder into a zip file
!zip -r {zip_file_path} {os.path.basename('/content/outputs')}

  adding: outputs/ (stored 0%)
  adding: outputs/PSAT.test.txt (deflated 71%)
  adding: outputs/score_PSAT_test.log.txt (deflated 6%)


In [None]:
from google.colab import files
files.download(f'/content/outputs.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## References

*  [SimSum GitHub Code](https://github.com/epfml/easy-summary/tree/main)

