In [1]:
!nvidia-smi

Wed Dec  8 18:45:01 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  Off  | 00000000:01:01.0 Off |                    0 |
| N/A   31C    P0    26W / 250W |      0MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [2]:
# !pip install torch transformers datasets nltk rouge_score gensim

In [3]:
from collections import defaultdict

import datasets
import numpy as np
from gensim.parsing.preprocessing import STOPWORDS

from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments

DataCollatorForSeq2Seq
import torch

In [4]:
### load the data

In [5]:
def load_splits_from_bucket_dir(base_path):
    splits = defaultdict(dict)
    dataset_names = ["ami", "cc_news", "cnn_dailymail", "icsi", "mediasumm", "spotify", "ted", "xsum"]
    split_names = ["train", "valid", "test"]
    for name in dataset_names:
        for split in split_names:
            splits[name][split] = f"{base_path}/{name}/{split}.json.gz"
    return dict(splits)

In [6]:
def _load_dataset_from_split(dataset_name, splits):
    features = datasets.features.Features({
        'document': datasets.Sequence(feature=datasets.Value(dtype='string', id=None), length=-1, id=None),
        'summary': datasets.Sequence(feature=datasets.Value(dtype='string', id=None), length=-1, id=None)
    })

    dataset = datasets.load_dataset("json", dataset_name, data_files=splits, features=features, )
    return dataset


In [7]:
def load_dataset_from_splits(data_splits):
    train_dataset = []
    valid_dataset = []
    test_dataset = []
    for dataset, split in data_splits.items():
        print(f"Loading {dataset} {split}")
        dataset = _load_dataset_from_split(dataset, split)
        train_dataset.append(dataset["train"])
        valid_dataset.append(dataset["valid"])
        test_dataset.append(dataset["test"])

    dataset_dict = datasets.DatasetDict(
        train=datasets.concatenate_datasets(train_dataset),
        valid=datasets.concatenate_datasets(valid_dataset),
        test=datasets.concatenate_datasets(test_dataset))
    return dataset_dict

In [8]:
### Preprocess the data

In [9]:
NUM_PROC = 6


def count_sentences(examples):
    return {
        "document_sentence_count": [len(document) for document in examples["document"]],
        "summary_sentence_count": [len(document) for document in examples["summary"]],
    }


def count_words(examples):
    return {
        "document_word_count": [sum(len(item.split()) for item in document) for document in examples["document"]],
        "summary_word_count": [sum(len(item.split()) for item in document) for document in examples["summary"]],
    }


def count_chars(examples):
    return {
        "document_char_count": [sum(len("".join(item.split())) for item in document) for document in
                                examples["document"]],
        "summary_char_count": [sum(len("".join(item.split())) for item in document) for document in
                               examples["summary"]],
    }


def sentence_density(examples):
    return {
        "document_sentence_density": [document[0] / (document[1] if document[1] else 1) for document in
                                      zip(examples["document_sentence_count"], examples["document_word_count"])],
        "summary_sentence_density": [document[0] / (document[1] if document[1] else 1) for document in
                                     zip(examples["summary_sentence_count"], examples["summary_word_count"])],
    }


def _count_stopwords(text, stopwords=STOPWORDS):
    ''' Return the number of stopwords in the text
        Input:
            - text: string
            - stopwords: list of string, containing the stopwords
        Output:
            - int, number of stopwords in the text argument
    '''
    stopwords_x = [w for w in "\n".join(text).split() if w.lower() in stopwords]

    return len(stopwords_x)


def count_stopwords(examples):
    return {
        "document_stopword_count": [_count_stopwords(document) for document in examples["document"]],
        "summary_stopword_count": [_count_stopwords(document) for document in examples["summary"]]
    }


def count_word_density(examples):
    return {
        "document_word_density": [(document[0] - document[1]) / (document[0] if document[0] else 1) for document in
                                  zip(examples["document_word_count"], examples["document_stopword_count"])],
        "summary_word_density": [(document[0] - document[1]) / (document[0] if document[0] else 1) for document in
                                 zip(examples["summary_word_count"], examples["summary_stopword_count"])],
    }


def add_info_to_dataset_dict(dataset_dict):
    return (
        dataset_dict
            .map(count_sentences, batched=True, num_proc=NUM_PROC)
            .map(count_words, batched=True, num_proc=NUM_PROC)
            .map(count_chars, batched=True, num_proc=NUM_PROC)
            .map(sentence_density, batched=True, num_proc=NUM_PROC)
            .map(count_stopwords, batched=True, num_proc=NUM_PROC)
            .map(count_word_density, batched=True, num_proc=NUM_PROC)
    )


In [10]:
def filter_by_density(example):
    return example["document_word_density"] > 0.1


def filter_lower_counts(example):
    return (
            example["document_word_count"] > 100 and
            example["summary_word_count"] > 10)


def filter_upper_counts(example):
    return (
            example["document_word_count"] < 16500 and
            example["summary_word_count"] < 500)


def apply_filters(dataset):
    return (dataset
            .filter(filter_by_density, num_proc=NUM_PROC)
            .filter(filter_lower_counts, num_proc=NUM_PROC)
            .filter(filter_upper_counts, num_proc=NUM_PROC))


def clean_up_dataset(dataset):
    dataset = add_info_to_dataset_dict(dataset)
    dataset = apply_filters(dataset)
    return dataset



In [11]:
from transformers.file_utils import PaddingStrategy
from typing import Optional, Union, Dict
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
from transformers import LEDTokenizer


class LEDTokenizerFixed(LEDTokenizer):

    def _pad(
            self,
            encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
            max_length: Optional[int] = None,
            padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
            pad_to_multiple_of: Optional[int] = None,
            return_attention_mask: Optional[bool] = None,
    ) -> dict:
        """
        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)

        Args:
            encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
            max_length: maximum length of the returned list and optionally padding length (see below).
                Will truncate by taking into account the special tokens.
            padding_strategy: PaddingStrategy to use for padding.

                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
                - PaddingStrategy.DO_NOT_PAD: Do not pad
                The tokenizer padding sides are defined in self.padding_side:

                    - 'left': pads on the left of the sequences
                    - 'right': pads on the right of the sequences
            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
                >= 7.5 (Volta).
            return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
        """
        # Load from model defaults
        if return_attention_mask is None:
            return_attention_mask = "attention_mask" in self.model_input_names

        required_input = encoded_inputs[self.model_input_names[0]]

        if padding_strategy == PaddingStrategy.LONGEST:
            max_length = len(required_input)

        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of

        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length

        # Initialize attention mask if not present.
        if return_attention_mask and "attention_mask" not in encoded_inputs:
            encoded_inputs["attention_mask"] = [1] * len(required_input)

        if needs_to_be_padded:
            difference = max_length - len(required_input)

            if self.padding_side == "right":
                if return_attention_mask:
                    encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
                    encoded_inputs["global_attention_mask"] = (
                            encoded_inputs["global_attention_mask"] + [0] * difference
                    )
                if "token_type_ids" in encoded_inputs:
                    encoded_inputs["token_type_ids"] = (
                            encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
                    )
                if "special_tokens_mask" in encoded_inputs:
                    encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
                encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
            elif self.padding_side == "left":
                if return_attention_mask:
                    encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
                if "token_type_ids" in encoded_inputs:
                    encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
                        "token_type_ids"
                    ]
                if "special_tokens_mask" in encoded_inputs:
                    encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
                encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
            else:
                raise ValueError("Invalid padding strategy:" + str(self.padding_side))

        return encoded_inputs


In [13]:
# Model Name
model_name = "../models/led-16k/checkpoint-1000"
tokenizer = LEDTokenizerFixed.from_pretrained(model_name)

In [14]:

def tokenize_dataset(examples):
    inputs = ["\n".join(document) for document in examples["document"]]
    targets = ["\n".join(document) for document in examples["summary"]]
    model_inputs = tokenizer(inputs, max_length=4096, padding=False, truncation=True)

    model_inputs["global_attention_mask"] = [np.zeros_like(input).tolist() for input in model_inputs["input_ids"]]
    # put global attention on <s> token
    for input in model_inputs["global_attention_mask"][:]:
        input[0] = 1

    model_inputs["global_attention_mask"] = model_inputs["global_attention_mask"]
    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, padding=False, truncation=True, )

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


def preprocess_dataset(dataset):
    return (dataset
            .sort('document_word_count')
            .map(tokenize_dataset, batched=True, num_proc=NUM_PROC, )
            .shuffle())



In [15]:
rouge = datasets.load_metric("rouge")

In [16]:
import nltk


# compute Rouge score during validation
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [17]:
# RAW_SUMMARIZATION_DATASETS_BUCKET_DIR = os.environ["RAW_SUMMARIZATION_DATASETS_BUCKET_DIR"]
# data_splits = load_splits_from_bucket_dir(RAW_SUMMARIZATION_DATASETS_BUCKET_DIR)
# dataset_dict = load_dataset_from_splits(data_splits)
# dataset_dict = clean_up_dataset(dataset_dict)
# dataset_dict  = preprocess_dataset(dataset_dict)
# dataset_dict.save_to_disk("../datasets/interim/summarization_dataset_big")

In [18]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_cache=False)
model.gradient_checkpointing_enable()

model.config.num_beams = 4
model.config.max_length = 512
model.config.min_length = 100
model.config.length_penalty = 2.0
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3


In [19]:

dataset_dict = datasets.load_from_disk("../datasets/interim/summarization_dataset_big")

In [20]:
# dataset_dict.set_format("numpy", columns=["input_ids", "attention_mask", "global_attention_mask", "labels"])

In [21]:
removal_cols = ["document",
                "document_char_count",
                "document_sentence_count",
                "document_sentence_density",
                "document_stopword_count",
                "document_word_count",
                "document_word_density",
                "summary",
                "summary_char_count",
                "summary_sentence_count",
                "summary_sentence_density",
                "summary_stopword_count",
                "summary_word_count",
                "summary_word_density"
                ]

dataset_dict = dataset_dict.remove_columns(removal_cols)

In [23]:
batch_size = 12
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    max_length=tokenizer.model_max_length,
    pad_to_multiple_of=1024,
    label_pad_token_id=-100, )

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    dataloader_drop_last=True,
    group_by_length=True,
    fp16=True,
    output_dir="../models/led-16k",
    logging_steps=5,
    eval_steps=1,
    save_steps=100,
    save_total_limit=2,
    gradient_accumulation_steps=4,
    num_train_epochs=10,
)

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["valid"].select(range(50)),
)


PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using amp fp16 backend


In [24]:
# train_dataloader = trainer.get_train_dataloader()
# loader = iter(train_dataloader)
# batch = next(loader)
# for k,v in batch.items():
#     print(k,reversed(v[0])[:10], v.shape)

In [25]:
# for k,v in dataset_dict["train"][0].items():
#     print(k,v.shape)

In [None]:
# start training
torch.cuda.empty_cache()
trainer.train()

In [None]:
inputs = dict(
    input_ids=torch.tensor(
        dataset_dict["test"]["input_ids"][1:2], dtype=torch.int32).to("cuda:0"),
    attention_mask=torch.tensor(
        dataset_dict["test"]["attention_mask"][1:2], dtype=torch.int32).to("cuda:0"),
    global_attention_mask=torch.tensor(
        dataset_dict["test"]["global_attention_mask"][1:2], dtype=torch.int32).to("cuda:0"))
with torch.no_grad():
    predicted_ids = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True).to("cpu")

In [None]:
dataset_other = datasets.load_from_disk("../datasets/interim/summarization_dataset_big")
"\n".join(dataset_other["test"]["document"][1])

In [None]:

"\n".join(dataset_other["test"]["summary"][1])

In [None]:
tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, )[0]