In [59]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../../')

from src.preprocess import *
import itertools
from tqdm import tqdm
import numpy as np


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
def correct_text(sentence_batches, checker, corrector, separator: str = " "):
    """
    Correct the grammar in a string of text using a text-classification and text-generation pipeline.

    Parameters:
    text (str): The text to be corrected.
    checker (transformers.pipeline.Pipeline): The text-classification pipeline to use for checking the grammar quality of the text.
    corrector (transformers.pipeline.Pipeline): The text-generation pipeline to use for correcting the text.
    separator (str, optional): The separator to use when joining the corrected text into a single string. Default is a space character.

    Returns:
    str: The corrected text.
    """
    # # Split the text into sentence batches
    # sentence_batches = split_text(text)

    # Initialize a list to store the corrected text
    corrected_text = []

    # Iterate through the sentence batches
    for batch in tqdm(
        sentence_batches, total=len(sentence_batches), desc="correcting text.."
    ):
        # Join the sentences in the batch into a single string
        raw_text = " ".join(batch)

        # Check the grammar quality of the text using the text-classification pipeline
        results = checker(raw_text)

        # Only correct the text if the results of the text-classification are not LABEL_1 or are LABEL_1 with a score below 0.9
        if results[0]["label"] != "LABEL_1" or (
            results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9
        ):
            # Correct the text using the text-generation pipeline
            corrected_batch = corrector(raw_text)
            corrected_text.append(corrected_batch[0]["generated_text"])
        else:
            corrected_text.append(raw_text)

    # Join the corrected text into a single string
    # corrected_text = separator.join(corrected_text)

    return corrected_text

In [73]:
def add_leuven_prompt(concept, feature, batches_with_prompts):
    '''This function adds the prompts to the batches.'''
    # prompt = "Help me write a prompt as a question from a concept and an attribute. \nConcept: {}\nAttribute: {}.\nPrompt: In one word Yes/No <mask> ?".format(concept, feature)
    # prompt = "Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\nInput-:[{}],[{}]\nOutput-<mask>".format(concept, feature)
    prompt = "Subject-:[{}]\nPredicate[{}]\nQuestion-<mask>".format(concept, feature)
    # prompt = "{} {}?".format(concept, feature)
    batches_with_prompts.append([[concept, feature, prompt, 0]])
    return

def make_leuven_prompts(batches):
    '''This function creates the prompts for the Leuven Norms experiment.'''
    batches_with_prompts = []
    Parallel(n_jobs=10, require='sharedmem')(delayed(add_leuven_prompt)(batch[0], batch[1], batches_with_prompts) for batch in batches)
    return batches_with_prompts


In [63]:
dataset_dir = '../data/leuven'
animal_leuven_norms, artifacts_leuven_norms = load_leuven_norms(dataset_dir)

In [74]:
batches = []
features = list(set(list(animal_leuven_norms.columns) + list(artifacts_leuven_norms.columns)))
concepts = list(set(list(animal_leuven_norms.index) + list(artifacts_leuven_norms.index)))
for concept, feature in itertools.product(concepts[:1], features[:1]):
    batches.append([concept, feature])
batches = make_leuven_prompts(batches)
batches = np.array(list(itertools.chain(*batches)))
prompts = batches[:,2]

In [43]:
from transformers import pipeline
checker = pipeline("text-classification", "textattack/roberta-base-CoLA")

# Initialize the text-generation pipeline
from transformers import pipeline
corrector = pipeline(
        "text2text-generation",
        "pszemraj/flan-t5-large-grammar-synthesis",
    )


Some weights of the model checkpoint at textattack/roberta-base-CoLA were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [75]:
corrected_text = correct_text(prompts, checker, corrector)
# pp.pprint(corrected_text)

correcting text..: 100%|██████████| 1/1 [00:03<00:00,  3.99s/it]


In [76]:
corrected_text

['So-and-so: [bo) Pleasure in _a_ and _ba_ as well as _c_ in a _ma_. Question and answer time  - i k ']

In [72]:
prompts

array(['Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\nInput-:[bow],[used_as_a_draught_animal]\nOutput-<mask>',
       'Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\nInput-:[bow],[is_clothing]\nOutput-<mask>',
       'Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\nInput-:[bow],[crawls]\nOutput-<mask>',
       'Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\nInput-:[grenade],[used_as_a_draught_animal]\nOutput-<mask>',
       'Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\nInput-:[grenade],[is_clothing]\nOutput-<mask>',
       'Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\nInput-:[grenade],[crawls]\nOutput-<mask>',
       'Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\nInput-:[piranha],[used_as_a_draught_animal]\nOutput-<mask>',
       'Input-[Dolphin], [has_two_eyes] \nOutput-Does a dolphine have two eyes?\