### Setup

In [None]:
!nvidia-smi

Mon Aug 22 17:09:55 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   44C    P0    25W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

### Utils

In [None]:
%%capture
%pip install -r /content/feedback/code/requirements.txt

In [None]:
import os
import gc
import glob
import json
import pandas as pd
import numpy as np

from sklearn.metrics import f1_score, log_loss
import matplotlib.pyplot as plt
from itertools import chain

from copy import deepcopy
from dataclasses import dataclass

import torch
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint

import math
import shutil

from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
from transformers.trainer_pt_utils import get_parameter_names

import pdb
import random
from collections import OrderedDict

from accelerate import Accelerator
from tqdm.auto import tqdm

import re


from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration


from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

from transformers.optimization import  Adafactor


In [None]:
config =json.loads("""{
    "debug": false,
    "seed": 453,
    "model_checkpoint": "t5-large",
    
    "batch_size": 2,
    "warmup_pct": 0.025,
    "num_epochs": 5,
    
    "gradient_checkpointing": false,
    "fp16": true,
    "grad_accumulation": 4,
    "eval_frequency": 2400,
    "max_length": 512,

    "model_dir": "../models/T5_generator",

    "project": "feedback-prize-ea",
    "run_name": "rb-t5-data-augmentation",
    "patience": 10,
    
    "fpe_dataset_dir": "../datasets/feedback-prize-effectiveness",
    "fold_path": "../datasets/processed/cv_map_topics_10_folds.parquet",
    "train_essay_fpe21_dir": "../datasets/processed/fpe_21_train_essays.parquet",
    "train_essay_fpe22_dir": "../datasets/processed/fpe_22_train_essays.parquet",
    "test_essay_fpe22_dir": "../datasets/processed/fpe_22_test_essays.parquet",
    "n_folds":10
}""")

In [None]:
df = pd.read_csv(os.path.join(config["fpe_dataset_dir"], "train.csv"))
fold_df = pd.read_parquet(config["fold_path"])
df = pd.merge(df, fold_df, on="essay_id", how="left")
essay_df = pd.read_parquet(config["train_essay_fpe22_dir"])

topic_df = pd.read_csv("../datasets/processed/fpe_2021_topics.csv")
topic_df = topic_df[["essay_id", "prompt", "topic_num"]].copy()

In [None]:
def relaxed_search(text, substring, min_length=2, fraction=0.99999):
    """
    Returns substring's span from the given text with the certain precision.
    """

    position = text.find(substring)
    substring_length = len(substring)
    if position == -1:
        half_length = int(substring_length * fraction)
        half_substring = substring[:half_length]
        half_substring_length = len(half_substring)
        if half_substring_length < min_length:
            return [-1, 0]
        else:
            return relaxed_search(text=text,
                                  substring=half_substring,
                                  min_length=min_length,
                                  fraction=fraction)

    span = [position, position+substring_length]
    return span


def build_span_map(discourse_list, essay_text):
    reading_head = 0
    to_return = dict()

    for cur_discourse in discourse_list:
        if cur_discourse not in to_return:
            to_return[cur_discourse] = []

        matches = re.finditer(re.escape(r'{}'.format(cur_discourse)), essay_text)
        for match in matches:
            span_start, span_end = match.span()
            if span_end <= reading_head:
                continue
            to_return[cur_discourse].append(match.span())
            reading_head = span_end
            break

    # post process
    for cur_discourse in discourse_list:
        if not to_return[cur_discourse]:
            print("resorting to relaxed search...")
            to_return[cur_discourse] = [relaxed_search(essay_text, cur_discourse)]
    return to_return


def get_substring_span(texts, mapping):
    result = []
    for text in texts:
        ans = mapping[text].pop(0)
        result.append(ans)
    return result


def process_input_df(anno_df, notes_df):
    """pre-process input dataframe

    :param df: input dataframe
    :type df: pd.DataFrame
    :return: processed dataframe
    :rtype: pd.DataFrame
    """
    notes_df = deepcopy(notes_df)
    anno_df = deepcopy(anno_df)

    #------------------- Pre-Process Essay Text --------------------------#
    anno_df["discourse_text"] = anno_df["discourse_text"].apply(lambda x: x.strip())  # pre-process
    if "discourse_effectiveness" in anno_df.columns:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text",
                           "discourse_type", "discourse_effectiveness"]].copy()
    else:
        anno_df = anno_df[["discourse_id", "essay_id", "discourse_text", "discourse_type", "uid"]].copy()

    tmp_df = anno_df.groupby("essay_id")[["discourse_id", "discourse_text"]].agg(list).reset_index()
    tmp_df = pd.merge(tmp_df, notes_df, on="essay_id", how="left")
    tmp_df["span_map"] = tmp_df[["discourse_text", "essay_text"]].apply(
        lambda x: build_span_map(x[0], x[1]), axis=1)
    tmp_df["span"] = tmp_df[["discourse_text", "span_map"]].apply(
        lambda x: get_substring_span(x[0], x[1]), axis=1)

    all_discourse_ids = list(chain(*tmp_df["discourse_id"].values))
    all_discourse_spans = list(chain(*tmp_df["span"].values))
    span_df = pd.DataFrame()
    span_df["discourse_id"] = all_discourse_ids
    span_df["span"] = all_discourse_spans
    span_df["discourse_start"] = span_df["span"].apply(lambda x: x[0])
    span_df["discourse_end"] = span_df["span"].apply(lambda x: x[1])
    span_df = span_df.drop(columns="span")

    anno_df = pd.merge(anno_df, span_df, on="discourse_id", how="left")
    return anno_df


In [None]:
df = process_input_df(df, essay_df)
df = pd.merge(df, essay_df, on="essay_id", how="left")
df = pd.merge(df, topic_df, on="essay_id", how="left")

def get_model_input_text(prompt, left_context, right_context, discourse_type, discourse_effectiveness):

    to_return = [
        f"Generate text for {discourse_effectiveness} {discourse_type}", 
        f"Prompt: {prompt}",
        f"Left Context: {left_context}",
        f"Right Context: {right_context}"
    ]
    return "|| \n".join(to_return)

def get_model_output_text(discourse_text):
    return discourse_text


df["left_context"] = df[["essay_text", "discourse_start"]].apply(lambda x: x[0][:x[1]], axis=1)
df["right_context"] = df[["essay_text", "discourse_end"]].apply(lambda x: x[0][x[1]:], axis=1)
df["model_input"] = df[["prompt", "left_context", "right_context", "discourse_type", "discourse_effectiveness"]].apply(
    lambda x: get_model_input_text(x[0], x[1], x[2], x[3], x[4]), axis=1
)

df["model_output"] = df["discourse_text"].apply(lambda x: get_model_output_text(x))

resorting to relaxed search...


In [None]:
df[["model_input", "model_output"]].sample(1)

Unnamed: 0,model_input,model_output
29990,Generate text for Ineffective Evidence|| \nPro...,NASA has lots of it i believed the rich people...


In [None]:
model = T5ForConditionalGeneration.from_pretrained(config["model_checkpoint"])
tokenizer = T5Tokenizer.from_pretrained(config["model_checkpoint"], model_max_length=config["max_length"])

ckpt = torch.load("../models/T5_generator/t5_generator_left_right_context.pth.tar")
model.load_state_dict(ckpt["state_dict"])

del ckpt
gc.collect()
torch.cuda.empty_cache()

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.75G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

In [None]:
accelerator = Accelerator()
model = accelerator.prepare(model)

In [None]:
model.eval()
config["max_length"] = 512

In [None]:
def generate_text(input_text, num_sents):
    test_tokenized = tokenizer.encode_plus(
        input_text, 
        add_special_tokens=True, 
        return_tensors="pt", 
        truncation=True, 
        max_length=config["max_length"], 
        padding=False
    )
    test_input_ids  = test_tokenized["input_ids"]
    test_attention_mask = test_tokenized["attention_mask"]
    
    beam_outputs = model.generate(
        input_ids = test_input_ids.to("cuda"),
        attention_mask = test_attention_mask.to("cuda"),
        max_length=config["max_length"],
        early_stopping=True,
        num_beams=15,
        num_return_sequences=num_sents,
        no_repeat_ngram_size=2,
        temperature=2.0,
        do_sample=True,
    )

    to_return = []
    for beam_output in beam_outputs:
        sent = tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        to_return.append(sent)
    torch.cuda.empty_cache()
    return to_return

In [None]:
def generate_aumentation(essay_id, input_df, num_sents=5):
    example_df = input_df[input_df["essay_id"]==essay_id].copy()
    example_df = example_df.sort_values(by="discourse_start")
    example_df = example_df.reset_index(drop=True)
    
    essay_text = example_df["essay_text"].values[0]
    starts = example_df["discourse_start"].tolist()
    ends = example_df["discourse_end"].tolist()
    
    starts.append(-1)
    
    fillers = [essay_text[:starts[0]]]
    
    for s, e in zip(ends, starts[1:]):
        fillers.append(essay_text[s:e])
               
    generated_texts = dict()

    for idx, (row_id, row) in enumerate(example_df.iterrows()):
        model_input = row.model_input
        generated_texts[idx] = generate_text(model_input, num_sents=num_sents)
               
    generated_df = pd.DataFrame(generated_texts).T
    generated_df.columns = [f"t5_generated_{i}" for i in range(num_sents)]
    result_df = pd.concat([example_df, generated_df], axis=1)
    result_df = result_df.drop(columns=["essay_text", "topic_num", "left_context", "right_context", "model_input", "model_output"])
    
    generated_essays = []
    for col_idx in range(num_sents):
        current_essay = fillers[0]
        for idx in range(len(generated_df)):
            current_essay += generated_df.iloc[idx][f"t5_generated_{col_idx}"]
            current_essay += fillers[idx+1]
        generated_essays.append(current_essay)   
    return result_df, essay_text, generated_essays

In [None]:
all_essay_ids = df["essay_id"].unique().tolist()
random.shuffle(all_essay_ids)

In [None]:
cache_df = pd.read_csv("../datasets/augmented_data/t5_essays.csv")
cache_df["essay_id"] = cache_df["essay_id"].apply(lambda x: str(x).split("_")[0])
done_essays = cache_df["essay_id"].unique().tolist()
len(done_essays)

2202

In [None]:
for essay_num in tqdm(range(len(all_essay_ids))):
    essay_id = all_essay_ids[essay_num]
    if essay_id in done_essays:
        print(f"skipping {essay_id}")
        continue
        
    result_df, original_essay, generated_essays = generate_aumentation(essay_id, df, num_sents=5)
    
    result_df.to_csv(f"../datasets/augmented_data/worker_3/df_{essay_id}.csv", index=False)
    
    for idx, aug_text in enumerate(generated_essays):
        content = {f"{essay_id}": aug_text}
        with open(f"../datasets/augmented_data/worker_3/{essay_id}_augmented_{idx}.json", "w") as f:
            json.dump(content, f)
    torch.cuda.empty_cache()

  0%|          | 0/4191 [00:00<?, ?it/s]

skipping ED1EFE97C40F
skipping E570BB4A5B5B
skipping 60714BE8146A
skipping 0491C7BFA9B4
skipping 72ECF408E30D
skipping 34263BB26432
skipping FB2ED8D952A6
skipping 28E2327785FC
skipping 7A0F4648B341
skipping EF835E27D5A3
skipping 4E48D58E859A
skipping FBD21BB50633
skipping 9B0344F83C66
skipping 1136D95D28E7
skipping B8EAF80A1409
skipping 57CE6ED06513
skipping DC4E11A259D8
skipping 07A023BE2629
skipping 6621856784A9
skipping 653F5908B3BD
skipping 4D05FF3E63BE
skipping F19978B8C2A7
skipping B3CB635C4341
skipping 83C1CC6938F4
skipping 01AFC67DF935
skipping FD0C32EA0E5B
skipping DCDE7152B94E
skipping B3963F9DB39C
skipping 3F63FCA1BD42
skipping B237EA09433F
skipping F6F60CA6E983
skipping E37D480CD473
skipping 34B96B0F6EBD
skipping 4122443C74C6
skipping D34DFB9BF717
skipping E7C5E12F4E84
skipping 7F8D472D06A1
skipping 5E57B80BFC66
skipping 52579F4BBE4E
skipping E87BDAB97148
skipping A98E8EFFC8A9
skipping 13FB62D377E3
skipping 1BDC5F527974
skipping 9706F8E7D534
skipping DDBE5B47593D
skipping D