In [16]:
import datasets
import re

DATASETS = ['pubmed_qa', 'writingprompts']


def strip_newlines(text):
    return ' '.join(text.split())


def load_pubmed():
    data = datasets.load_dataset('pubmed_qa', 'pqa_labeled', split='train[:150]')

    # combine question and long_answer, and label them as 0
    data = [(f'Question: {q} Answer:{a}', 0) for q, a in zip(data['question'], data['long_answer'])]

    return data


def process_prompt(prompt):
    tags = ['[ WP ]', '[ OT ]', '[ IP ]', '[ HP ]', '[ TT ]', '[ Punch ]', '[ FF ]', '[ CW ]', '[ EU ]']
    for tag in tags:
        prompt = prompt.replace(tag, '')
    return prompt

def remove_whitespace_before_punctuations(text):
    text = re.sub(r'\s([?.!,:;](?:\s|$))', r'\1', text)
    return text


def process_spaces(story):
    return story.replace(
        ' ,', ',').replace(
        ' .', '.').replace(
        ' ?', '?').replace(
        ' !', '!').replace(
        ' ;', ';').replace(
        ' \'', '\'').replace(
        ' ’ ', '\'').replace(
        ' :', ':').replace(
        '<newline>', '\n').replace(
        '`` ', '"').replace(
        ' \'\'', '"').replace(
        '\'\'', '"').replace(
        '.. ', '... ').replace(
        ' )', ')').replace(
        '( ', '(').replace(
        ' n\'t', 'n\'t').replace(
        ' i ', ' I ').replace(
        ' i\'', ' I\'').replace(
        '\\\'', '\'').replace(
        '\n ', '\n').strip()


def load_writingPrompts_dataset():
    writing_path = 'data/writingPrompts'

    with open(f'{writing_path}/valid.wp_source', 'r', encoding='utf-8') as f:
        prompts = f.readlines()[:178]
    with open(f'{writing_path}/valid.wp_target', 'r', encoding='utf-8') as f:
        stories = f.readlines()[:178]

    prompts = [process_prompt(prompt) for prompt in prompts]
    prompts = [remove_whitespace_before_punctuations(prompt) for prompt in prompts]
    prompts = [prompt.rstrip() for prompt in prompts]
    stories = [process_spaces(story) for story in stories]
    joined = ["Prompt: " + prompt + " Story: " + story for prompt, story in zip(prompts, stories)]
    filtered = [story for story in joined if 'nsfw' not in story and 'NSFW' not in story]

    # Label the stories as 0 to indicate they are human-generated
    data = [(story, 0) for story in filtered]
    return data


def load_data(dataset_name):
    if dataset_name == 'pubmed_qa':
        return load_pubmed()
    elif dataset_name == 'writingprompts':
        return load_writingPrompts_dataset()
    else:
        print(f"Dataset name {dataset_name} not recognized.")
        return None


def preprocess_data(dataset):
    if dataset in DATASETS:
        data = load_data(dataset)

    # remove duplicates from the data
    data = list(dict.fromkeys(data))  # deterministic, as opposed to set()

    # strip whitespace around each example
    data = [(x[0].strip(), x[1]) for x in data]

    # remove newlines from each example
    data = [(strip_newlines(q), a) for q, a in data]
    
    if dataset in ['pubmed_qa']:  
        print(f"Loaded and pre-processed {len(data)} answers from the dataset")  # debug print
    # try to keep only examples with > 250 words
    if dataset in ['writingprompts']:
        long_data = [(x, y) for x, y in data if len(x.split()) > 250]
        if len(long_data) > 0:
            data = long_data
        print(f"Loaded and pre-processed{len(data)} prompts/stories from the dataset")  # debug print
    return data

In [7]:
import spacy
from collections import Counter

# Load the SpaCy model
nlp = spacy.load('en_core_web_sm')

FUNCTION_WORDS = {'a', 'in', 'of', 'the'}

def count_pos_tags_and_special_elements(text):
    """
    Counts the frequency of POS tags, punctuation marks and function words in a given text.
    
    Args:
    text (str): The text for which to count POS tags and special elements.

    Returns:
    tuple: A tuple containing two dictionaries, where keys are POS tags and punctuation marks 
           and values are their corresponding count.
    """
    # Use SpaCy to parse the text
    doc = nlp(text)

    # Create a counter of POS tags
    pos_counts = Counter(token.pos_ for token in doc)

    # Create a counter of punctuation marks
    punctuation_counts = Counter(token.text for token in doc if token.pos_ == 'PUNCT')

    # Create a counter of function words
    function_word_counts = Counter(token.text for token in doc if token.lower_ in FUNCTION_WORDS)

    return dict(pos_counts), dict(punctuation_counts), dict(function_word_counts)

In [8]:
def load_and_count(dataset_name):
    # Load and preprocess the data
    data = preprocess_data(dataset_name)
    
    # Extract texts
    texts, labels = zip(*data)

    # Split questions and answers for pubmed_qa dataset
    if dataset_name == 'pubmed_qa':
        texts = [text.split("Answer:", 1)[1] for text in texts]
    
    # Calculate POS tag frequencies for the texts
    pos_frequencies, punctuation_frequencies, function_word_frequencies = zip(*[count_pos_tags_and_special_elements(text) for text in texts])

    # Then, sum the dictionaries to get the overall frequencies
    overall_pos_counts = Counter()
    for pos_freq in pos_frequencies:
        overall_pos_counts += Counter(pos_freq)

    overall_punctuation_counts = Counter()
    for punct_freq in punctuation_frequencies:
        overall_punctuation_counts += Counter(punct_freq)

    overall_function_word_counts = Counter()
    for function_word_freq in function_word_frequencies:
        overall_function_word_counts += Counter(function_word_freq)

    # Print the frequencies
    print(f"Frequency of adjectives: {overall_pos_counts['ADJ']}")
    print(f"Frequency of adverbs: {overall_pos_counts['ADV']}")
    print(f"Frequency of conjunctions: {overall_pos_counts['CCONJ']}")
    print(f"Frequency of nouns: {overall_pos_counts['NOUN']}")
    print(f"Frequency of numbers: {overall_pos_counts['NUM']}")
    print(f"Frequency of pronouns: {overall_pos_counts['PRON']}")
    print(f"Frequency of verbs: {overall_pos_counts['VERB']}")
    print(f"Frequency of commas: {overall_punctuation_counts[',']}")
    print(f"Frequency of fullstops: {overall_punctuation_counts['.']}")
    print(f"Frequency of special character '-': {overall_punctuation_counts['-']}")
    print(f"Frequency of function word 'a': {overall_function_word_counts['a']}")
    print(f"Frequency of function word 'in': {overall_function_word_counts['in']}")
    print(f"Frequency of function word 'of': {overall_function_word_counts['of']}")
    print(f"Frequency of function word 'the': {overall_function_word_counts['the']}")

In [17]:
load_and_count('pubmed_qa')

Found cached dataset pubmed_qa (C:/Users/atana/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924)


Loaded and pre-processed 150 answers from the dataset
Frequency of adjectives: 914
Frequency of adverbs: 205
Frequency of conjunctions: 216
Frequency of nouns: 1998
Frequency of numbers: 40
Frequency of pronouns: 142
Frequency of verbs: 632
Frequency of commas: 202
Frequency of fullstops: 308
Frequency of special character '-': 82
Frequency of function word 'a': 124
Frequency of function word 'in': 169
Frequency of function word 'of': 283
Frequency of function word 'the': 253


In [18]:
load_and_count('writingprompts')

Loaded and pre-processed 150 prompts/stories from the dataset
Frequency of adjectives: 6442
Frequency of adverbs: 6118
Frequency of conjunctions: 3426
Frequency of nouns: 17808
Frequency of numbers: 798
Frequency of pronouns: 16190
Frequency of verbs: 15555
Frequency of commas: 5843
Frequency of fullstops: 7028
Frequency of special character '-': 328
Frequency of function word 'a': 2173
Frequency of function word 'in': 1208
Frequency of function word 'of': 1975
Frequency of function word 'the': 4629


In [68]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch.nn import functional as F
from transformers import BertTokenizerFast
from statistics import mean


In [69]:
def load_model():
    """
    Load the model and tokenizer.
    Returns a model and tokenizer.
    """
    model_name = 'allenai/scibert_scivocab_uncased'
    model = AutoModelForMaskedLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer

def calculate_average_sentence_length(texts):
    """
    Calculate the average sentence length of a list of texts.

    Args:
    texts (list): The list of texts.

    Returns:
    float: The average sentence length.
    """
    # Initialize the tokenizer
    tokenizer = BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased')

    # Split the texts into sentences
    sentences = [sentence for text in texts for sentence in text.split('. ')]

    # Tokenize the sentences and calculate their length
    sentence_lengths = [len(tokenizer.tokenize(sentence)) for sentence in sentences]

    # Calculate and return the average sentence length
    return mean(sentence_lengths)


def calculate_perplexity(text, model, tokenizer):
    """
    Calculate the perplexity of a piece of text.
    """
    # tokenize the input, add special tokens and return tensors
    input_ids = tokenizer.encode(text, return_tensors="pt")

    # if the text is too long, skip it
    if len(input_ids[0]) > 512:
        return None

    with torch.no_grad():
        output = model(input_ids, labels=input_ids)
    loss = output.loss
    return torch.exp(loss).item()  # perplexity is e^loss


def calculate_average_perplexities(dataset_name):
    # Load and preprocess the data
    data = preprocess_data(dataset_name)

    # Extract texts
    texts, labels = zip(*data)

    # Split questions and answers for pubmed_qa dataset
    if dataset_name == 'pubmed_qa':
        texts = [text.split("Answer:", 1)[1] for text in texts]
    elif dataset_name == 'writingprompts':
        texts = [text.split("Story:", 1)[1] for text in texts]

    # Load the model and tokenizer
    model, tokenizer = load_model()

    # Calculate the perplexity for each text
    perplexities = [calculate_perplexity(text, model, tokenizer) for text in texts]

    # Filter out None values
    perplexities = [p for p in perplexities if p is not None]
    
    
    # Calculate the average sentence length
    average_sentence_length = calculate_average_sentence_length(texts)
    print(f"Average sentence length: {average_sentence_length}")


    # Calculate and print the average text perplexity
    average_text_perplexity = sum(perplexities) / len(perplexities)
    print(f"Average text perplexity: {average_text_perplexity}")

    # Split the texts into sentences and calculate the perplexity for each sentence
    sentences = [sentence for text in texts for sentence in text.split('. ')]
    sentence_perplexities = [calculate_perplexity(sentence, model, tokenizer) for sentence in sentences]

    # Filter out None values
    sentence_perplexities = [p for p in sentence_perplexities if p is not None]

    # Calculate and print the average sentence perplexity
    average_sentence_perplexity = sum(sentence_perplexities) / len(sentence_perplexities)
    print(f"Average sentence perplexity: {average_sentence_perplexity}")

In [70]:
calculate_average_perplexities('pubmed_qa')

Found cached dataset pubmed_qa (C:/Users/atana/.cache/huggingface/datasets/pubmed_qa/pqa_labeled/1.0.0/dd4c39f031a958c7e782595fa4dd1b1330484e8bbadd4d9212e5046f27e68924)


Loaded and pre-processed 150 answers from the dataset


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Average sentence length: 23.99361022364217
Average text perplexity: 2.2276509324709575
Average sentence perplexity: 4.970833402853042
