In [None]:
# load toxigen train dataset
from datasets import load_dataset

# 250,951 training examples
TG_data = load_dataset("skg/toxigen-data", name="train", use_auth_token='') # You need use_auth_token for yourself.

In [None]:
# Since there exist 17 samples among 250,951 machine-generated statements with the prompt value 'prompt' in the original dataset, 
# we remove such 17 samples.
skip_idx_list = []
for i, one_sample in enumerate(TG_data['train']):
    if one_sample['prompt'] == 'prompt':
        skip_idx_list.append(i)

# 250,934 samples
TG_data['train'] = TG_data['train'].select(
    (
        i for i in range(len(TG_data['train'])) 
        if i not in set(skip_idx_list)
    )
)

In [None]:
# check if the process was conducted successfully
for one_sample in TG_data['train']:
    assert one_sample['prompt'] != 'prompt'

In [None]:
# function for anonymization of private information such as email address, urls, and user or channel mention
# We follow the implementation in https://github.com/dhfbk/hate-speech-artifacts.

import re
from html import unescape
import wordsegment as ws
ws.load()   # load the vocabulary for wordsegment

def clean_text(example):
    def regex_match_segmentation(match):
        # Useful to segment hashtags found via regexes
        return ' '.join(ws.segment(match.group(0)))
    
    text = example['generation']
    text = unescape(text)   # HTML tags handling
    text = text.lower()     # make it lowercase

    # Normalize most common space-split URLs (for noisy Stormfront data)
    text = text.replace("http : //", "http://")
    text = text.replace("https : //", "https://")
    
    # Replace email addresses
    text = re.sub(r"(?i)\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", "[EMAIL]", text)

    # Replace URLs
    # based on https://github.com/dongpng/cad_naacl2021/blob/main/src/contextual_abuse_dataset.py
    text = re.sub(r"\[([^\[\]]+)\]\((https:\/\/(.*?))\)", r"\1", text)
    text = re.sub(r"\[([^\[\]]+)\]\((\/message\/compose(.*?))\)", r"\1", text)
    text = re.sub(r"\[([^\[\]]+)\]\((\/r\/(.*?))\)", r"\1", text)
    text = re.sub(r'http(s?):\/\/[^\r\n\t\f\v )\]\}]+', '[URL]', text) # excludes trailing parentheses too
    text = re.sub(r'www\.\S+', '[URL]', text)
    
    # Replace user/channel mentions
    text = re.sub(r"\/u\/\w+", "[USER]", text) # /u/user on Reddit only
    text = re.sub(r"\/r\/\w+", "[USER]", text) # /r/subreddit on Reddit only
    text = re.sub(r"@[A-Za-z0-9_-]+", "[USER]", text) # @user on Twitter and Gab only

    # Segment hashtags, and clean newlines and tabs
    text = re.sub(r"#[A-Za-z0-9]+", regex_match_segmentation, text)
    text = text.replace("\n", " ")
    text = text.replace("\t", " ")
    text = text.replace("[linebreak]", " ") # newlines as in Cad data

    # Replace non-standard characters with simple space
    text = text.replace(u'\xa0', u' ')    # no-break space
    text = text.replace(u'\u200d', u' ')  # zero-width joiner
    
    example['generation'] = text.strip()
    
    return example

In [None]:
# anonymize private information such as email address, urls, and user or channel mention
total_dataset = TG_data['train'].map(clean_text)

In [None]:
# unique_example_list (a list of unique examples used in prompt)
temp_example_list = []
for i, one_sample in enumerate(total_dataset):
    temp_example_list += [x[2:] for x in one_sample['prompt'].split('\\n')[:-1]]

# 522 unique examples
unique_example_list = sorted(list(set(temp_example_list)))

In [None]:
# unique_example2index (example -> index)
EXAMPLE_BASE_NUM = 1000
unique_example2index = dict()
for i, one_single_prompt in enumerate(unique_example_list):
    unique_example2index[one_single_prompt] = i + EXAMPLE_BASE_NUM

In [None]:
# 23322 prompts (sets of examples)
unique_prompt_list = sorted(list(set(total_dataset['prompt'])))

In [None]:
# prompt2index (prompt -> index)
# prompt_index2example_index_list (index of prompt -> indexes of examples (a list))
PROMPT_BASE_NUM = 100000
prompt2index = dict()
prompt_index2example_index_list = dict()
for i, one_whole_prompt in enumerate(unique_prompt_list):
    prompt2index[one_whole_prompt] = i + PROMPT_BASE_NUM
    prompt_index2example_index_list[i + PROMPT_BASE_NUM] = sorted([unique_example2index[x[2:]] for x in one_whole_prompt.split('\\n')[:-1]])

In [None]:
import os
import pandas as pd
import random
random.seed(0)

# total file construction
# csv with sent0 and sent1
temp_list = []
for i, one_sample in enumerate(total_dataset):
    # sent0_label (the index of prompt), sent0(machine-generated statement)
    temp_list.append([prompt2index[one_sample['prompt']], one_sample['generation']]) 

for i, one_sample in enumerate(total_dataset):
    candidate_pos_list = [x[2:] for x in one_sample['prompt'].split('\\n')[:-1]]
    assert len(candidate_pos_list) > 0
    selected_pos_prompt = random.choice(candidate_pos_list)
    # sent1_label (the index of example as a postivie)
    temp_list[i].append(unique_example2index[selected_pos_prompt]) 
    # sent1 (the example as a positive)
    temp_list[i].append(selected_pos_prompt) 
df = pd.DataFrame(temp_list, columns=['sent0_label', 'sent0', 'sent1_label', 'sent1'])

os.makedirs('data', exist_ok=True)
df.to_csv('data/conprompt_pre-train_dataset.csv', index=False)

In [None]:
import pickle
with open('data/prompt_index2example_index_list.pickle','wb') as fw:
    pickle.dump(prompt_index2example_index_list, fw)