This notebook illustrates how we:

- Query Macaw on everyday things
- Perform constraint reasoning on top of Macaw's outputs based on device + parts list


#  Part1: Prompt the model to get its MM of a device

In [4]:
from utils import *

In [5]:
#make_sure_dir_exists("saved_logs")
#make_sure_dir_exists("macaw_query_beaker_input")

In [6]:
import csv
import json
import ast
from csp import *

Set up beaker. Let's use beaker-py to run experiments on beaker -- no more yaml scripts! :)

To learn more about beaker: https://beaker-docs.apps.allenai.org/

NOTE: Before you run the following cells, change to the appropriate workspace_name of yours.

In [7]:
#!pip install --upgrade beaker-py

In [8]:
#!pip install beaker-py
from beaker import *

workspace_name = "."
beaker = Beaker.from_env(default_workspace=workspace_name)
# check who you are logged in as 
print(beaker.account.whoami().name)

ModuleNotFoundError: No module named 'beaker'

Load annotated data

In [9]:
et2triplets_ann = {}
with open("enriched_mms/full-ET-dataset.tsv", "r") as dataset:
    lines = csv.reader(dataset, delimiter = "\t")
    for line_idx, line in enumerate(lines):
        # skip header
        if line_idx == 0:
            continue
            
        # per MM as in an everyday thing sketched by a turker
        et_turker = (line[0].replace("-", " "), line[1])
        if et_turker not in et2triplets_ann:
            et2triplets_ann[et_turker] = {"triplets": [], "parts-list": []}
           
        # collect list of (triplet_tuple, True_False_label)
        triplet = ast.literal_eval(line[2])
        annotated_relation = (triplet, line[3])
        assert annotated_relation not in et2triplets_ann[et_turker]["triplets"]
        et2triplets_ann[et_turker]["triplets"].append(annotated_relation)
        
        # also collect a list of unique parts
        for part in (triplet[0], triplet[2]):
            if part not in et2triplets_ann[et_turker]["parts-list"]:
                et2triplets_ann[et_turker]["parts-list"].append(part)

# Batch mode
## Generate queries to batch query macaw
Macaw's MM is only dependent on the everyday thing and the parts. Each query depends on the everyday thing, and triplet. We query Macaw for unique queries only (i.e. do not repeat the same question twice).

### Note: creating this file takes a few minutes. You may comment out this section of the code after running it just once!

In [10]:
with open("macaw_query_beaker_input/macaw_query_full-ET-dataset.json", "w") as query_outfile:
    seen = {}
    for et_turker in tqdm(et2triplets_ann):
        et, turker = et_turker
        parts_list = et2triplets_ann[et_turker]["parts-list"]


        # permutation of list of parts
        perm = get_parts_perm(et, parts_list)

        # for all permutations
        for entry in perm:
            for rln in all_relations_lst:

                # form statement
                triplet = (entry[0], rln, entry[1])
                statement = triplet2statement(triplet)

                # form questions
                mc_list = ("True", "False")
                mcoptions = " ".join(["(" + chr(i+65) + ") " + word for i, word in enumerate(mc_list)])

                et = et.replace("-", " ")
                et_triplet = (et, triplet)
                if et_triplet not in seen:
                    
                    compiled_qns = "Judge whether this statement is true or false: In {determiner} {device}, {statement}.".format( \
                            determiner = get_determiner(et), device = et, statement=statement)
                    query_outfile.write(json.dumps({"id": str(et_triplet) + "-originalTF", "question" : compiled_qns, "mcoptions": mcoptions,
                        "angle":[["question","mcoptions"],["answer"]], "explicit_outputs": mc_list}))
                    query_outfile.write("\n")
                    
#                     # Ignore ngegated statements for now (Macaw may not handle negation well)
#                     neg_compiled_qns = "Judge whether this statement is true or false: In a/an {device}, it is not the case that {statement}.".format( \
#                             device = et, statement=statement)
#                     query_outfile.write(json.dumps({"id": str(et_triplet) + "-negatedTF", "question" : neg_compiled_qns, "mcoptions": mcoptions,
#                         "angle":[["question","mcoptions"],["answer"]], "explicit_outputs": mc_list}))
#                     query_outfile.write("\n")

                    seen[et_triplet] = 1
                else:
                    seen[et_triplet] += 1


100%|██████████| 300/300 [10:33<00:00,  2.11s/it]


In [16]:
# structure of seen 
for key in seen.keys():
    print(key)

('air conditioner', ('hot coils', 'part of', 'expansion valve'))
('air conditioner', ('hot coils', 'has part', 'expansion valve'))
('air conditioner', ('hot coils', 'inside', 'expansion valve'))
('air conditioner', ('hot coils', 'contains', 'expansion valve'))
('air conditioner', ('hot coils', 'in front of', 'expansion valve'))
('air conditioner', ('hot coils', 'behind', 'expansion valve'))
('air conditioner', ('hot coils', 'above', 'expansion valve'))
('air conditioner', ('hot coils', 'below', 'expansion valve'))
('air conditioner', ('hot coils', 'surrounds', 'expansion valve'))
('air conditioner', ('hot coils', 'surrounded by', 'expansion valve'))
('air conditioner', ('hot coils', 'requires', 'expansion valve'))
('air conditioner', ('hot coils', 'required by', 'expansion valve'))
('air conditioner', ('hot coils', 'connects', 'expansion valve'))
('air conditioner', ('hot coils', 'next to', 'expansion valve'))
('air conditioner', ('hot coils', 'part of', 'compressor'))
('air conditione

In [18]:
print(seen['tree', ('leaves', 'requires', 'twigs')])

1


In [11]:
print(len(seen), "unique queries")

108528 unique queries


108528 unique queries

## Upload dataset

(only need to be done once)

In [None]:
#Store as <dataset-name> using data from <local-source>
dataset = beaker.dataset.create("full-ET-dataset-2022Dec16", "macaw_query_beaker_input/", quiet=True, strip_paths=True)

## Run experiment

In [None]:
myimage = beaker.image.get("oyvindt/MultiAngleV50")
model_used = "oyvindt/maqaw_11B_v1"
#model_used = "oyvindt/maqaw_large_v1" 
#model_used = "oyvindt/maqaw_3B_v1" 
#model_used = "oyvindt/multi-angle-union2-ang1-large-hf" ##HERE - model type 

qns_type = "full-ET-dataset"
experiment_name = "Macaw-11B_zero-shot_Decision_on_" + qns_type + "_2022Dec16"##HERE - model type

input_dataset_name = "yulingg/full-ET-dataset-2022Dec16"
input_jsonl_file_name = "macaw_query_full-ET-dataset.json"

beaker_cluster_name = "ai2/general-cirrascale"
#beaker_cluster_name = "ai2/aristo-cirrascale"
#beaker_cluster_name = "ai2/mosaic-cirrascale"
#beaker_cluster_name = "ai2/allennlp-elanding-rtx8000"

In [None]:
spec = ExperimentSpec(
    tasks=[
        TaskSpec(
            name="predict",
            image=ImageSource(beaker=myimage.id),
            command=["python", "multi_angle/run_maqaw.py", "--model_name_or_path", \
                        "/inputs/model", "--input_files", "/inputs/data/" + input_jsonl_file_name, \
                       "--output_dir", "/output", "--n_gpu", "2", "--add_metrics", \
                       "--eval_batch_size", "4", "--num_beams", "1"], # gpu
            datasets=[DataMount(mount_path="/inputs/model", source=DataSource(beaker=model_used)),\
                     DataMount(mount_path="/inputs/data", source=DataSource(beaker=input_dataset_name))],
            resources=TaskResources(gpu_count=2, memory ="100GiB"), # gpu
            context=TaskContext(cluster=beaker_cluster_name),
            result=ResultSpec(
                path='/output'  # required even if the task produces no output.
            ),
        ),
    ],
)


In [None]:
experiment = beaker.experiment.create(
    experiment_name,
    spec,
    workspace=workspace_name,
)

## Process predictions from macaw

The following assumes you are storing the predictions as Macaw-11B_zero-shot_Decision_on_full-ET-dataset.jsonl inside macaw_query_beaker_output/, placed in the same directory as this notebook.

In [None]:
model_name = "Macaw-11B"
# model_name = "Macaw-3B"
# model_name = "Macaw-large"
# model_name = "UnifiedQA"
# model_name = "gpt3-text-davinci-003"

In [None]:
# !mkdir macaw_query_beaker_output

In [None]:
et_triplet_2_probTF = {}
with open("macaw_query_beaker_output/" + model_name + "_zero-shot_Decision_on_full-ET-dataset.jsonl", "r") as predfile:
    prediction_data = predfile.readlines()
    for prediction in prediction_data:
        json_pred = json.loads(prediction)
        #print(json_pred)
        
        et_triplet_str, original_negated_label = json_pred['id'].rsplit("-",1)
        #print(et_triplet, original_negated_label)
#         et = et_triplet_str.split(",", 1)[0]
#         et = et.replace("-", " ")
#         et_triplet_str = et + "," + et_triplet_str.split(",", 1)[1]
        
#         # Ignore ngegated statements for now (Macaw may not handle negation well)
#         if "negated" in original_negated_label:
#             continue
        
#         if et_triplet_str in et_triplet_2_probTF:
#             continue
            
        assert et_triplet_str not in et_triplet_2_probTF
        et_triplet_2_probTF[et_triplet_str] = {"answer": None, "prob_True": 0, "prob_False": 0}
        
        # Get answer label
        et_triplet_2_probTF[et_triplet_str]["answer"] = json_pred["output_slots_list"][0]["answer"]
        
        # Get raw scores
        score_true = 0
        score_false = 0
        for output_choice in json_pred['explicit_outputs']:
            if output_choice['output_text'] == "True":
                score_true = output_choice['output_prob']
            elif output_choice['output_text'] == "False":
                score_false = output_choice['output_prob']
                
        if  et_triplet_2_probTF[et_triplet_str]["answer"] not in ("True", "False") and score_true > 0 and score_false > 0:
            #print(et_triplet_str, et_triplet_2_probTF[et_triplet_str]["answer"], score_true,score_false)
            et_triplet_2_probTF[et_triplet_str]["answer"] = str(score_true > score_false)
            #print(et_triplet_2_probTF[et_triplet_str]["answer"])
                
        assert  et_triplet_2_probTF[et_triplet_str]["answer"] == "True" or  et_triplet_2_probTF[et_triplet_str]["answer"] == "False"
        
        # Scale to 100
        if score_true + score_false != 0.0:
            et_triplet_2_probTF[et_triplet_str]["prob_True"] = score_true/(score_false + score_true)
            et_triplet_2_probTF[et_triplet_str]["prob_False"] = 1 - et_triplet_2_probTF[et_triplet_str]["prob_True"]
        else:
            # if true and false not in the top options, label_True_False_probs stays as default
            print(statement, "Alert: true and false not in the options.")



In [None]:
print(model_name)
true_cnt = 0
for triplet_ans in et_triplet_2_probTF:
    for t in triplet_ans[1]:
        if t != t.strip():
            print(triplet_ans)
    if et_triplet_2_probTF[triplet_ans]['answer'] == 'True':
        true_cnt += 1
print("% True tuples: {}/{} ({})".format(true_cnt, len(et_triplet_2_probTF), round((true_cnt/len(et_triplet_2_probTF)) * 100, 2)))

Macaw-11B
% True tuples: 62692/108528 (57.77)


In [None]:
def query_macaw_statements_batch_mode(device, perm):
    '''
    Input: everyday thing, list of tuples for permutation of list of parts
    Output: triplet_ans_conf_lst - contains [triplet,  ans, p_statement]
            neg_ans_conf_lst - list of p_neg_statement
    '''

    triplet_ans_conf_lst = [] # list of list
    neg_ans_conf_lst = [] # list
    for entry in perm:
        for rln in all_relations_lst:
            
            triplet = (entry[0], rln, entry[1])
            et_triplet_str = str((device, triplet))
            
            if et_triplet_str not in et_triplet_2_probTF:
                print("Need to query macaw online for", et_triplet_str)
                statement = triplet2statement(triplet)
                ans, p_statement, p_neg_statement = get_p_statement_and_p_neg_statement(device, statement)
            else:
                stored_data = et_triplet_2_probTF[et_triplet_str]
                ans = stored_data["answer"]
                p_statement = stored_data["prob_True"]
                p_neg_statement = stored_data["prob_False"]
            
            
            triplet_ans_conf_lst.append([triplet,  ans, p_statement])
            neg_ans_conf_lst.append(p_neg_statement)
            
    return triplet_ans_conf_lst, neg_ans_conf_lst

# query macaw on "everyday thing"
def run_query_macaw_everyday_thing_batch_mode(device, parts):
    # get parts
    perm = get_parts_perm(device, parts)
    # get macaw judgment
    triplet_ans_conf_lst, neg_ans_conf_lst = query_macaw_statements_batch_mode(device, perm)
    triplet_ans_conf_lst_true = get_statements_that_macaw_believesT(triplet_ans_conf_lst)
    return triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst

The next 2 cells are just for doing a sanity check with a small example:

In [None]:
triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst =\
    run_query_macaw_everyday_thing_batch_mode("egg", ['yolk', 'egg white', 'shell membrane', 'shell', 'air cell'])

In [None]:
triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst

([[('yolk', 'part of', 'egg white'), 'False', 0.004708560493744304],
  [('yolk', 'has part', 'egg white'), 'False', 0.34120868304130925],
  [('yolk', 'inside', 'egg white'), 'True', 0.9118826352336734],
  [('yolk', 'contains', 'egg white'), 'False', 0.12130310330302765],
  [('yolk', 'in front of', 'egg white'), 'False', 0.38251589606806186],
  [('yolk', 'behind', 'egg white'), 'False', 0.3635262971899514],
  [('yolk', 'above', 'egg white'), 'True', 0.7750988644836971],
  [('yolk', 'below', 'egg white'), 'True', 0.7171058899153786],
  [('yolk', 'surrounds', 'egg white'), 'True', 0.943189790718785],
  [('yolk', 'surrounded by', 'egg white'), 'True', 0.9868005561647796],
  [('yolk', 'requires', 'egg white'), 'False', 0.15301443122687497],
  [('yolk', 'required by', 'egg white'), 'False', 0.161191613819064],
  [('yolk', 'connects', 'egg white'), 'True', 0.5136137525947052],
  [('yolk', 'next to', 'egg white'), 'True', 0.7384337240733773],
  [('yolk', 'part of', 'shell membrane'), 'False', 

# Part 2: Constraint satisfaction


In [None]:
def imagine_a_device_with_csp(device, turker, outputs_dir, filter_threshold, parts=[]):

    device = device.lower()
    tag = "threshold" + str(filter_threshold)
    
    lm_query_dir = outputs_dir + "LMResponses/" # dir where you want to save macaw output
    wcnf_dir = outputs_dir + "WCNF_format/" # dir where you want to save these wcnf for reference
    plots_dir = outputs_dir + "VizPlots/" # dir where you want to store output files
    statements_dir = outputs_dir + "Props/"# dir where you save data from this run
    all_results_filename = device.replace(" ", "-") + "_" + turker + "_" + tag
    for desired_dir in [outputs_dir, lm_query_dir, wcnf_dir, plots_dir, statements_dir]:
        make_sure_dir_exists(desired_dir)
    
    if all_results_filename + ".pkl" in os.listdir(statements_dir):
         # read
        with open(statements_dir + all_results_filename + ".pkl", 'rb') as f:
             all_result_dict = pickle.load(f)
        print("Read from file ...", len(all_result_dict["macaw_predictions"]), "triplets ...")
    else:
        # lm response - do not want to query LM again if the same device has been asked
        if device.replace(" ", "-") + "-" + turker + ".pkl" in os.listdir(lm_query_dir):
            # read
            with open(lm_query_dir + device.replace(" ", "-") + "-" + turker + ".pkl", 'rb') as f:
                 triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst = pickle.load(f) 
        else:
            # query macaw
            triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst = run_query_macaw_everyday_thing_batch_mode(device, parts)
            # save
            with open(lm_query_dir + device.replace(" ", "-") + "-" + turker + ".pkl", 'wb') as f:
                pickle.dump([triplet_ans_conf_lst, triplet_ans_conf_lst_true, neg_ans_conf_lst], f)

        # use maxsat
        print("Running maxsat ...", len(triplet_ans_conf_lst), "triplets...")
        model_believe_true_props, maxsat_selected_props = run_maxsat(device, turker, wcnf_dir, triplet_ans_conf_lst, neg_ans_conf_lst, triplet_ans_conf_lst_true, use_only_model_true_props = False)

        print("Filtering ...", len(model_believe_true_props), "triplets...", len(maxsat_selected_props), "triplets...")
        # filter based on confidence
        model_believe_true_props_filtered = filter_props(model_believe_true_props, filter_threshold)
        maxsat_selected_props_filtered = filter_props(maxsat_selected_props, filter_threshold)

        # plot
        print("Generating visualization ...", len(model_believe_true_props_filtered), "believed...", len(maxsat_selected_props_filtered), "selected")
        generate_graph_png(device, turker, model_believe_true_props_filtered, plots_dir, "model_believe_true_" + tag)
        generate_graph_png(device, turker, maxsat_selected_props_filtered, plots_dir, "maxsat_selected_" + tag)
        believed_selected= [k for k,v in model_believe_true_props_filtered.items() if k in maxsat_selected_props_filtered]
        generate_graph_png(device, turker, believed_selected, plots_dir, "believed_selected_" + tag)

        # save result
        all_result_dict = {"macaw_predictions": triplet_ans_conf_lst,\
                        "macaw_predictions_believe_true": triplet_ans_conf_lst_true,\
                        "model_believe_true_props": model_believe_true_props,\
                        "maxsat_selected_props": maxsat_selected_props,\
                        "filter_threshold": filter_threshold,\
                        "model_believe_true_props_filtered": model_believe_true_props_filtered,\
                        "maxsat_selected_props_filtered": maxsat_selected_props_filtered}

        with open(statements_dir + all_results_filename + ".pkl", 'wb') as f:
            pickle.dump(all_result_dict, f)
        print()
    return all_result_dict

    

In [None]:
outputs_dir = "0_" + model_name + "-ImagineADevice-CSP-Viz-full-ET-dataset/"
# # outputs_dir = "0_Macaw-3B-ImagineADevice-CSP-Viz-full-ET-dataset/"
# # outputs_dir = "0_Macaw-large-ImagineADevice-CSP-Viz-full-ET-dataset/"
# # outputs_dir = "0_UnifiedQA-ImagineADevice-CSP-Viz-full-ET-dataset/"
filter_threshold = 50 

In [None]:
sorted_et2triplets_ann = sorted(et2triplets_ann, key=lambda et_turker: len(et2triplets_ann[et_turker]['parts-list']))

Be patient, we may encounter some formulae that will take longer to solve (like a few minutes)!

In [None]:
for mm_idx, et_turker in enumerate(sorted_et2triplets_ann) :
    print(et_turker, "MM #", mm_idx + 1)
        
    et, turker = et_turker
    parts_list = et2triplets_ann[et_turker]['parts-list']
    print(len(parts_list))
    all_result_dict = imagine_a_device_with_csp(et, turker, outputs_dir, filter_threshold, parts_list)

In [None]:
macaw_getMM_logfile.close()
impose_contraints_logfile.close()