
## Fine Tune GPTJ to Summarize

* Ref: OpenAI, CarperAI, HF


In [8]:

## !pip install evaluate
## !pip install rouge_score


In [9]:


import json

import pandas as pd
import torch
from datasets import load_dataset
from torch.utils.data import Dataset

import random

import evaluate
import numpy as np

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    default_data_collator,
)



## Parameters


In [10]:

## Set up the metric

rouge = evaluate.load("rouge")



In [11]:

output_dir                  = "gptj-supervised-summarize-checkpoint"
train_batch_size            = 16
gradient_accumulation_steps = 1
learning_rate               = 1e-5
eval_batch_size             = 1
eval_steps                  = 500
max_input_length            = 550
save_steps                  = 1000
num_train_epochs            = 1         ## 5

random.seed(42)



## Utility Functions


In [12]:

def get_dataset_from_jsonl(jsonl_file, return_summary=True):
    # if return_summary is True, return a list of posts with summary concatenated
    # if return_summary is False, return a list of posts and a list of summaries
    with open(jsonl_file, "r") as f:
        dataset = [json.loads(line) for line in f]
    post_list = []
    summary_list = []
    for d in dataset:
        if return_summary:
            post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: {d['summary']}"
        else:
            post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: "
            summary_list.append(d["summary"])
        post_list.append(post)
    if not return_summary:
        return post_list, summary_list
    return post_list


In [13]:

def set_seed(seed_val=42):
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)


In [14]:

def compute_metrics(eval_preds):
    labels_ids = eval_preds.label_ids
    pred_ids   = eval_preds.predictions
    pred_str   = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str  = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    result     = rouge.compute(predictions=pred_str, references=label_str)
    return result


In [15]:

# Create a preprocessing function to extract out the proper logits from the model output
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        logits = logits[0]
    return logits.argmax(dim=-1)



## Classes


In [16]:

class TLDRDataset(Dataset):
    def __init__(self, train_path, tokenizer, split, max_length=550):
        self.post_list = []
        dataset = load_dataset(train_path, split=split)
        for sample in dataset:
            self.post_list.append(sample["prompt"] + sample["label"])
        if "valid" in split:
            self.post_list = self.post_list[0:2000]
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        txt = self.post_list[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": input_ids,
        }


In [17]:


class ComparisonDataset(Dataset):
    def __init__(self, comparison_path, tokenizer, max_length=550):
        with open(comparison_path, "r") as f:
            dataset = [json.loads(line) for line in f]

        self.tokenizer = tokenizer
        self.post_list = []
        self.summaries_0 = []
        self.summaries_1 = []
        self.labels = []
        self.max_length = max_length

        def make_text(post, summarize):
            return f"SUBREDDIT: r/{post['subreddit']}\nTITLE: {post['title']}\nPOST: {post['post']}\nTL;DR: {summarize}"

        for sample in dataset:  # chosen summary is always the first one
            self.post_list.append(sample["info"]["post"])
            # NOTE: The chosen summary is always the first one, i.e. `sample["summaries"][0]`
            if sample["choice"] == 0:
                self.summaries_0.append(make_text(sample["info"], sample["summaries"][0]["text"]))
                self.summaries_1.append(make_text(sample["info"], sample["summaries"][1]["text"]))
            else:
                self.summaries_0.append(make_text(sample["info"], sample["summaries"][1]["text"]))
                self.summaries_1.append(make_text(sample["info"], sample["summaries"][0]["text"]))
            self.labels.append(0)

    def __len__(self):
        return len(self.post_list)

    def __getitem__(self, idx):
        summ0 = self.summaries_0[idx]
        summ1 = self.summaries_1[idx]
        encodings_dict = self.tokenizer(
            [summ0, summ1],
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
        )
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attention_mask = torch.tensor(encodings_dict["attention_mask"])
        return {"input_ids": input_ids, "attention_mask": attention_mask}



In [18]:

class AllSummDataset(Dataset):
    def __init__(self, train_path, tokenizer, split, max_length=1024):
        df = pd.read_parquet(train_path)
        if split == "valid":
            df = df.sample(n=5000)
        self.summarizes = []
        for i, row in df.iterrows():
            self.summarizes.append(f"Summarize: {row['text']}. TL;DR: {row['summary']}")
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_ids = []
        self.attn_masks = []

    def __len__(self):
        return len(self.summarizes)

    def __getitem__(self, idx):
        txt = self.summarizes[idx]
        encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
        input_ids = torch.tensor(encodings_dict["input_ids"])
        attn_masks = torch.tensor(encodings_dict["attention_mask"])

        return {
            "input_ids": input_ids,
            "attention_mask": attn_masks,
            "labels": input_ids,
        }



## Tokenizer


In [19]:

tokenizer              = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token    = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id



## Model


In [None]:

model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", use_cache=False)

model.resize_token_embeddings( len(tokenizer) )

model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id


Downloading config.json:   0%|          | 0.00/930 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]


## Set up the datasets


In [None]:

data_path = "CarperAI/openai_summarize_tldr"


In [None]:

    
train_dataset = TLDRDataset(
    data_path,
    tokenizer,
    "train",
    max_length=max_input_length,
)


In [None]:


dev_dataset = TLDRDataset(
    data_path,
    tokenizer,
    "valid",
    max_length=max_input_length,
)



## Prepare the trainer and start training


In [None]:

training_args = TrainingArguments(
    output_dir                  = output_dir,
    evaluation_strategy         = "steps",
    eval_accumulation_steps     = 1,
    learning_rate               = learning_rate,
    per_device_train_batch_size = train_batch_size,
    per_device_eval_batch_size  = eval_batch_size,
    gradient_checkpointing      = True,
    half_precision_backend      = True,
    fp16                        = True,
    adam_beta1                  = 0.9,
    adam_beta2                  = 0.95,
    gradient_accumulation_steps = gradient_accumulation_steps,
    num_train_epochs            = num_train_epochs,
    warmup_steps                = 1,
    eval_steps                  = eval_steps,
    save_steps                  = save_steps,
    load_best_model_at_end      = True,
    logging_steps               = 50,
    deepspeed                   = "./ds_config_gptj.json",
    ## no_cuda                       = True,
)


In [None]:
   
trainer = Trainer(
    model                         = model,
    args                          = training_args,
    train_dataset                 = train_dataset,
    eval_dataset                  = dev_dataset,
    compute_metrics               = compute_metrics,
    data_collator                 = default_data_collator,
    preprocess_logits_for_metrics = preprocess_logits_for_metrics,
)
    

In [None]:

trainer.train()


In [None]:

trainer.save_model(output_dir)
