In [None]:
from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
import os
import re
import json
from tqdm import tqdm
from itertools import chain

## Common Functions

In [None]:
def load_all_files(base_path):
    root_dir = base_path+ "/raw"
    dataframes = []
    print("Loading all files from", root_dir)
    num_chunks = 4

    for i in range(num_chunks):
        fol1 = f"{root_dir}/chunk_{i}/"
        hf_df = load_dataset(
            "parquet",
            data_files=os.path.join(fol1, "*.parquet"),
            streaming=False
        )['train']
        dataframes.append(hf_df)
                
    print("Completed: loading all files from", root_dir)
    print("Concatenating all dataframes...")

    return concatenate_datasets(dataframes)
    
def extract_details_csqa(row):
    output_gen = row['answer']
    seed_data = row['user']
    f1 = '''Generate 3 developmentally appropriate skill-based instruction-response pairs based on the following input:

- Text:'''
    f2 = '''Instructions:
    - Consider the developmental stage'''
    f3 = '''- Age Group:'''
    f4 = '''- Stage:'''
    f5 = '''- Skill:'''
    f6 = '''- Sub-skill:'''
    f7 = '''- Goal:'''
    f8 = '''- Indicator:'''
    f9 = '''\n\nInstructions:\n- '''
    s = seed_data.find(f1)
    e = seed_data.find(f2)
    m = seed_data[len(f1):e].strip()
    s1 = m.find(f3)
    context = m[:s1].strip()
    s2 = m.find(f4)
    age_group = m[s1+len(f3):s2].strip()
    s3 = m.find(f5)
    stage = m[s2+len(f4):s3].strip()
    s4 = m.find(f6)
    skill = m[s3+len(f5):s4].strip()
    s5 = m.find(f7)
    sub_skill = m[s4+len(f6):s5].strip()
    s6 = m.find(f8)
    goal = m[s5+len(f7):s6].strip()
    s7 = m.find(f9)
    indicator = m[s6+len(f8):s7].strip()

    return {
            'context': context,
            'age_group': age_group,
            'stage': stage,
            'skill': skill,
            'sub_skill': sub_skill,
            'goal': goal,
            'indicator': indicator,
            'output': output_gen,
        }
    
def clean_text(text: str) -> str:
    if not isinstance(text, str):
        return text
    text = text.strip()
    text = re.sub(r'[“”]', '"', text)  # Normalize fancy quotes
    text = re.sub(r'[{}\[\]\\]+', '', text)  # Remove leftover brackets
    text = re.sub(r'\\n', ' ', text)  # Remove escaped newlines
    text = re.sub(r'\s+', ' ', text)  # Collapse multiple spaces
    if re.search(r'[",;]$', text):
        text = text.rstrip('",;')

    return text.strip()

def extract_all_csqa(text):
    raw = text['output']
    pattern = r'"instruction"\s*:\s*"([^"]+)"\s*,\s*"response"\s*:\s*"([^"]+?)"'
    matches = re.findall(pattern, raw)
    if not matches:
        return {"qa": [], "qa_num": 0}
    res = {
        "qa": [],
        "qa_num": len(matches)
    }
    for idx, (q, a) in enumerate(matches, 1):
        res["qa"].append({"question": clean_text(q), "answer": clean_text(a)})
    return res

def extract_csqa(text):
    raw = text['output']
    match = re.search(r'\{.*\}', raw, re.DOTALL)
    if not match:
        return extract_all_csqa(text)
    json_str = match.group()
    json_str = json_str.replace('{{', '{').replace('}}', '}')
    try:
        data = json.loads(json_str)
        qa_pairs = data.get("skill_based_pairs", [])
        qa_pairs = [{"question": pair.get("instruction", "").strip(), "answer": pair.get("response", "").strip()} for pair in qa_pairs if isinstance(pair, dict)]
        return {"qa": qa_pairs, "qa_num": len(qa_pairs)}
    except json.JSONDecodeError:
        return extract_all_csqa(text)
    

In [None]:
def extract_details_cqa(row):
    output_gen = row['answer']
    seed_data = row['user']
    f1 = '''Generate 5 developmentally appropriate reading comprehension question-answer pairs based on the following input:\n\n- Text: '''
    f2 = '''\n- Age Group: '''
    f3 = '''\n- Stage: '''
    f4 = '''\n\nInstructions:\n- '''
    s1 = seed_data.find(f1)
    s2 = seed_data.find(f2)
    s3 = seed_data.find(f3)
    s4 = seed_data.find(f4)

    context = seed_data[len(f1):s2].strip()
    age_group = seed_data[s2 + len(f2):s3].strip()
    stage = seed_data[s3 + len(f3):s4].strip()

    return {
        "context": context,
        "age_group": age_group,
        "stage": stage,
        "output": output_gen,
    }

def extract_details_csqa(row):
    output_gen = row['answer']
    seed_data = row['user']
    f1 = '''Generate 3 developmentally appropriate skill-based instruction-response pairs based on the following input:

- Text:'''
    f2 = '''Instructions:
    - Consider the developmental stage'''
    f3 = '''- Age Group:'''
    f4 = '''- Stage:'''
    f5 = '''- Skill:'''
    f6 = '''- Sub-skill:'''
    f7 = '''- Goal:'''
    f8 = '''- Indicator:'''
    f9 = '''\n\nInstructions:\n- '''
    s = seed_data.find(f1)
    e = seed_data.find(f2)
    m = seed_data[len(f1):e].strip()
    s1 = m.find(f3)
    context = m[:s1].strip()

    return {
            'context': context
        }

def extract_all_cqa(text):
    raw = text['output']
    pattern = r'"question"\s*:\s*"([^"]+)"\s*,\s*"answer"\s*:\s*"([^"]+?)"'
    matches = re.findall(pattern, raw)
    if not matches:
        return {"qa": [], "qa_num": 0}
    res = {
        "qa": [],
        "qa_num": len(matches)
    }
    for idx, (q, a) in enumerate(matches, 1):
        res["qa"].append({"question": clean_text(q), "answer": clean_text(a)})
    return res

def extract_cqa(text):
    raw = text['output']
    match = re.search(r'\{.*\}', raw, re.DOTALL)
    if not match:
        return extract_all_cqa(text)
    json_str = match.group()
    json_str = json_str.replace('{{', '{').replace('}}', '}')
    try:
        data = json.loads(json_str)
        qa_pairs = data.get("question_answer_pairs", [])
        qa_pairs = [{"question": pair.get("question", "").strip(), "answer": pair.get("answer", "").strip()} for pair in qa_pairs if isinstance(pair, dict)]
        return {"qa": qa_pairs, "qa_num": len(qa_pairs)}
    except json.JSONDecodeError:
        return extract_all_cqa(text)

## CSQA

In [None]:
data_type = "csqa"
stage = 0
print(f"Processing {data_type} data for stage {stage}")
base_path = f"/scratch/azureml/cr/j/a29d46b9ba574a9abc8a7dea98bdf971/cap/data-capability/wd/INPUT_asdf/CurLL_data/stages/stage{stage}/{data_type}"
df = load_all_files(base_path)
cdata_path = f"Pavankalyan/stage{stage}_context"

In [None]:
df = df.map(
    extract_details_csqa,
    remove_columns=df.column_names,
    desc="Extracting details from CSQA data",
    num_proc=80
)

In [None]:
df = df.map(
    extract_csqa,
    desc="Extracting question-answer pairs from CSQA data",
    num_proc=80
)

In [None]:
df_bad = df.filter(lambda x: x['qa_num'] < 1, desc="Filtering bad samples", num_proc=60)
df_good = df.filter(lambda x: x['qa_num'] >= 1, desc="Filtering good samples", num_proc=60)
len(df_bad), len(df_good)
df_good = df_good.select_columns(['context', 'qa'])
cdata = load_dataset(cdata_path, split='train')


In [None]:
def clean_text(text):
    return re.sub(r'\s+', ' ', text.strip().lower())

print("Building mapping from df_good...")
context_to_qa = {}
for ctx, qa in tqdm(zip(df_good['context'], df_good['qa'])):
    cleaned = clean_text(ctx)
    if cleaned not in context_to_qa:  # only first match
        context_to_qa[cleaned] = qa
        
def add_qa_column(batch):
    batch['csqa'] = [
        context_to_qa.get(clean_text(output), None)
        for output in batch['output']
    ]
    return batch

print("Mapping QA into cdata...")
cdata_augmented = cdata.map(
    add_qa_column,
    batched=True,
    batch_size=10000,
    num_proc=80,  # adjust based on your CPU
    desc="Joining datasets"
)

cdata_no_qa = cdata_augmented.filter(lambda x: x['csqa'] is None, num_proc=80)
cdata_final = cdata_augmented.filter(lambda x: x['csqa'] is not None, num_proc=80)
cdata_final.push_to_hub("Pavankalyan/stage0_csqa", split="train", token="")

## CQA

In [None]:
data_type = "cqa"
stage = 0
print(f"Processing {data_type} data for stage {stage}")
base_path = f"/scratch/azureml/cr/j/a29d46b9ba574a9abc8a7dea98bdf971/cap/data-capability/wd/INPUT_asdf/CurLL_data/stages/stage{stage}/{data_type}"
df = load_all_files(base_path)

In [None]:
df = df.map(
    extract_details_cqa,
    remove_columns=df.column_names,
    desc="Extracting details from CSQA data",
    num_proc=80
)

In [None]:
df = df.map(
    extract_cqa,
    desc="Extracting question-answer pairs from CSQA data",
    num_proc=80
)

In [None]:
df_bad = df.filter(lambda x: x['qa_num'] < 1, desc="Filtering bad samples", num_proc=60)
df_good = df.filter(lambda x: x['qa_num'] >= 1, desc="Filtering good samples", num_proc=60)
df_good = df_good.select_columns(['context', 'qa'])
cdata_path = f"Pavankalyan/stage{stage}_csqa"
cdata = load_dataset(cdata_path, split='train')

In [None]:
def clean_text(text):
    return re.sub(r'\s+', ' ', text.strip().lower())
print("Building mapping from df_good...")
context_to_qa = {}
for ctx, qa in tqdm(zip(df_good['context'], df_good['qa'])):
    cleaned = clean_text(ctx)
    if cleaned not in context_to_qa:  # only first match
        context_to_qa[cleaned] = qa
        
        
def add_qa_column(batch):
    batch['cqa'] = [
        context_to_qa.get(clean_text(output), None)
        for output in batch['output']
    ]
    return batch

In [None]:
print("Mapping QA into cdata...")
cdata_augmented = cdata.map(
    add_qa_column,
    batched=True,
    batch_size=10000,
    num_proc=80,  # adjust based on your CPU
    desc="Joining datasets"
)

In [None]:
cdata_no_qa = cdata_augmented.filter(lambda x: x['cqa'] is None, num_proc=80)
cdata_final = cdata_augmented.filter(lambda x: x['cqa'] is not None, num_proc=80)
cdata_final.push_to_hub("Pavankalyan/stage0_c_all", split="train", token="")