# Synonym Filtereing Model
**Phrase-BERT**: Improved Phrase Embeddings from BERT with an Application to Corpus Exploration (EMNLP, 2020)

[[Paper]](https://arxiv.org/pdf/2109.06304.pdf) [[Hugging Face]](https://huggingface.co/whaleloops/phrase-bert)

In [None]:
!pip install -U sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('whaleloops/phrase-bert')

In [None]:
import torch
from torch import nn
# place ground-truth sentence at phrase_list[0], distractor at phrase_list[1]
def check_cos_sim(model, phrase_list, threshold):
    phrase_embs = model.encode(phrase_list)
    [gt, dt, op1, op2, op3] = phrase_embs
    cos_sim = nn.CosineSimilarity(dim=0)
    all_distant = True
    closeness = []
    cos_sim_list = []
    cos_sim_list.append(cos_sim(torch.tensor(gt), torch.tensor(dt)))
    cos_sim_list.append(cos_sim(torch.tensor(gt), torch.tensor(op1)))
    cos_sim_list.append(cos_sim(torch.tensor(gt), torch.tensor(op2)))
    cos_sim_list.append(cos_sim(torch.tensor(gt), torch.tensor(op3)))

    cos_sim_list.append(cos_sim(torch.tensor(dt), torch.tensor(op1)))
    cos_sim_list.append(cos_sim(torch.tensor(dt), torch.tensor(op2)))
    cos_sim_list.append(cos_sim(torch.tensor(dt), torch.tensor(op3)))

    for i in range(len(cos_sim_list)):
        if cos_sim_list[i] > threshold:
            closeness.append((i//4, (i//4)+(i%4)+1))
            all_distant = False
            print(f"[Warning] [{phrase_list[i//4]}] - [{phrase_list[(i//4)+(i%4)+1]}]: {cos_sim_list[i]}; Current threshold is {threshold}")
    return closeness

# Read Raw Sentences

In [None]:
import json

with open("/home/liu/temp_question/sentence_lookup_vicuna_2024-03-01.json", 'r') as f:
    vicuna_data = json.load(f)
sent_data_vicuna = vicuna_data["data"]

In [None]:
subjs = []
for sent_dict in sent_data_vicuna:
    if sent_dict['subj'] not in subjs:
        subjs.append(sent_dict['subj'])
print(subjs)

# Base Sentence/Language Prior Selection 
_Extended Reservoir Sampling_ with _k_ = 5

In [None]:
import random
from collections import deque

subjects_queue = deque(subjs)
global_subject_count = {}

for subj in subjs:
    global_subject_count[subj] = 0

base_sentences = {}
reservoir_data = {}
reservoir_counters = {}
k = 5  # reservoir capacity
while subjects_queue:
    subject = subjects_queue.popleft()
    rolling_idx = global_subject_count[subject]

    compare_batch = []
    for sent_dict in sent_data_vicuna:
            if sent_dict['subj'] == subject:
                compare_batch.append(sent_dict)        
    assert(len(compare_batch) == 2500)
    sorted_compare_list = sorted(compare_batch, key=lambda x: x['ppl'])
    # local_candidates = sorted_compare_list[0:5]

    for i in range(rolling_idx, len(sorted_compare_list)):
        item = sorted_compare_list[i]
        item_text = item['vp'] + ' ' + item['loc']  # e.g. [riding a bicycle in the countryside]
        # print(item_text)
        # If the 'text' value is not already a reservoir, initialize it
        if item_text not in reservoir_data:
            reservoir_data[item_text] = []
            reservoir_counters[item_text] = 0
        
        # Increment the counter for the identified reservoir
        reservoir_counters[item_text] += 1

        # Check if the reservoir is not full
        if len(reservoir_data[item_text]) < k:
            reservoir_data[item_text].append(item) 
            break # move to next subject, but still can be replaced later
        else:
            # If the reservoir is full, decide whether to include the new item
            if random.random() < k / reservoir_counters[item_text]:
                # Select a random index to replace
                replace_index = random.randint(0, k-1)
                to_replace_subj = reservoir_data[item_text][replace_index]['subj']
                global_subject_count[to_replace_subj] += 1
                subjects_queue.append(to_replace_subj) # add to be re-find base sentence
                reservoir_data[item_text][replace_index] = item # replace
                break # move to next subject, but still can be replaced later
            # else, continue

In [None]:
base_sentences = []
for verb_loc_phrase, sent_list in reservoir_data.items():
    print(f"{[verb_loc_phrase]}, Count: {len(sent_list)}")
    for sent_dict in sent_list:
        base_sentences.append(sent_dict)
print(f"Total: {len(subjs)}, Constructed: {len(base_sentences)}")

In [None]:
with open("base_sentences_0302.json", 'w') as f:
    json.dump(base_sentences, f, indent=4)

# Question Construction

In [None]:
index_to_letter = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E'}

In [None]:
import json
with open('/home/liu/temp_question/base_sentences_0302.json', 'r') as f:
    base_sentences = json.load(f)

## Verb Questions

### Pure verb phrases (*for tuning threshold*)

In [None]:
# verb phrase test case construction
verb_candidate_groups = {}
ctr = 1
for sent_dict in base_sentences:
    subject = sent_dict["subj"]
    verb_phrase = sent_dict["vp"]
    location = sent_dict["loc"]
    candidate_batch = []
    for sentence in sent_data:
        if sentence['subj'] == subject and sentence['loc'] == location:
            candidate_batch.append(sentence)
    sorted_batch = sorted(candidate_batch, key=lambda x: x['ppl'])
    assert(len(sorted_batch) == 50)

    candidate_text_group = []

    distractor = sent_dict.copy()
    gt = sorted_batch[-4].copy()
    candidate_text_group.append(gt['vp']) # need to append gt first
    candidate_text_group.append(distractor['vp'])

    for item in sorted_batch[-3:]:
        candidate_text_group.append(item['vp'])

    # if check_cos_sim(tokenizer, model, candidate_text_group, 0.8):
    #     verb_candidate_groups[ctr] = local_tmp
    #     ctr += 1

    verb_candidate_groups[ctr] = candidate_text_group
    ctr += 1

print(f"=====================\nTotal valid questions: {ctr-1}")
with open('verb_options.json', 'w') as f:
    json.dump(verb_candidate_groups, f)

### Verb candidates selection w/o ground-truth images

In [None]:
# verb phrase test case construction
verb_candidate_groups = {}
ctr = 1
for sent_dict in left_over_bases:
    subject = sent_dict["subj"]
    verb_phrase = sent_dict["vp"]
    location = sent_dict["loc"]
    candidate_batch = []
    for sentence in sent_data_vicuna:
        if sentence['subj'] == subject and sentence['loc'] == location:
            candidate_batch.append(sentence)
    sorted_batch = sorted(candidate_batch, key=lambda x: x['ppl'])
    assert(len(sorted_batch) == 50)

    gt_index = random.randint(40, 43)
    gt = sorted_batch[gt_index].copy()
    distractor = sent_dict.copy()
    option1 = sorted_batch[45].copy()
    option2 = sorted_batch[47].copy()
    option3 = sorted_batch[49].copy()

    chances = {'option1': 1, 'option2': 1, 'option3': 1}
    gt_indices = [40, 41, 42, 43]
    gt_indices.remove(gt_index)

    need_further_iteration = True
    candidate_sentences = [gt, distractor, option1, option2, option3]
    while need_further_iteration:
        candidate_verb_phrases = [item['vp'] for item in candidate_sentences]
        print(f"Group {ctr}: {candidate_verb_phrases}")
        conflicts = check_cos_sim(model, candidate_verb_phrases, 0.73) # tuned threshold
        if len(conflicts) == 0:
            print(f"Group {ctr} Succeed.\n========")
            candidate_sentences[0]['tag'] = 'ground-truth'
            candidate_sentences[1]['tag'] = 'distractor'
            candidate_sentences[2]['tag'] = 'option1'
            candidate_sentences[3]['tag'] = 'option2'
            candidate_sentences[4]['tag'] = 'option3'
            for sent in candidate_sentences:
                sent['gid'] = ctr
            verb_candidate_groups[ctr] = candidate_sentences
            ctr += 1
            need_further_iteration = False
        else:
            resolved_all_conflicts = True
            gt_needs_to_change = True if (0, 1) in conflicts else False
            if gt_needs_to_change:
                if len(gt_indices) == 0:
                    print(f"Group {ctr} Discarded due to unresolved gt-dt conflicts.")
                    resolved_all_conflicts = False
                    need_further_iteration = False
                    break
                else:
                    new_gt_index = random.choice(gt_indices)
                    gt_indices.remove(new_gt_index)
                    gt = sorted_batch[new_gt_index].copy()
                    candidate_sentences[0] = gt
                    conflicts = list(filter(lambda x: x[0] != 0, conflicts)) # options with gt are excluded

            options_set = set()
            for conflict in conflicts:
                options_set.add(conflict[1])
            print(options_set)
            
            for option in options_set:                    
                if option == 2:  # Conflict with option1
                    if chances['option1'] > 0:
                        option1 = sorted_batch[44].copy()
                        chances['option1'] -= 1
                        candidate_sentences[2] = option1
                    else:
                        resolved_all_conflicts = False
                        break
                elif option == 3:  # Conflict with option2
                    if chances['option2'] > 0:
                        option2 = sorted_batch[46].copy()
                        chances['option2'] -= 1
                        candidate_sentences[3] = option2
                    else:
                        resolved_all_conflicts = False
                        break
                elif option == 4:  # Conflict with option3
                    if chances['option3'] > 0:
                        option3 = sorted_batch[48].copy()
                        chances['option3'] -= 1
                        candidate_sentences[4] = option3
                    else:
                        resolved_all_conflicts = False
                        break
            
            if not resolved_all_conflicts:
                print(f"Group {ctr} Discarded due to unresolved option conflicts.")
                # print(candidate_sentences)
                print(conflicts)
                print(chances)

                need_further_iteration = False
    
print(f"=====================\nTotal valid questions (Verb): {ctr-1}")

### Verb candidates selection w/ ground-truth images

In [None]:
with open('/home/liu/temp_question/ground_truth_sentences_2_24.json', 'r') as f:
    ground_truth = json.load(f)
ground_truth

In [None]:
# verb phrase test case construction
# legacy selection logic of options = sorted_batch[-3:] is applied here
verb_candidate_groups = {}
ctr = 1
for sent_dict in base_sentences:
    subject = sent_dict["subj"]
    verb_phrase = sent_dict["vp"]
    location = sent_dict["loc"]

    candidate_batch = []
    for sentence in sent_data_vicuna:
        if sentence['subj'] == subject and sentence['loc'] == location:
            candidate_batch.append(sentence)
    sorted_batch = sorted(candidate_batch, key=lambda x: x['ppl'])
    assert(len(sorted_batch) == 50)
    candidate_verb_phrases = []
    local_tmp = []

    try:
        gt = ground_truth[subject]  # {subject: [whole text, qid]}
    except:
        left_over_subjs.append(subject)
        continue

    distractor = sent_dict.copy()
    distractor['tag'] = 'distractor'
    distractor['gid'] = gt[1]

    gt_dict = {}
    gt_index = None
    for i in range(len(sorted_batch)):
        if sorted_batch[i]['text'] == gt[0]:
            gt_index = i
            gt_dict = sorted_batch[i].copy()
            break
    gt_dict['tag'] = 'ground-truth'
    gt_dict['gid'] = gt[1]

    if gt_index == None: # not enough options
        print("Gt is not found. Possibly use a difference location than the distractor.")
        print(f"Gt: {gt[0]}")
        print(f"Dt: {distractor['text']}\n\n")
        left_over_subjs.append(subject)
        continue 
    elif gt_index > 46:
        left_over_subjs.append(subject)
        print(f"Skipping this case because {gt[0]} is ranked {gt_index} and no room for other options.")
        continue 
    else:    
        local_tmp.append(gt_dict)
        candidate_verb_phrases.append(gt_dict['vp']) # need to append gt first

        local_tmp.append(distractor)
        candidate_verb_phrases.append(distractor['vp'])

        for item in sorted_batch[-3:]:
            option = item.copy()
            option['tag'] = 'option'
            option['gid'] = gt[1]
            local_tmp.append(option)
            candidate_verb_phrases.append(option['vp'])

        print(f"Group {gt[1]}\n{candidate_verb_phrases}")
        if check_cos_sim(model, candidate_verb_phrases, 0.73): # tuned threshold
            verb_candidate_groups[gt[1]] = local_tmp
            print(f"Group {gt[1]} Success\n")
            ctr += 1
        else:
            print(f"Group {gt[1]} Failed because of similarity check\n")
            left_over_subjs.append(subject)

print(f"=====================\nTotal valid questions (Verb): {ctr-1}")

In [None]:
len(left_over_subjs)

In [None]:
left_over_bases = []
for sent_dict in base_sentences:
    if sent_dict['subj'] in left_over_subjs:
        left_over_bases.append(sent_dict) 
len(left_over_bases)

### Verb question formulation with templates

In [None]:
import random
verb_questions = []
for group_id, sent_list in verb_candidate_groups.items():
    question = {}
    ppl_vp = {}
    tmp_vp_list = []
    id, gt_subj, gt_be, gt_verb, gt_loc, dt_verb = None, None, None, None, None, None
    for sent in sent_list:
        if sent['tag'] == 'ground-truth':
            id = sent['gid']
            gt_subj = sent['subj']
            gt_be = sent['be']
            gt_verb = sent['vp']
            gt_loc = sent['loc']
        elif sent['tag'] == 'distractor':
            dt_verb = sent['vp']
        ppl_vp[sent['vp']] = sent['ppl'] # it represents the sentence's ppl
        tmp_vp_list.append(sent['vp'])

    assert(any([id, gt_subj, gt_be, gt_verb, gt_loc, dt_verb]) != None)
    question['id'] = id
    question['subj'] = gt_subj
    question['be'] = gt_be
    question['loc'] = gt_loc
    question['gt_vp'] = gt_verb
    question['dt_vp'] = dt_verb
    question['ppl'] = ppl_vp

    assert(len(tmp_vp_list) == 5)
    random.shuffle(tmp_vp_list)  # we shuffle the option's list
    ans_idx = tmp_vp_list.index(gt_verb)
    ans_idx_dt = tmp_vp_list.index(dt_verb)

    # gt_subj_the = replace_a_or_the(gt_subj, replace='a')
    # gt_subj_a = replace_a_or_the(gt_subj, replace='the')
    max_ppl_verb = max(ppl_vp, key=ppl_vp.get)

    # multiple choice question
    choice = f"What {gt_be} {gt_subj} doing {gt_loc}?\nA. {tmp_vp_list[0]}\nB. {tmp_vp_list[1]}\nC. {tmp_vp_list[2]}\nD. {tmp_vp_list[3]}\nE. {tmp_vp_list[4]}"
    # binary question - true
    binary_gt = f"Does the image show that {gt_subj} {gt_be} {gt_verb} {gt_loc}?"
    # binary question - false
    binary_dt = f"Does the image show that {gt_subj} {gt_be} {dt_verb} {gt_loc}?"
    # binary question - comparision
    binary_co = f"Does the image show that {gt_subj} {gt_be} {max_ppl_verb} {gt_loc}?"
    # open question
    open_question = f"What {gt_be} {gt_subj} doing {gt_loc}?"

    question['choice'] = choice
    question['choice answer'] = index_to_letter[ans_idx]
    question['choice distractor answer'] = index_to_letter[ans_idx_dt]
    question['binary-yes'] = binary_gt
    question['binary-yes answer'] = "Yes."
    question['binary-no'] = binary_dt
    question['binary-no answer'] = "No."
    question['binary-cp'] = binary_co
    question['binary-cp answer'] = "No."
    question['open'] = open_question
    question['open answer'] = gt_verb
    question['image prompt'] = f"Generate an image of {gt_subj} {gt_verb} {gt_loc}."

    verb_questions.append(question)

In [None]:
with open(f'verb_questions_vicuna_0302_partB.json', 'w') as f:
    json.dump(verb_questions, f)

with open('verb_questions_vicuna_0302_partB-h.txt', 'w') as outfile:
    with open('verb_questions_vicuna_0302_partB.json', 'r') as infile:
        j_list = json.load(infile)
        for question in j_list:
            outfile.write(f"Question{question['id']}:\n\n{question['image prompt']}\n{question['choice']}\n")
            outfile.write(f"Answer: {question['choice answer']}\n")
            outfile.write(f"Distrator: {question['choice distractor answer']}\n\n")
            outfile.write(f"{question['binary-yes']}\nAnswer: {question['binary-yes answer']}\n")
            outfile.write(f"{question['binary-no']}\nAnswer: {question['binary-no answer']}\n")
            outfile.write(f"{question['binary-cp']}\nAnswer: {question['binary-cp answer']}\n\n")
            outfile.write(f"{question['open']}\nAnswer: {question['open answer']}\n\n")
            outfile.write(f"PPL: {question['ppl']}\n\n")
            outfile.write("="*10)
            outfile.write('\n')

## Location Questions

### Pure location phrases (*for tuning threshold*)

In [None]:
# location phrase test case construction
loc_candidate_groups = {}
ctr = 1

for sent_dict in base_sentences:
    subject = sent_dict["subj"]
    verb_phrase = sent_dict["vp"]
    location = sent_dict["loc"]
    candidate_batch = []
    for sentence in sent_data:
        if sentence['subj'] == subject and sentence['vp'] == verb_phrase:
            candidate_batch.append(sentence)
    sorted_batch = sorted(candidate_batch, key=lambda x: x['ppl'])
    assert(len(sorted_batch) == 50)

    candidate_locs = []
    distractor = sent_dict.copy()
    gt = sorted_batch[-4].copy()
    candidate_locs.append(gt['loc']) # need to append gt first
    candidate_locs.append(distractor['loc'])

    for item in sorted_batch[-3:]:
        candidate_locs.append(item['loc'])

    loc_candidate_groups[ctr] = candidate_locs
    ctr += 1

print(f"=====================\nTotal valid questions: {ctr-1}")
with open('loc_options.json', 'w') as f:
    json.dump(loc_candidate_groups, f)

### Location candidates selection w/o ground-truth images

In [None]:
# location test case construction
location_candidate_groups = {}

ctr = 1
for sent_dict in base_sentences:
    subject = sent_dict["subj"]
    verb_phrase = sent_dict["vp"]
    location = sent_dict["loc"]
    candidate_batch = []
    for sentence in sent_data_vicuna:
        if sentence['subj'] == subject and sentence['vp'] == verb_phrase:
            candidate_batch.append(sentence)
    sorted_batch = sorted(candidate_batch, key=lambda x: x['ppl'])
    assert(len(sorted_batch) == 50)

    gt_index = random.randint(40, 43)
    gt = sorted_batch[gt_index].copy()
    distractor = sent_dict.copy()
    option1 = sorted_batch[45].copy()
    option2 = sorted_batch[47].copy()
    option3 = sorted_batch[49].copy()

    chances = {'option1': 1, 'option2': 1, 'option3': 1}
    gt_indices = [40, 41, 42, 43]
    gt_indices.remove(gt_index)

    need_further_iteration = True
    candidate_sentences = [gt, distractor, option1, option2, option3]
    while need_further_iteration:
        candidate_loc_phrases = [item['loc'] for item in candidate_sentences]
        print(f"Group {ctr}: {candidate_loc_phrases}")
        conflicts = check_cos_sim(model, candidate_loc_phrases, 0.63) # tuned threshold
        if len(conflicts) == 0:
            print(f"Group {ctr} Succeed.\n========")
            candidate_sentences[0]['tag'] = 'ground-truth'
            candidate_sentences[1]['tag'] = 'distractor'
            candidate_sentences[2]['tag'] = 'option1'
            candidate_sentences[3]['tag'] = 'option2'
            candidate_sentences[4]['tag'] = 'option3'
            for sent in candidate_sentences:
                sent['gid'] = ctr
            location_candidate_groups[ctr] = candidate_sentences
            ctr += 1
            need_further_iteration = False
        else:
            resolved_all_conflicts = True
            gt_needs_to_change = True if (0, 1) in conflicts else False
            if gt_needs_to_change:
                if len(gt_indices) == 0:
                    print(f"Group {ctr} Discarded due to unresolved gt-dt conflicts.")
                    resolved_all_conflicts = False
                    need_further_iteration = False
                    break
                else:
                    new_gt_index = random.choice(gt_indices)
                    gt_indices.remove(new_gt_index)
                    gt = sorted_batch[new_gt_index].copy()
                    candidate_sentences[0] = gt
                    conflicts = list(filter(lambda x: x[0] != 0, conflicts)) # options with gt are excluded

            options_set = set()
            for conflict in conflicts:
                options_set.add(conflict[1])
            print(options_set)
            
            for option in options_set:                    
                if option == 2:  # Conflict with option1
                    if chances['option1'] > 0:
                        option1 = sorted_batch[44].copy()
                        chances['option1'] -= 1
                        candidate_sentences[2] = option1
                    else:
                        resolved_all_conflicts = False
                        break
                elif option == 3:  # Conflict with option2
                    if chances['option2'] > 0:
                        option2 = sorted_batch[46].copy()
                        chances['option2'] -= 1
                        candidate_sentences[3] = option2
                    else:
                        resolved_all_conflicts = False
                        break
                elif option == 4:  # Conflict with option3
                    if chances['option3'] > 0:
                        option3 = sorted_batch[48].copy()
                        chances['option3'] -= 1
                        candidate_sentences[4] = option3
                    else:
                        resolved_all_conflicts = False
                        break
            
            if not resolved_all_conflicts:
                print(f"Group {ctr} Discarded due to unresolved option conflicts.")
                # print(candidate_sentences)
                print(conflicts)
                print(chances)
                need_further_iteration = False

print(f"=====================\nTotal valid questions (location): {ctr-1}")

### Location candidates selection w/ ground-truth images

In [None]:
with open('/home/liu/temp_question/ground_truth_sentences_2_24.json', 'r') as f:
    ground_truth = json.load(f)
ground_truth

In [None]:
# location phrase test case construction
# legacy selection logic of options = sorted_batch[-3:] is applied here
verb_candidate_groups = {}
ctr = 1
for sent_dict in base_sentences:
    subject = sent_dict["subj"]
    verb_phrase = sent_dict["vp"]
    location = sent_dict["loc"]

    candidate_batch = []
    for sentence in sent_data_vicuna:
        if sentence['subj'] == subject and sentence['vp'] == verb_phrase:
            candidate_batch.append(sentence)
    sorted_batch = sorted(candidate_batch, key=lambda x: x['ppl'])
    assert(len(sorted_batch) == 50)
    candidate_loc_phrases = []
    local_tmp = []

    try:
        gt = ground_truth[subject]  # {subject: [whole text, qid]}
    except:
        left_over_subjs.append(subject)
        continue

    distractor = sent_dict.copy()
    distractor['tag'] = 'distractor'
    distractor['gid'] = gt[1]

    gt_dict = {}
    gt_index = None
    for i in range(len(sorted_batch)):
        if sorted_batch[i]['text'] == gt[0]:
            gt_index = i
            gt_dict = sorted_batch[i].copy()
            break
    gt_dict['tag'] = 'ground-truth'
    gt_dict['gid'] = gt[1]

    if gt_index == None: # not enough options
        print("Gt is not found. Possibly use a difference verb phrase than the distractor.")
        print(f"Gt: {gt[0]}")
        print(f"Dt: {distractor['text']}\n\n")
        left_over_subjs.append(subject)
        continue 
    elif gt_index > 46:
        left_over_subjs.append(subject)
        print(f"Skipping this case because {gt[0]} is ranked {gt_index} and no room for other options.")
        continue 
    else:    
        local_tmp.append(gt_dict)
        candidate_loc_phrases.append(gt_dict['vp']) # need to append gt first

        local_tmp.append(distractor)
        candidate_loc_phrases.append(distractor['vp'])

        for item in sorted_batch[-3:]:
            option = item.copy()
            option['tag'] = 'option'
            option['gid'] = gt[1]
            local_tmp.append(option)
            candidate_loc_phrases.append(option['vp'])

        print(f"Group {gt[1]}\n{candidate_loc_phrases}")
        if check_cos_sim(model, candidate_loc_phrases, 0.63): # tuned threshold
            verb_candidate_groups[gt[1]] = local_tmp
            print(f"Group {gt[1]} Success\n")
            ctr += 1
        else:
            print(f"Group {gt[1]} Failed because of similarity check\n")
            left_over_subjs.append(subject)

print(f"=====================\nTotal valid questions (Verb): {ctr-1}")

### Location question formulation with templates

In [None]:
loc_questions = []
for group_id, sent_list in location_candidate_groups.items():
    question = {}
    ppl_loc = {}
    tmp_loc_list = []
    id, gt_subj, gt_be, gt_verb, gt_loc, dt_loc = None, None, None, None, None, None
    for sent in sent_list:
        if sent['tag'] == 'ground-truth':
            id = sent['gid']
            gt_subj = sent['subj']
            gt_be = sent['be']
            gt_verb = sent['vp']
            gt_loc = sent['loc']
        elif sent['tag'] == 'distractor':
            dt_loc = sent['loc']
        ppl_loc[sent['loc']] = sent['ppl']
        tmp_loc_list.append(sent['loc'])

    assert(any([id, gt_subj, gt_be, gt_verb, gt_loc, dt_loc]) != None)
    question['id'] = id
    question['subj'] = gt_subj
    question['be'] = gt_be
    question['vp'] = gt_verb
    question['gt_loc'] = gt_loc
    question['dt_loc'] = dt_loc
    question['ppl'] = ppl_loc

    assert(len(tmp_loc_list) == 5)
    random.shuffle(tmp_loc_list) # we shuffle the option's list
    ans_idx = tmp_loc_list.index(gt_loc)
    ans_idx_dt = tmp_loc_list.index(dt_loc)

    # gt_subj_the = replace_a_or_the(gt_subj, 'a')
    # gt_subj_a = replace_a_or_the(gt_subj, 'the')
    max_ppl_loc = max(ppl_loc, key=ppl_loc.get)

    choice = f"Where {gt_be} {gt_subj} {gt_verb}?\nA. {tmp_loc_list[0]}\nB. {tmp_loc_list[1]}\nC. {tmp_loc_list[2]}\nD. {tmp_loc_list[3]}\nE. {tmp_loc_list[4]}"
    # binary question - true
    binary_gt = f"Does the image show that {gt_subj} {gt_be} {gt_verb} {gt_loc}?"
    # binary question - false
    binary_dt = f"Does the image show that {gt_subj} {gt_be} {gt_verb} {dt_loc}?"
    # binary question - comparision
    binary_co = f"Does the image show that {gt_subj} {gt_be} {gt_verb} {max_ppl_loc}?"

    # open question
    open_question = f"Where {gt_be} {gt_subj} {gt_verb}?"

    question['choice'] = choice
    question['choice answer'] = index_to_letter[ans_idx]
    question['choice distractor answer'] = index_to_letter[ans_idx_dt]
    question['binary-yes'] = binary_gt
    question['binary-yes answer'] = "Yes."
    question['binary-no'] = binary_dt
    question['binary-no answer'] = "No."
    question['binary-cp'] = binary_co
    question['binary-cp answer'] = "No."
    question['open'] = open_question
    question['open answer'] = gt_loc
    question['image prompt'] = f"Generate an image of {gt_subj} {gt_verb} {gt_loc}."

    loc_questions.append(question)

In [None]:
with open('loc_questions_vicuna_0302_full.json', 'w') as f:
    json.dump(loc_questions, f)

with open('loc_questions_vicuna_0302_full-h.txt', 'w') as outfile:
    with open('loc_questions_vicuna_0302_full.json', 'r') as infile:
        j_list = json.load(infile)
        for question in j_list:
            outfile.write(f"Question{question['id']}:\n\n{question['image prompt']}\n{question['choice']}\n")
            outfile.write(f"Answer: {question['choice answer']}\n")
            outfile.write(f"Distractor: {question['choice distractor answer']}\n\n")
            outfile.write(f"{question['binary-yes']}\nAnswer: {question['binary-yes answer']}\n")
            outfile.write(f"{question['binary-no']}\nAnswer: {question['binary-no answer']}\n")
            outfile.write(f"{question['binary-cp']}\nAnswer: {question['binary-cp answer']}\n\n")
            outfile.write(f"{question['open']}\nAnswer: {question['open answer']}\n\n")
            outfile.write(f"PPL: {question['ppl']}\n\n")
            outfile.write("="*10)
            outfile.write('\n')