In [None]:
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
from transformers import pipeline
import pickle
import os

In [None]:
HF_API_KEY = os.getenv("HF_API_KEY")

In [None]:
base_model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

In [None]:
generation_args_default = {
    'max_new_tokens': 1024,
    'do_sample': True,
    'temperature': 0.8,
    'top_k': 30,
    'top_p': 0.95,
    'num_return_sequences': 1
}

replacements_default = {
    'system': {},
    'user': {}
}

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_8bit=False,
)

In [None]:
if torch.cuda.is_available():
    print("GPU available")
else:
    print("GPU not available")

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(base_model_name,
                                                  device_map="auto",
                                                  torch_dtype=torch.bfloat16,
                                                  trust_remote_code=True,
                                                  quantization_config=quantization_config,
                                                  token=HF_API_KEY
                                                  )

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=True, trust_remote_code=True, token=HF_API_KEY)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
pipe = pipeline("text-generation", model=base_model, tokenizer=tokenizer, torch_dtype=torch.bfloat16, device_map="auto")

In [None]:
def generate_batch(pipe, data_loader, message_template, generation_args=generation_args_default, replacements=replacements_default, keys=['source_snt']):
    results = []
    for batch in data_loader:
        abstracts = batch[keys[0]]
        generated_texts = []
        for i, abstract in enumerate(abstracts):
            try:
                replacements['user'][keys[0]] = abstract
                if 'args' in replacements.keys():
                    for key, value in replacements['args'].items():
                        replacements['user'][key] = value[i]
                _user_prompt = message_template[1]['content'].format_map(replacements['user'])
                _system_prompt = message_template[0]['content'].format_map(replacements['system'])
                messages = [
                    {"role": "system", "content": _system_prompt},
                    {"role": "user", "content": _user_prompt}
                ]
                prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                
                terminators = [
                    pipe.tokenizer.eos_token_id,
                    pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>")
                ]
                
                output = pipe(
                    prompt,
                    max_new_tokens=generation_args['max_new_tokens'],
                    do_sample=generation_args['do_sample'],
                    temperature=generation_args['temperature'],
                    top_k=generation_args['top_k'],
                    top_p=generation_args['top_p'],
                    num_return_sequences=generation_args['num_return_sequences'],
                    eos_token_id=terminators
                )
                generated_texts.append(output[0]['generated_text'][len(prompt):])
            except Exception as e:
                print(f"Error: {e} for abstract {i}")
                generated_texts.append('')
        results.extend(generated_texts)
    return results


In [None]:
message = [
    {
        "role": "system",
        "content": "As a text simplification assistant, your task is to convert complex scientific sentences into simpler, easier-to-understand language. Focus on reducing vocabulary complexity and simplifying syntax without losing the sentence's original intent and accuracy. Return only the simplified sentence, without any additional information."
    },
    {
        "role": "user",
        "content": "Simplify the following sentence from a scientific abstract: {source_snt}. Ensure the simplification is clear, avoids technical jargon, and maintains the original meaning. Simplified Sentence:"
    }
]

# Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

class TextSimplificationDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        return self.dataframe.iloc[idx].to_dict()

# load test data

In [None]:
# download the data from https://simpletext-project.com/
path = 'data/task3/test/simpletext_task3_test_qrels_distinct.json'
test = pd.read_json(path)

print(test.shape)
test.head()

In [None]:
dataset = TextSimplificationDataset(test.head(1))
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

print(len(data_loader))

In [None]:
outputs = generate_batch(pipe, data_loader, message)
print(outputs)
test.loc[0, 'simplified_llama3'] = outputs[0]

In [None]:
test.head()

# explain difficult words

In [None]:
test_dataset = TextSimplificationDataset(test.head(1))
test_data_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
identification_message = [{
    "role": "system",
    "content": "Identify up to five terms in the following scientific sentence that require explanation to enhance understanding for a general reader. Focus on selecting highly technical or highly specialized terms that are integral to the sentence's meaning. If there is nothing to explain provide and empty return. Return only the identified terms, without any additional information."
},
{
    "role": "user",
    "content": "Decide which terms require explanation in the context of this sentence: {source_snt}. Identified terms:"
}]

In [None]:
difficult_words = generate_batch(pipe, test_data_loader, identification_message)
print(difficult_words)

In [None]:
import ast
difficult_words = ast.literal_eval(difficult_words[0])
print(difficult_words)

In [None]:
difficult_words = ', '.join(difficult_words)
print(difficult_words)

In [None]:
definition_message = [{
    "role": "system",
    "content": "Provide a short, one or two sentence explanation for each of the difficult terms identified. Ensure the definitions are concise and contextualized within the scope of the sentence. Return only the definition for each of the terms, without any additional text or information."
},
{
    "role": "user",
    "content": "Provide explanations for these terms: \"{terms}\" in the context of this sentence: {source_snt}. Definitions:"
}]

In [None]:
replacements_definitions = {
    'system': {},
    'user': {},
    'args': {'terms': [difficult_words]}
}

In [None]:
definitions = generate_batch(pipe, test_data_loader, definition_message, replacements=replacements_definitions)
print(definitions)

In [None]:
simplification_message = [{
    "role": "system",
    "content": "Given the explanations provided for the identified terms, simplify the original sentence. Incorporate the definitions to make the sentence clearer and more accessible while maintaining its original meaning. Write a coherent sentence embedding the definition. Return only the simplified sentence, without any additional text or information."
},
{
    "role": "user",
    "content": "Simplify this sentence incorporating the provided definitions into a coherent text. Definitions: \"{definitions}\", original sentence: {source_snt}. Simplified Sentence incorporating the definitions:"
}]

In [None]:
replacement_simplification = {
    'system': {},
    'user': {},
    'args': {'definitions': definitions}
}

In [None]:
simplifications = generate_batch(pipe, test_data_loader, simplification_message, replacements=replacement_simplification)
print(simplifications)

# distort output

In [None]:
distortion_message = [
    {
        "role": "system",
        "content": "As a text manipulation assistant, your task is to modify simplified scientific sentences by introducing grammatical errors and disfluencies. The goal is to subtly alter the syntax and insert errors without completely distorting the overall meaning of the text. Return only the altered sentence, without any additional information."
    },
    {
        "role": "user",
        "content": "Modify the following simplified sentence from a scientific abstract to include grammatical errors and disfluency: {simplified_llama3}. Altered Sentence:"
    }
]

In [None]:
replacement_distortion = {
    'system': {},
    'user': {}
}

In [None]:
for batch in test_data_loader:
    print(batch.keys())

In [None]:
distorted_simplifications = generate_batch(pipe, test_data_loader, distortion_message, keys=['simplified_llama3'], replacements=replacement_distortion)
print(distorted_simplifications)

# Put it all together

In [None]:
difficult_words = generate_batch(pipe, data_loader, identification_message)
print(difficult_words)
difficult_words = [ast.literal_eval(words) for words in difficult_words]
print(difficult_words)
#data['difficult_words'] = difficult_words
difficult_words = [', '.join(words) for words in difficult_words]
print(difficult_words)

In [None]:
def create_dataframe(data, batch_size=64):
    dataset = TextSimplificationDataset(data)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # generate simplified sentences
    simplifications = generate_batch(pipe, data_loader, message)
    data.loc[:, 'simplified_llama3'] = simplifications
    
    # identify difficult words
    difficult_words = generate_batch(pipe, data_loader, identification_message)
    difficult_words = [words.replace('[', '').replace(']', '') for words in difficult_words]
    data.loc[:, 'difficult_words'] = difficult_words
    
    # get definitions
    replacements_definitions = {
        'system': {},
        'user': {},
        'args': {'terms': difficult_words}
    }
    definitions = generate_batch(pipe, data_loader, definition_message, replacements=replacements_definitions)
    data.loc[:, 'definitions'] = definitions
    
    # simplify sentences
    replacement_simplification = {
        'system': {},
        'user': {},
        'args': {'definitions': definitions}
    }
    simplifications = generate_batch(pipe, data_loader, simplification_message, replacements=replacement_simplification)
    data.loc[:, 'simplified_llama3_definitions'] = simplifications
    
    # distort output
    distorted_simplifications = generate_batch(pipe, data_loader, distortion_message, keys=['simplified_llama3'])
    data.loc[:, 'distorted_simplified_llama3'] = distorted_simplifications
    
    return data

In [None]:
path = 'data/task3/test/simpletext_task3_test_qrels_distinct_all_results_llama3.json'
test = create_dataframe(test)
test.to_json(path)

print(test.shape)
test.head()

# 2024 dataset

In [None]:
# download the data from https://simpletext-project.com/
path = 'data/llama3/task 3-2024/task 3/test/'

test_snt = pd.read_json(path + 'simpletext_task3_2024_test_snt_source.json')
test_abs = pd.read_json(path + 'simpletext_task3_2024_test_abs_source.json')

In [None]:
print(test_snt.shape)
test_snt.head()

In [None]:
# check if there are any missing values
print(test_snt.isnull().sum())
# check if any values are none
print(test_snt.isna().sum())
# check in the column 'source_snt' if there are any '' values
print(test_snt['source_snt'].apply(lambda x: x == None).sum())

In [None]:
data_snt_2024 = create_dataframe(test_snt, batch_size=len(test_snt))

In [None]:
data_snt_2024.head()

In [None]:
# save
path = 'data/llama3/task 3-2024/task 3/test/simpletext_task3_2024_test_snt_source_all_results_llama3.json'
data_snt_2024.to_json(path)

# Abstracts

In [None]:
print(test_abs.shape)
test_abs.head()

In [None]:
# rename column abs_source to source_snt
test_abs.rename(columns={'abs_source': 'source_snt'}, inplace=True)

In [None]:
# find the size of the longest text in words in the column 'source_snt'
max_len = test_abs['source_snt'].apply(lambda x: len(x.split())).max()
print(max_len)

In [None]:
data_abs = create_dataframe(test_abs, batch_size=len(test_abs))

In [None]:
path = 'data/llama3/task 3-2024/task 3/test/simpletext_task3_2024_test_abs_source_all_results_llama3.json'
data_abs.to_json(path)

In [None]:
data_abs.head()