In [1]:
cd ..

/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing


In [2]:
# change to where ever you have your feature datasets saved
import os
#os.environ['HF_DATASETS_CACHE'] = '/Users/katherineharvey/sparse-probing-4/data/feature_datasets'''
os.environ['HF_DATASETS_CACHE'] = '/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/downloads'

os.environ['TRANSFORMERS_OFFLINE'] = '1'

In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch

from load import *
from probing_datasets import language_id

In [6]:
NATURAL_LANGS_UNABBREVIATED = {
    'bg': 'Bulgarian',
    'de': 'German',
    'es': 'Spanish',
    'fr': 'French',
    'lt': 'Lithuanian',
    'pl': 'Polish',
    'sk': 'Slovak',
    'da': 'Danish',
    'en': 'English',
    'fi': 'Finnish',
    'it': 'Italian',
    'nl': 'Dutch',
    'ro': 'Romanian',
    'sv': 'Swedish',
    'cs': 'Czech',
    'el': 'Greek',
    'et': 'Estonian',
    'hu': 'Hungarian',
    'lv': 'Lativian',
    'pt': 'Portuguese',
    'sl': 'Slovenian'
}

# Overview
We need nice tables for the paper which describe all of the models and datasets we used/made. You might read up a _little_ bit about how to make tables manually (see [here](https://www.overleaf.com/learn/latex/Tables)) but you will want to do as much as possible programatically (such that if we want to make any changes we can just rerun the cell). My recommendation to start is to put everything into a pandas dataframe and then run `df.to_latex()`. You can then copy it into the overleaf to see how it renders (though you can also render it in a markdown cell in jupyter). I expect chatGPT to be fairly helpful with this so be mindful how you might factor the problem into promptable chunks.

## Model Table
Here we want a table similar to the models table [here](https://raw.githubusercontent.com/EleutherAI/pythia/main/README.md). We use the 70m-6.9b models (which are not deduped). Only include the architecture params (ie. anything that starts with n_ or d_), as well as the number of total neurons (4 * d_model * n_layers).

In [7]:

model_df = pd.DataFrame({
    'Params': ['Pythia 70M','Pythia 160M', 'Pythia 410M', 'Pythia 1B','Pythia 1.4B', 'Pythia 2.8B', 'Pythia 6.9B'], 
    'n_layers': [6, 12, 24, 16, 24, 32, 32],
    'd_model': [512, 768, 1024, 2048, 2048, 2560, 4096],
    'n_heads': [8, 12, 16, 8, 16, 32, 32],
    'd_head': [64, 64, 64, 256, 128, 80, 128],
    'Total Number of Neurons': [12288, 36864, 98304, 131072, 196608, 327680, 524288]
})
model_df

Unnamed: 0,Params,n_layers,d_model,n_heads,d_head,Total Number of Neurons
0,Pythia 70M,6,512,8,64,12288
1,Pythia 160M,12,768,12,64,36864
2,Pythia 410M,24,1024,16,64,98304
3,Pythia 1B,16,2048,8,256,131072
4,Pythia 1.4B,24,2048,16,128,196608
5,Pythia 2.8B,32,2560,32,80,327680
6,Pythia 6.9B,32,4096,32,128,524288


In [8]:
model_latex_table = model_df.to_latex(index=False)
print(model_latex_table)

\begin{tabular}{lrrrrr}
\toprule
     Params &  n\_layers &  d\_model &  n\_heads &  d\_head &  Total Number of Neurons \\
\midrule
 Pythia 70M &         6 &      512 &        8 &      64 &                    12288 \\
Pythia 160M &        12 &      768 &       12 &      64 &                    36864 \\
Pythia 410M &        24 &     1024 &       16 &      64 &                    98304 \\
  Pythia 1B &        16 &     2048 &        8 &     256 &                   131072 \\
Pythia 1.4B &        24 &     2048 &       16 &     128 &                   196608 \\
Pythia 2.8B &        32 &     2560 &       32 &      80 &                   327680 \\
Pythia 6.9B &        32 &     4096 &       32 &     128 &                   524288 \\
\bottomrule
\end{tabular}



  model_latex_table = model_df.to_latex(index=False)


## Dataset Table
Here we want a table which describes all of the datasets we made/use (listed below, though these could change). You can get started just creating a table for these, but for a few of them (potentially just EWT actually) we may want additional granularity. The columns we want are:
* Dataset Name
* Number of rows 
* Number of columns (i.e. number of tokens per row)
* Number of total tokens (i.e. number of tokens which aren't 0 or 1)
* List of the actual features (ie, the names of all the languages)
* Original data source (e.g. "pile-europarl")
You should get all of the "number" items programmatically, that is by actually loading every dataset and computing these things programmatically. For the list of features, depending on the dataset you should be able to get it programatically by looking at the column names in the dataset which are of the form `<feature_name>|probe_indices` or get the unique values from a relevant column in the dataset (like `lang` below) but if there are tricky edge cases you can hardcode some of them. For the original data source, you can just hardcode it in a dictionary (and you can make up data for now and I will put in the correct values).

Again the goal is to programmatically construct a pandas dataframe with all this information, then generate a table, then potentially do something which makes it look nicer.


In [75]:
feature_datasets = [
    'programming_lang_id.pyth.512.-1',
    'natural_lang_id.pyth.512.-1',
    'text_features.pyth.256.10000',
    'counterfact.pyth.64.-1',
    'distribution_id.pyth.512.-1',
    'ewt.pyth.512.-1',
    'compound_words.pyth.24.-1',
    'latex.pyth.1024.-1'
]

dataset_sources = {
    'programming_lang_id.pyth.512.-1': 'pile-github',
    'natural_lang_id.pyth.512.-1': 'pile-europarl',
    'text_features.pyth.256.10000': 'pile-test-all',
    'counterfact.pyth.64.-1': 'pile-test-all',
    'distribution_id.pyth.512.-1': 'pile-test-all',
    'ewt.pyth.512.-1': 'EWT',
    'compound_words.pyth.24.-1': 'pile-test-all',
    'latex.pyth.1024.-1': 'pile-arxiv'
}



In [14]:
# small helper function to unabbreviate languages
def unabbreviate_languages(abbr_list):
    unabbr_list = [NATURAL_LANGS_UNABBREVIATED.get(abbr, abbr) for abbr in abbr_list]
    return ", ".join(unabbr_list)

In [106]:
def get_features(dataset):
    fds = load_feature_dataset(dataset)
    features = set()
    for feature in fds.features:
        if '|' in feature:
            features.add(feature.split('|')[0])
        elif feature in ('lang', 'relation_id', 'distribution', 'feature_name'):
            features.update(set(fds[feature]))
    
    features = [item for item in features if item not in ['all_tokens', 'tokens', 'meta', 'text']]
    return unabbreviate_languages(features)

In [79]:
get_features('programming_lang_id.pyth.512.-1')



'Python, XML, Java, C++, HTML, C, Go, PHP, JavaScript'

In [15]:
# creates a list of all features that start with a certain prefix
def filter_by_prefix(labels, prefix):
    return [label for label in labels if label.startswith(prefix) and label.endswith('probe_classes')]

In [16]:
def compute_avg_positives():
    average_pos = []
    average_neg = []
    ewt = load_feature_dataset('ewt.pyth.512.-1')
    upos = filter_by_prefix(ewt.features, 'upos')
    dep = filter_by_prefix(ewt.features, 'dep')
    other = list(filter(lambda x: not x.startswith("upos") and not x.startswith("dep"), ewt.features))
    for dataset in feature_datasets:
        if dataset == 'programming_lang_id.pyth.512.-1' or dataset == 'natural_lang_id.pyth.512.-1' or dataset == 'distribution_id.pyth.512.-1':
            average_pos.append(round(1/9, 2))
            average_neg.append(round(8/9, 2))
        elif dataset == 'text_features.pyth.256.10000' or dataset == 'latex.pyth.1024.-1':
            fds = load_feature_dataset(dataset)
            feature_names = [name for name in fds.features if '|probe_classes' in name]
            total_positive = 0
            total_negative = 0
            total_count = 0
            for feature in feature_names:
                concat_tensor = torch.cat(fds[feature])
                count_positive = (concat_tensor == 1).sum().item()
                total_positive += count_positive
                count_negative = (concat_tensor == -1).sum().item()
                total_negative += count_negative
                total_count += (concat_tensor == -1).sum().item() + count_positive
                avg_pos = total_positive / total_count
                avg_neg = total_negative / total_count
            average_pos.append(round(avg_pos, 2))
            average_neg.append(round(avg_neg, 2))
        elif dataset == 'counterfact.pyth.64.-1':
            average_pos.append(0)
            average_neg.append(0)
        elif dataset == 'ewt.pyth.512.-1':
            fds = load_feature_dataset(dataset)
            for features, feature_name in [(set(upos), 'upos'), (set(dep), 'dep'), (set(other), 'other')]:
                feature_names = [name for name in features if '|probe_classes' in name]
                total_positive = 0
                total_negative = 0 
                total_count = 0
                for feature in feature_names:
                    concat_tensor = torch.cat(fds[feature])
                    count_positive = (concat_tensor == 1).sum().item()
                    total_positive += count_positive
                    count_negative = (concat_tensor == -1).sum().item()
                    total_negative += count_negative
                    total_count += (concat_tensor == -1).sum().item() + count_positive
                avg_pos = total_positive / total_count
                avg_neg = total_negative / total_count
                average_pos.append(round(avg_pos, 1))
                average_neg.append(round(avg_neg, 1))
        elif dataset == 'compound_words.pyth.24.-1':
            fds = load_feature_dataset('compound_words.pyth.24.-1')
            avg_pos = 33559 / len(fds)
            avg_neg = 1 - avg_pos
            average_pos.append(round(avg_pos, 1))
            average_neg.append(round(avg_neg, 1))
    return average_pos, average_neg


In [17]:
compute_avg_positives()

([0.11, 0.11, 0.26, 0, 0.11, 0.2, 0.2, 0.2, 0.2, 0.26],
 [0.89, 0.89, 0.74, 0, 0.89, 0.8, 0.8, 0.8, 0.8, 0.74])

In [95]:
def get_dataset_info(feature_datasets, dataset_sources):
    shapes = []
    num_tokens_list = []
    sources = []
    col = []
    features_dict = {}
    avg_positives = []
    
    # info for all the datasets 
    for dataset in feature_datasets:
        _fds = load_feature_dataset(dataset)
        shapes.append(_fds.shape)
        col.append(_fds['tokens'].shape[1])
        num_tokens = (_fds['tokens'] > 1).numpy().sum()
        num_tokens_list.append(num_tokens)
        sources.append(dataset_sources[dataset])
        features_dict[dataset] = get_features(dataset)
        average_pos, average_neg = compute_avg_positives()

    # creating a new list to contain all the data bc I did all the ewt stuff separately 
    dataset_data = []

    # ewt 
    ewt = load_feature_dataset('ewt.pyth.512.-1')
    upos = filter_by_prefix(ewt.features, 'upos')
    dep = filter_by_prefix(ewt.features, 'dep')
    other = list(filter(lambda x: not x.startswith("upos") and not x.startswith("dep") and x.endswith('probe_indices'), ewt.features))
    # defined i so I can match with the correct value in average_list 
    i = 5
    for features, feature_name in [(set(upos), 'upos'), (set(dep), 'dep'), (set(other), 'other')]:
        i = i + 1
        features_list = set()
        for feature in features:
            features_list.add(feature.split('|')[0])
        dataset_data.append({
            'Dataset': f'ewt.pyth.512.-1 ({feature_name})',
            'Sequences': ewt.shape[0],
            'Context Length': ewt['tokens'].shape[1],
            'Non padding tokens': (ewt['tokens'] > 1).numpy().sum(),
            'Total Features': len(set(features)),
            'Average Class Balance': average_pos[i],
            #'Average Number of Negative Examples' : average_neg[i],
            'Features': unabbreviate_languages([item.split('|')[0] for item in features_list if item not in ['all_tokens', 'tokens', 'meta', 'text']]),
            'Source': dataset_sources[feature_datasets[5]]
        })
        
    # excluding values from ewt (probably a better way to do this) so the rest of the averages get mapped to the right dataset    
    del average_pos[5:8]
    del average_neg[5:8]
    # defined j so I can match with the correct value in average_list
    j = -1
    # appends everything but ewt to the larger dataset
    for i, dataset in enumerate(feature_datasets):
        if i == 5: continue 
        j = j + 1
        fds = load_feature_dataset(dataset)
        dataset_data.append({
            'Dataset': dataset,
            'Sequences': fds.shape[0],
            'Context Length': fds['tokens'].shape[1],
            'Non padding tokens': (fds['tokens'] > 1).numpy().sum(),
            'Total Features': len(get_features(dataset).split(', ')),
            'Average Class Balance': average_pos[j],
            #'Average Number of Negative Examples' : average_neg[j],
            'Features': get_features(dataset),
            'Source': dataset_sources[dataset]
        })

    dataset_df = pd.DataFrame(dataset_data, columns=['Dataset', 'Sequences', 'Context Length', 'Non padding tokens', 'Source', 'Average Class Balance', 'Total Features'])
    return dataset_df


In [96]:
dataset_df = get_dataset_info(feature_datasets, dataset_sources)



In [97]:
pd.set_option('max_colwidth', None)
dataset_df

Unnamed: 0,Dataset,Sequences,Context Length,Non padding tokens,Source,Average Class Balance,Total Features
0,ewt.pyth.512.-1 (upos),1438,512,281044,EWT,0.2,16
1,ewt.pyth.512.-1 (dep),1438,512,281044,EWT,0.2,29
2,ewt.pyth.512.-1 (other),1438,512,281044,EWT,0.2,22
3,programming_lang_id.pyth.512.-1,5397,512,2757867,pile-github,0.11,9
4,natural_lang_id.pyth.512.-1,28084,512,14350924,pile-europarl,0.11,9
5,text_features.pyth.256.10000,10000,256,2248714,pile-test-all,0.26,11
6,counterfact.pyth.64.-1,43820,64,403331,pile-test-all,0.0,34
7,distribution_id.pyth.512.-1,8413,512,4299043,pile-test-all,0.11,9
8,compound_words.pyth.24.-1,167959,24,4031016,pile-test-all,0.2,21
9,latex.pyth.1024.-1,4486,1024,4589178,pile-arxiv,0.26,12


In [98]:
dataset_latex_table = dataset_df.to_latex(index=False)
print(dataset_latex_table)

\begin{tabular}{lrrrlrr}
\toprule
                        Dataset &  Sequences &  Context Length &  Non padding tokens &        Source &  Average Class Balance &  Total Features \\
\midrule
         ewt.pyth.512.-1 (upos) &       1438 &             512 &              281044 &           EWT &                   0.20 &              16 \\
          ewt.pyth.512.-1 (dep) &       1438 &             512 &              281044 &           EWT &                   0.20 &              29 \\
        ewt.pyth.512.-1 (other) &       1438 &             512 &              281044 &           EWT &                   0.20 &              22 \\
programming\_lang\_id.pyth.512.-1 &       5397 &             512 &             2757867 &   pile-github &                   0.11 &               9 \\
    natural\_lang\_id.pyth.512.-1 &      28084 &             512 &            14350924 & pile-europarl &                   0.11 &               9 \\
   text\_features.pyth.256.10000 &      10000 &             256 &      

  dataset_latex_table = dataset_df.to_latex(index=False)


In [112]:
# lowkey code duplication, I j copied the other table but selected different columns and deleted the columns we didn't need for this table 
def get_dataset_features(feature_datasets):
    features_dict = {}
    
    for dataset in feature_datasets:
        _fds = load_feature_dataset(dataset)
        features_dict[dataset] = get_features(dataset)

    dataset_data = []

    ewt = load_feature_dataset('ewt.pyth.512.-1')
    upos = filter_by_prefix(ewt.features, 'upos')
    dep = filter_by_prefix(ewt.features, 'dep')
    other = list(filter(lambda x: not x.startswith("upos") and not x.startswith("dep") and x.endswith('probe_indices'), ewt.features))
    i = 5
    for features, feature_name in [(set(upos), 'upos'), (set(dep), 'dep'), (set(other), 'other')]:
        i = i + 1
        features_list = set()
        for feature in features:
            features_list.add(feature.split('|')[0])
        dataset_data.append({
            'Dataset': f'ewt.pyth.512.-1 ({feature_name})',
            'Features': unabbreviate_languages([item.split('|')[0] for item in features_list if item not in ['all_tokens', 'tokens', 'meta', 'text']])
        })
        
    j = -1
    for i, dataset in enumerate(feature_datasets):
        if i == 5: continue 
        j = j + 1
        fds = load_feature_dataset(dataset)
        dataset_data.append({
            'Dataset': dataset,
            'Features': get_features(dataset)
        })

    dataset_df = pd.DataFrame(dataset_data, columns=['Dataset', 'Features'])
    return dataset_df

In [113]:
feature_names = get_dataset_features(feature_datasets)

In [114]:
feature_names['Dataset'] = ['part of speech', 'dependencies', 'morphology', 'code language', 'natural language', 'text features', 'counterfact', 'datasubset', 'compound words', 'latex' ]

In [115]:
feature_names = feature_names.drop(6)

In [110]:
pd.set_option('max_colwidth', None)
feature_names

Unnamed: 0,Dataset,Features
0,part of speech,"upos_AUX, upos_ADP, upos_VERB, upos_ADJ, upos_X, upos_CCONJ, upos_PROPN, upos_NOUN, upos_INTJ, upos_SYM, upos_PRON, upos_DET, upos_SCONJ, upos_ADV, upos_PUNC, upos_NUM"
1,dependencies,"dep_aux:pass, dep_acl:relcl, dep_nsubj, dep_xcomp, dep_flat, dep_cc, dep_mark, dep_acl, dep_ccomp, dep_appos, dep_root, dep_nmod:poss, dep_aux, dep_amod, dep_nsubj:pass, dep_obj, dep_obl, dep_det, dep_advmod, dep_punct, dep_parataxis, dep_conj, dep_case, dep_list, dep_advcl, dep_cop, dep_compound, dep_nummod, dep_nmod"
2,morphology,"eos_True, Person_2, Gender_Fem, VerbForm_Inf, PronType_Dem, Gender_Masc, first_eos_True, Gender_Neut, VerbForm_Part, NumType_Card, PronType_Int, PronType_Prs, Person_3, Tense_Past, Number_Plur, PronType_Art, Voice_Pass, PronType_Rel, VerbForm_Ger, Mood_Imp, Person_1, VerbForm_Fin"
3,code language,"Python, XML, Java, C++, HTML, C, Go, PHP, JavaScript"
4,natural language,"Swedish, Portuguese, German, English, French, Spanish, Greek, Italian, Dutch"
5,text features,"leading_capital, no_leading_space_and_loweralpha, all_digits, is_not_ascii, has_leading_space, contains_all_whitespace, all_capitals, is_not_alphanumeric, contains_whitespace, contains_capital, contains_digit"
7,datasubset,"github, pubmed_abstracts, stack_exchange, wikipedia, freelaw, hackernews, arxiv, enron, uspto"
8,compound words,"mental-health, magnetic-field, trial-court, control-group, human-rights, north-america, clinical-trials, high-school, third-party, public-health, cell-lines, living-room, second-derivative, credit-card, social-media, prime-factors, federal-government, social-security, blood-pressure, gene-expression, side-effects"
9,latex,"is_superscript, is_inline_math, is_title, is_subscript, is_reference, is_denominator, is_author, is_numerator, is_display_math, is_math, is_abstract, is_frac"


In [117]:
feature_latex_table = feature_names.to_latex(index=False)
print(feature_latex_table)

\begin{tabular}{ll}
\toprule
         Dataset &                                                                                                                                                                                                                                                                                                                        Features \\
\midrule
  part of speech &                                                                                                                                                         upos\_AUX, upos\_ADP, upos\_VERB, upos\_ADJ, upos\_X, upos\_CCONJ, upos\_PROPN, upos\_NOUN, upos\_INTJ, upos\_SYM, upos\_PRON, upos\_DET, upos\_SCONJ, upos\_ADV, upos\_PUNC, upos\_NUM \\
    dependencies & dep\_aux:pass, dep\_acl:relcl, dep\_nsubj, dep\_xcomp, dep\_flat, dep\_cc, dep\_mark, dep\_acl, dep\_ccomp, dep\_appos, dep\_root, dep\_nmod:poss, dep\_aux, dep\_amod, dep\_nsubj:pass, dep\_obj, dep\_obl, dep\_det, dep\_advmod, dep\_punct, dep

  feature_latex_table = feature_names.to_latex(index=False)


In [7]:
ewt = load_feature_dataset('ewt.pyth.512.-1')
ewt



Dataset({
    features: ['Number', 'Mood', 'Tense', 'VerbForm', 'PronType', 'Person', 'NumType', 'Voice', 'Gender', 'eos', 'first_eos', 'upos', 'dep', 'head', 'within_compound_token_ix', 'max_compound_token_ix', 'tokens', 'doc_id', 'split', 'position', 'upos_NOUN|probe_indices', 'upos_NOUN|probe_classes', 'upos_PUNC|probe_indices', 'upos_PUNC|probe_classes', 'upos_ADP|probe_indices', 'upos_ADP|probe_classes', 'upos_NUM|probe_indices', 'upos_NUM|probe_classes', 'upos_SYM|probe_indices', 'upos_SYM|probe_classes', 'upos_SCONJ|probe_indices', 'upos_SCONJ|probe_classes', 'upos_ADJ|probe_indices', 'upos_ADJ|probe_classes', 'upos_DET|probe_indices', 'upos_DET|probe_classes', 'upos_CCONJ|probe_indices', 'upos_CCONJ|probe_classes', 'upos_PROPN|probe_indices', 'upos_PROPN|probe_classes', 'upos_PRON|probe_indices', 'upos_PRON|probe_classes', 'upos_X|probe_indices', 'upos_X|probe_classes', 'upos_ADV|probe_indices', 'upos_ADV|probe_classes', 'upos_INTJ|probe_indices', 'upos_INTJ|probe_classes', 'up

In [169]:
fds = ewt
features = []
for feature in fds.features:
    if feature.startswith('dep') and feature.endswith('probe_indices'):
        features.append(feature)

In [170]:
len(features)

59

In [131]:
len(filter_by_prefix(ewt.features, 'dep'))

59

In [176]:
list(filter(lambda x: not x.startswith("upos") and not x.startswith("dep") and x.endswith('probe_indices'), ewt.features))

['VerbForm_Fin|probe_indices',
 'VerbForm_Inf|probe_indices',
 'VerbForm_Ger|probe_indices',
 'VerbForm_Part|probe_indices',
 'PronType_Art|probe_indices',
 'PronType_Dem|probe_indices',
 'PronType_Prs|probe_indices',
 'PronType_Rel|probe_indices',
 'PronType_Int|probe_indices',
 'Person_1|probe_indices',
 'Person_2|probe_indices',
 'Person_3|probe_indices',
 'Gender_Masc|probe_indices',
 'Gender_Fem|probe_indices',
 'Gender_Neut|probe_indices',
 'Number_Plur|probe_indices',
 'Mood_Imp|probe_indices',
 'Tense_Past|probe_indices',
 'NumType_Card|probe_indices',
 'Voice_Pass|probe_indices',
 'eos_True|probe_indices',
 'first_eos_True|probe_indices']

In [606]:
dataset_df['Features'][1]

'Greek, German, French, Dutch, Spanish, Swedish, English, Portuguese, Italian'

In [None]:
np.unique(np.array(feature_dataset['lang']), return_counts=True)

In [679]:
load_feature_dataset('programming_lang_id.pyth.512.-1')['lang_prob']

tensor([0.8080, 1.0000, 1.0000,  ..., 0.9860, 1.0000, 0.9743])

In [130]:
load_feature_dataset('counterfact.pyth.64.-1')



1.0

In [75]:
feature_dataset = load_feature_dataset('counterfact.pyth.64.-1')

In [129]:
np.unique(np.array(feature_dataset['relation_id']), return_counts=True)[1]

array([1090, 1838, 1640,  700,  866, 1026, 1428, 1694,  558,  854, 1848,
       1512, 1750, 1822, 1156, 1558, 1238, 1632,  106, 1916, 1226, 1918,
        278, 1502, 1782,  952,  432, 1904, 1588,  326, 1808,  632, 1548,
       1692])

In [665]:
(fds['lang_prob']).mean().item()

0.9853899478912354

In [629]:
feature_names = [name for name in fds.features if '|probe_classes' in name]
feature_names

['is_frac|probe_classes',
 'is_numerator|probe_classes',
 'is_denominator|probe_classes',
 'is_title|probe_classes',
 'is_abstract|probe_classes',
 'is_author|probe_classes',
 'is_subscript|probe_classes',
 'is_superscript|probe_classes',
 'is_reference|probe_classes',
 'is_math|probe_classes',
 'is_inline_math|probe_classes',
 'is_display_math|probe_classes']

In [120]:
len(set(load_feature_dataset(feature_datasets[4])['relation_id']))

34

In [131]:
for dataset in feature_datasets:
    loaded_dataset = load_feature_dataset(dataset)
    print(loaded_dataset)

Dataset({
    features: ['text', 'meta', 'lang_prob', 'lang', 'all_tokens', 'tokens', 'class_ids', 'probe_indices', 'valid_indices'],
    num_rows: 5397
})
Dataset({
    features: ['lang', 'tokens', 'class_ids', 'probe_indices', 'valid_indices'],
    num_rows: 28084
})
Dataset({
    features: ['tokens', 'distribution', 'abs_pos', 'norm_abs_pos', 'rel_pos', 'norm_rel_pos', 'log_pos', 'index_mask'],
    num_rows: 10000
})
Dataset({
    features: ['text', 'meta', 'all_tokens', 'tokens', 'contains_digit|probe_indices', 'contains_digit|probe_classes', 'all_digits|probe_indices', 'all_digits|probe_classes', 'contains_capital|probe_indices', 'contains_capital|probe_classes', 'leading_capital|probe_indices', 'leading_capital|probe_classes', 'all_capitals|probe_indices', 'all_capitals|probe_classes', 'contains_whitespace|probe_indices', 'contains_whitespace|probe_classes', 'has_leading_space|probe_indices', 'has_leading_space|probe_classes', 'no_leading_space_and_loweralpha|probe_indices', 'no_

for dataset in feature_datasets:
    fds = load_feature_dataset(dataset)
    print(dataset, fds.shape)

In [346]:
set(load_feature_dataset(feature_datasets[5])['distribution'])

{'arxiv',
 'enron',
 'freelaw',
 'github',
 'hackernews',
 'pubmed_abstracts',
 'stack_exchange',
 'uspto',
 'wikipedia'}

In [None]:
fds = load_feature_dataset(feature_datasets[0])

In [62]:
fds

Dataset({
    features: ['lang', 'tokens', 'class_ids', 'probe_indices', 'valid_indices'],
    num_rows: 28084
})

In [None]:
fds['text']

In [12]:
set(fds['lang'])

{'C', 'C++', 'Go', 'HTML', 'Java', 'JavaScript', 'PHP', 'Python', 'XML'}

In [13]:
fds['tokens'].shape

torch.Size([5397, 512])

In [14]:
(fds['tokens'] > 1).numpy().sum()

2757867