In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import json
import jsonlines
from collections import defaultdict

### Error analysis with AMA on really small FMs

In [30]:
def get_case(question_from, extractions_from):
    if question_from != "none":
        return "question"
    elif extractions_from != "none":
        return "extraction"
    else:
        return "base"

In [48]:
sample_size = {
    "super_glue_wsc": "full test set",
    "super_glue_cb": "full test set",
    "anli_r3": "n=150 (1200 in full test set)",
    "super_glue_rte": "n=150 (277 in full test set)"
}

#### Metrics

In [50]:
data_dir = "/home/ahojel/"
metrics_1b = defaultdict(dict)
metrics_6b = defaultdict(dict)
metrics_560m = defaultdict(dict)
with open(f"{data_dir}/ama/metrics.json") as f:
    reader = jsonlines.Reader(f)
    for obj in reader:
        task_name = obj['task_name']
        zero_shot = obj['zero_shot']
        few_shot = obj['few_shot']
        decomposed = obj['decomposed']
        decomposed_avg = obj['decomposed_by_boost_avg']
        questions_from = obj["questions_from"]
        extractions_from = obj["extractions_from"]
        case = get_case(questions_from, extractions_from)
        if "EleutherAI_gpt-neo-1.3B" in obj['model_name']:
            metrics_1b[task_name][case] = {
                "zero_shot": zero_shot,
                "few_shot": few_shot,
                "decomposed": decomposed,
                "decomposed_avg": decomposed_avg,
                "questions_from": questions_from,
                "extractions_from": extractions_from
            }
        elif "EleutherAI_gpt-j-6B" in obj['model_name']:
            metrics_6b[task_name][case] = {
                "zero_shot": zero_shot,
                "few_shot": few_shot,
                "decomposed": decomposed,
                "decomposed_avg": decomposed_avg,
                "questions_from": questions_from,
                "extractions_from": extractions_from
            }
            
for task, results in metrics_1b.items():
    print(f"{task}: {sample_size[task]}")
    print()
    print("Exploration of 1.3B Model")
    print(f"1.3B (original): AMA: %.3f (avg: %.3f)" % (results["base"]['decomposed'], results["base"]['decomposed_avg']))
    if task == "super_glue_wsc":
        print(f"1.3B (extraction from gpt-j-6B): AMA: %.3f (avg: %.3f)" % (results["extraction"]['decomposed'], results["extraction"]['decomposed_avg']))
    print(f"1.3B (question from gpt-j-6B): AMA: %.3f (avg: %.3f)" % (results["question"]['decomposed'], results["question"]['decomposed_avg']))
    print()
    print("Exploration of 6B Model")
    print(f"6B (original): AMA: %.3f (avg: %.3f)" % (metrics_6b[task]["base"]['decomposed'], metrics_6b[task]["base"]['decomposed_avg']))
    if task == "super_glue_wsc":
        print(f"6B (extraction from gpt-neo-1.3B): AMA: %.3f (avg: %.3f)" % (metrics_6b[task]["extraction"]['decomposed'], metrics_6b[task]["extraction"]['decomposed_avg']))
    print(f"6B (question from gpt-neo-1.3B): AMA: %.3f (avg: %.3f)" % (metrics_6b[task]["question"]['decomposed'], metrics_6b[task]["question"]['decomposed_avg']))
    print()    
    print("----------")

super_glue_wsc: full test set

Exploration of 1.3B Model
1.3B (original): AMA: 0.615 (avg: 0.622)
1.3B (extraction from gpt-j-6B): AMA: 0.673 (avg: 0.644)
1.3B (question from gpt-j-6B): AMA: 0.673 (avg: 0.647)

Exploration of 6B Model
6B (original): AMA: 0.779 (avg: 0.753)
6B (extraction from gpt-neo-1.3B): AMA: 0.731 (avg: 0.689)
6B (question from gpt-neo-1.3B): AMA: 0.654 (avg: 0.663)

----------
super_glue_cb: full test set

Exploration of 1.3B Model
1.3B (original): AMA: 0.625 (avg: 0.607)
1.3B (question from gpt-j-6B): AMA: 0.643 (avg: 0.613)

Exploration of 6B Model
6B (original): AMA: 0.821 (avg: 0.833)
6B (question from gpt-neo-1.3B): AMA: 0.804 (avg: 0.827)

----------
anli_r3: n=150 (1200 in full test set)

Exploration of 1.3B Model
1.3B (original): AMA: 0.293 (avg: 0.319)
1.3B (question from gpt-j-6B): AMA: 0.340 (avg: 0.339)

Exploration of 6B Model
6B (original): AMA: 0.380 (avg: 0.335)
6B (question from gpt-neo-1.3B): AMA: 0.340 (avg: 0.324)

----------
super_glue_rte: n=

In [40]:
metrics_6b["super_glue_wsc"]

{'extraction': {'zero_shot': -1.0,
  'few_shot': -1.0,
  'decomposed': 0.7307692307692307,
  'decomposed_avg': 0.6891025641025642,
  'questions_from': 'none',
  'extractions_from': 'gpt-neo-1.3B'},
 'question': {'zero_shot': -1.0,
  'few_shot': -1.0,
  'decomposed': 0.6538461538461539,
  'decomposed_avg': 0.6634615384615384,
  'questions_from': 'gpt-neo-1.3B',
  'extractions_from': 'gpt-neo-1.3B'},
 'base': {'zero_shot': -1.0,
  'few_shot': -1.0,
  'decomposed': 0.7788461538461539,
  'decomposed_avg': 0.7532051282051282,
  'questions_from': 'none',
  'extractions_from': 'none'}}

#### Raw preds

In [8]:
task_name = "super_glue_cb"
model_name = "EleutherAI_gpt-j-6B"

In [9]:
if task_name == "amazon_products":
    date = "10042022"
    small_date = "10102022"
elif task_name == "super_glue_wsc" or task_name == "super_glue_cb" or task_name == "webq":
    date = "10112022"
    small_date = "10112022"
elif task_name == "super_glue_rte":
    date = "10112022"
    small_date = "10102022"
else:
    date = "10032022"
    small_date = "10102022"
    
with open(f"{data_dir}/ama_logs/ama_final_runs/{task_name}/EleutherAI_gpt-j-6B_decomposed_{date}.json") as f:
    data_big = json.load(f)
with open(f"{data_dir}/ama_logs/ama_final_runs/{task_name}/EleutherAI_gpt-neo-1.3B_decomposed_{small_date}.json") as f:
    data_small = json.load(f)

In [10]:
first = True
for i, (k, v) in enumerate(data_big.items()):
    
    pred_big = v['pred']
    pred_small = data_small[k]['pred']
    gold = v['gold']
    
    if gold == pred_big and gold == pred_small:
        continue
    
    print(f"BIG MODEL:")
    if first:
        print(v['prompts'][0][0])
        print()
        
        if len(v['prompts'][0]) > 1:
            print(v['prompts'][0][1])
        if task_name == "super_glue_wsc":
            print(v['prompts'][0][2])
        first = False
            

    else:
        if len(v['prompts'][0]) > 1:
            print(v['prompts'][0][1].split("\n\n")[-1])
        else:
            print(v['prompts'][0][0].split("\n\n")[-1])
        
        if task_name == "super_glue_wsc":
            print()
            print(v['prompts'][0][2].split("\n\n")[-1])
            
    print(f"\nPREDS: {v['preds_boost']} --> {v['pred']}")

    print(f"\nSMALL MODEL:")
    
    if len(v['prompts'][0]) > 1:
        print(data_small[k]['prompts'][0][1].split("\n\n")[-1])
    else:
        print(data_small[k]['prompts'][0][0].split("\n\n")[-1])
    if task_name == "super_glue_wsc":
            print()
            print(data_small[k]['prompts'][0][2].split("\n\n")[-1])
            
    print(f"\nPREDS: {data_small[k]['preds_boost']} --> {data_small[k]['pred']}")
    
    print(f"\nGOLDS: {v['gold']}")
    print("\n --------------------------------------------- \n")
    
    if i == 50:
        break

BIG MODEL:

Rewrite the statement as a yes/no question.

Statement: most of the light comes from the sun
Question: Does most of the light come from the sun?

Statement: the test was not
Question: Was the test hard?

Statement: it is a good idea to buy your parents gifts
Question: Is it a good idea to buy your parents gifts?

Statement: the balloon popped
Question: Did the balloon pop?

Statement: The father and son went camping to California.
Question: Did the father and son go camping?

Statement: there is a shipment
Question:


Provide the answer to the question from the passage.

Passage: When Judy and Jack went to school, they got in trouble with their teacher for being late. I didn't think it was very fair.
Question: Did she think it was fair?
Answer: No

Passage: If inflation is occurring, leading to higher prices for basic necessities such as gas by 2 dollars. Do you think that inflation is good for society?
Question: Is inflation good for society?
Answer: Maybe

Passage: Put yo

In [59]:
print(v['preds_boost'])

['world news', 'business', 'business']
