## Dataset overlap analysis

This notebook computes hit statistics, accuracy on hits vs. non-hits, and provides a tool for visual inspection of hits.

In [1]:
%pylab inline

import os
import pandas as pd
import json
import glob
from collections import Counter, defaultdict
pd.set_option('display.max_colwidth', None)

Populating the interactive namespace from numpy and matplotlib


In [2]:
def hit_stats(hits):
    example_ctr = Counter()
    for hit in hits:
        example_ctr[hit['test_id']] += 1
    unique_docs = set([hit['text'] for hit in hits])
    stats = {
        'unique_test_examples_with_hits': len(example_ctr),
        'unique_docs_with_hits': len(unique_docs),
        'num_hits': len(hits),
        'test_example_hit_counts': list(example_ctr.most_common())
    }
    return stats

## Hit statistics

In [4]:
output_directory = '../output/llemma'

output_files = [
    # MATH
    'open-web-math_open-web-math-v1.2_MATH_input_hits_30.json',
    'open-web-math_open-web-math-v1.2_MATH_output_hits_30.json',
    'mathstack_MATH_input_hits_30.json',
    'mathstack_MATH_output_hits_30.json',
    # GSM8k
    'open-web-math_open-web-math-v1.2_gsm8k_input_hits_30.json',
    'open-web-math_open-web-math-v1.2_gsm8k_output_hits_30.json',
    'mathstack_gsm8k_input_hits_30.json',
    'mathstack_gsm8k_output_hits_30.json',
    # Model generations
    'open-web-math-v1.2_llemma7b_MATH_generations_hits_30.json'
]

df_data = []
for output_file in output_files:
    hits = [json.loads(x) for x in open(os.path.join(output_directory, output_file), 'r').readlines()]
    stats = hit_stats(hits)
    
    df_data.append({
        'file': output_file.split('/')[-1],
        'unique_test_examples_with_hits': stats['unique_test_examples_with_hits'],
        'unique_docs_with_hits': stats['unique_docs_with_hits'],
    })

df = pd.DataFrame(df_data)
df

Unnamed: 0,file,unique_test_examples_with_hits,unique_docs_with_hits
0,open-web-math_open-web-math-v1.2_MATH_input_hits_30.json,348,717
1,open-web-math_open-web-math-v1.2_MATH_output_hits_30.json,34,46
2,mathstack_MATH_input_hits_30.json,3,3
3,mathstack_MATH_output_hits_30.json,1,1
4,open-web-math_open-web-math-v1.2_gsm8k_input_hits_30.json,2,3
5,open-web-math_open-web-math-v1.2_gsm8k_output_hits_30.json,0,0
6,mathstack_gsm8k_input_hits_30.json,0,0
7,mathstack_gsm8k_output_hits_30.json,0,0
8,open-web-math-v1.2_llemma7b_MATH_generations_hits_30.json,13,437


## Evaluation on hits vs. non-hits

In [6]:
hits_file = 'open-web-math_open-web-math-v1.2_MATH_input_hits_30.json'
hits = [json.loads(x) for x in open(os.path.join(output_directory, hits_file), 'r').readlines()]

hits_file = 'open-web-math_open-web-math-v1.2_MATH_output_hits_30.json'
hits += [json.loads(x) for x in open(os.path.join(output_directory, hits_file), 'r').readlines()]

output_json = 'llemma_34b_minerva_math_maj1.json'
outputs = json.load(open(os.path.join(output_directory, output_json)))

hit_inputs = set([hit['input'] for hit in hits])
print("Unique hit inputs: %d" % len(hit_inputs))

Unique hit inputs: 367


In [7]:
stats = defaultdict(float)

levels = set()
for task, cache in outputs['cache'].items():    
    for item in cache:
        assert item['acc'] in {0, 1}        
        levels.add(item['level'])

        # Accuracy for hit problems
        if item['problem'] in hit_inputs:
            stats['level%s_hit_acc' % item['level']] += item['acc']
            stats['level%s_hit_n' % item['level']] += 1
            
            stats['hit_acc'] += item['acc']
            stats['hit_n'] += 1
            
        # Accuracy for non-hit problems
        else:
            stats['level%s_nonhit_acc' % item['level']] += item['acc']
            stats['level%s_nonhit_n' % item['level']] += 1
            
            stats['nonhit_acc'] += item['acc']
            stats['nonhit_n'] += 1

# Normalize
for level in levels:
    stats['level%s_hit_acc' % level] = stats['level%s_hit_acc' % level] / stats['level%s_hit_n' % level]
    stats['level%s_nonhit_acc' % level] = stats['level%s_nonhit_acc' % level] / stats['level%s_nonhit_n' % level]

df_data = []
for level in sorted(levels):
    df_data.append({
        'level': level, 
        'hit_acc': stats['level%s_hit_acc' % level],
        'nonhit_acc': stats['level%s_nonhit_acc' % level],
        'n_hits': stats['level%s_hit_n' % level],
    })

pd.options.display.float_format = '{:.2%}'.format
df = pd.DataFrame(df_data).round(4)
df_ = df.style.format({
    'hit_acc': '{:.2%}'.format,
    'nonhit_acc': '{:.2%}'.format,
    'n_hits': '{:}'.format,
})
df_

Unnamed: 0,level,hit_acc,nonhit_acc,n_hits
0,Level 1,72.73%,61.50%,11.0
1,Level 2,35.71%,40.18%,28.0
2,Level 3,30.36%,26.88%,56.0
3,Level 4,14.89%,16.61%,94.0
4,Level 5,6.08%,6.39%,181.0


## Visual inspection



In [8]:
output_file = os.path.join(output_directory, 'open-web-math_open-web-math-v1.2_MATH_input_hits_30.json')

In [9]:
import ipywidgets as widgets
from ipywidgets import HBox, VBox
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, Markdown, Latex

hits = [json.loads(x) for x in open(output_file, 'r').readlines()]
stats = hit_stats(hits)

def match(seq1, seq2):
    idx = seq1.find(seq2)
    # NOTE: this visualization is occasionally unreliable, e.g. when the span
    # occurs inside of latex. Therefore, be sure to double check the raw sequence.
    return seq1[:idx] + '<span style="color: red;">' + seq1[idx:idx+len(seq2)] + '</span>' + seq1[idx+len(seq2):]
    
def template(hit):
    md = """### Hit
%s

### URL
%s

### Text
%s



#### Test id: %d

#### Input
%s

#### Output
%s
""" % (hit['hits'][1]['ngram'], 
       hit['url'], 
       match(hit['text'], hit['hits'][1]['ngram']), 
       hit['id'], hit['input'], hit['output'])
    return md


items = hits

@widgets.interact(idx=(0, len(items)-1))
def f(idx=0):
    num_hits = [x for x in stats['test_example_hit_counts'] if x[0] == hits[idx]['id']][0]
    print(num_hits)
    display(Markdown(template(hits[idx])))    

interactive(children=(IntSlider(value=0, description='idx', max=716), Output()), _dom_classes=('widget-interac…