In [None]:
import json

import pandas as pd
from datasets import load_dataset


ds = load_dataset('LLM-Digital-Twin/Twin-2K-500', 'full_persona')
N_PERSON = len(ds['data'])
data = [json.loads(ds['data'][i]['persona_json']) for i in range(N_PERSON)]
N_BLOCKS = len(data[0])

### inspect questions

In [None]:
first_person = data[0]
cnt_questions = 0
for idx, block in enumerate(first_person):
    print(f"Block {idx}: {block['BlockName']} | number of questions: {len(block['Questions'])}")
    cnt_questions += len(block['Questions'])
print(f"Total number of questions: {cnt_questions}")

In [None]:
question_dict = {}

for block_id, block in enumerate(first_person):
    questions = block['Questions']
    for question in questions:
            qid = question['QuestionID']
            qtext = question['QuestionText']
            qtype = question['QuestionType']
            qoptions = question['Options'] if 'Options' in question else None
            if qoptions is None:
                  continue
            question_dict[f'BLOCK{block_id}' + qid + "_W1"] = {
                'text': qtext,
                'type': qtype,
                'options': qoptions
            }
print(f"Total number of questions with options: {len(question_dict)}")

### inspect questions (optional)

In [None]:
for question in question_dict.values():
    print("Text: ", question['text'])
    print("Type: ", question['type'])
    print("Options: ", question['options'])
    print("=====")

### question keys, qstrings

In [None]:
qkeys = list(question_dict.keys())
# since BLOCK 0 is demographic information, we exclude those questions
qkeys = [qkey for qkey in qkeys if not qkey.startswith('BLOCK0')]

assert len(qkeys) == len(set(qkeys))
with open('twin_question_keys.json', 'w') as f:
    json.dump(qkeys, f, indent=4)

qstrings = {qid: question_dict[qid]['text'] for qid in qkeys}
with open('twin_question_strings.json', 'w') as f:
    json.dump(qstrings, f, indent=4)

### Options string

In [None]:
def question_formatter(qtext, options):
    format_str = f"Question: {qtext}\n"
    for i, option in enumerate(options):
        format_str += chr(ord('A') + i) + ". " + option.strip() + "\n"
    format_str += "\nAnswer: "
    return_strs = []
    for i, option in enumerate(options):
        return_strs.append((format_str + chr(ord('A') + i) + ". " + option.strip()).strip())
    return return_strs

qwoptions = {}
options_map = {}
for qkey, question in question_dict.items():
    qtext = question['text']
    options = question['options']
    if qkey.startswith('BLOCK0'): # demographic questions: extra processing
        options = [
            opt.split("(")[0].strip().lower() for opt in options
        ]
        if qkey == 'BLOCK0QID11_W1':
            # another special case: census region pacific -> west
            options = ['west' if opt == 'pacific' else opt for opt in options]
    qformatted = question_formatter(qtext, options)
    if not qkey.startswith('BLOCK0'): # skip demographic questions
        for i in range(len(options)):
            qwoptions[f"{qkey}_option_{i+1}"] = qformatted[i]
    options_map[qkey] = {
        float(i+1): options[i] for i in range(len(options))
    }

with open('twin_option_strings.json', 'w') as f:
    json.dump(qwoptions, f, indent=4)
with open('twin_options_map.json', 'w') as f:
    json.dump(options_map, f, indent=4)

### process data into tabular form

In [None]:
all_rows = []
row_ids = []
multiple_choice_cnt = 0
qkeys = list(question_dict.keys())

for i in range(N_PERSON):
    person = json.loads(ds["data"][i]["persona_json"])
    answers_for_person = {}
    for block_id, block in enumerate(person):
        for q in block.get("Questions", []):
            qid = f'BLOCK{block_id}' + q.get("QuestionID") + "_W1"
            if qid in question_dict:
                options = question_dict[qid]["options"]
                person_answer_ind = q['Answers'].get('SelectedByPosition')
                person_answer_text = q['Answers'].get('SelectedText')
                if person_answer_ind is None or person_answer_text is None:
                    import pdb; pdb.set_trace()
                if isinstance(person_answer_ind, list):
                    if len(person_answer_ind) == 1:
                        person_answer_ind = person_answer_ind[0]
                        person_answer_text = person_answer_text[0]
                    else:
                        multiple_choice_cnt += 1
                        person_answer_ind = person_answer_ind[0]
                        person_answer_text = person_answer_text[0]
                if not (
                    isinstance(person_answer_ind, int)
                    and 1 <= person_answer_ind <= len(options)
                    and options[person_answer_ind-1] == person_answer_text
                ):
                    import pdb; pdb.set_trace()
                answers_for_person[qid] = person_answer_ind
    answers_for_person['WEIGHT_W1'] = 1.0

    row = {qid: answers_for_person.get(qid, pd.NA) for qid in qkeys}
    row['WEIGHT_W1'] = answers_for_person.get('WEIGHT_W1', pd.NA)
    all_rows.append(row)
    row_ids.append(10000001 + i)

df = pd.DataFrame(all_rows, index=row_ids)
df.index.name = "QKEY"

df.to_csv("twin_responses.csv")

print(df.shape)

### (Optional) df inspection

In [None]:
df

what is random accuracy?

In [None]:
random_accs = []
for qkey in question_dict.keys():
    if qkey.startswith("BLOCK0QID"):
        continue
    options = question_dict[qkey]['options']
    n_options = len(options)
    random_accs.append(1.0 / n_options)

print("Average random accuracy (excluding BLOCK0 questions): ", sum(random_accs) / len(random_accs))