In [17]:

import os
from os.path import join, exists

import pandas as pd
import glob

from utils import load_splits, split_gen
from utils_child import child_models

import config

## Useful functions

In [33]:
def check_chi_cgv_present(this_df):
    
    print(set(this_df.speaker_code))
    
    # 7/23/21: https://stackoverflow.com/questions/35277075/python-pandas-counting-the-occurrences-of-a-specific-value
    # Counting advice
    chi_num = this_df[this_df.speaker_code == 'CHI'].shape[0]
    cgv_num = this_df[this_df.speaker_code.isin({'FAT', 'MOT'})].shape[0]
    # end advice
    
    assert chi_num > 0 and cgv_num > 0
    print('Child tags:', chi_num, 'CGV tags:', cgv_num)

        
def load_marked_pooled_data(this_split_path):
    return pd.read_pickle(join(this_split_path, 'data_pool_with_phases.pkl'))

def check_disjoint_and_phase_written(split, name, base_dir):
    """
    Checks that the phase data indicated per entry corresponds to the written text in the file.
    By nature of the phase marks the data pool will be split disjointly.
    """
    
    this_split_loc = split_gen.get_split_folder(split, name, base_dir)
    this_pool_data = load_marked_pooled_data(this_split_loc)
    
    if split == 'child':
        data_cleaning.drop_errors(this_pool_data) # Don't consider the yyy, which are not written to the text files.
        assert this_pool_data[this_pool_data.gloss == 'yyy'].phase == 'val'
    
    for phase in ['train', 'val']:
        phase_locs = this_pool_data[this_pool_data['phase'] == phase]
        with open(join(this_split_loc, f"{phase}.txt"), 'r') as f:
            from_text_text = sorted([l.strip() for l in f.readlines()]) # Get rid of trailing \n
        
        from_df_text = sorted(list(phase_locs['gloss_with_punct']))
        
        assert from_text_text == from_df_text, f'Failed to match phase data for: {split}, {name}, {phase}'
    
    print(f'Assert passed for {split}, {name}')
    return True

## Checks

In [37]:
child_names = child_models.get_child_names()
all_phono = pd.read_pickle(join(config.eval_dir, 'pvd_all_tokens_phono_for_eval.pkl'))

phases = ['train', 'val', 'eval']

In [47]:

# Are the training/text files disjoint for non-Providence?
# (This was already written so might as well)

check_disjoint_and_phase_written('all', 'all', config.data_dir)
check_disjoint_and_phase_written('age', 'old', config.data_dir)
check_disjoint_and_phase_written('age', 'young', config.data_dir)

print('Asserts passed.')


Assert passed for all, all
Assert passed for age, old
Assert passed for age, young
Asserts passed.


In [35]:
# Make sure that [CHI], [CGV] are present in the model inputs
    
for s, d in config.childes_model_args:
    
    folder = split_gen.get_split_folder(s, d, config.data_dir)
    this_df = load_marked_pooled_data(folder)
    check_chi_cgv_present(this_df)


{'FAT', 'CHI', 'MOT'}
Child tags: 1640520 CGV tags: 2319432
{'FAT', 'CHI', 'MOT'}
Child tags: 985752 CGV tags: 1560943
{'FAT', 'CHI', 'MOT'}
Child tags: 618718 CGV tags: 639542


In [44]:
# Make sure that all of the eval and val data are separate for across_time_samples
# It's not disjoint!

all_time_samples = glob.glob(join(config.eval_dir, 'across_time_samples/*'))

val_ids = set(pd.concat([pd.read_csv(path) for path in all_time_samples if '_val' in path]).utterance_id)
eval_ids = set(pd.concat([pd.read_csv(path) for path in all_time_samples if '_eval' in path]).utterance_id)

val_phases = set(all_phono[all_phono.id.isin(val_ids)].phase)
eval_phases = set(all_phono[all_phono.id.isin(eval_ids)].phase)

print(val_phases)
print(eval_phases)

assert val_phases == {'val'}
assert eval_phases == {'eval'}


{'val', 'eval'}
{'val', 'eval'}


AssertionError: 

In [45]:
# Make sure that beta is on val only.

beta_ids = set(pd.concat([
    pd.read_csv(join(split_gen.get_split_folder(s, d, config.eval_dir), 'success_utts_beta_5000_val.csv'))
    for s, d in config.childes_model_args
]).utterance_id)


beta_phases = set(all_phono[all_phono.id.isin(beta_ids)].phase)
print(beta_phases)

assert beta_phases == {'val'}


{'val', 'eval'}


AssertionError: 

In [59]:
# Make sure the child train/val/eval data is separate

for name in child_names:
    
    child_pool = all_phono[all_phono.target_child_name == name]
    ids = {}
    for phase in phases:
        ids[phase] = set(child_pool[child_pool.phase_child == f"{name}_{phase}"].id)
    
    for p1 in phases:
        for p2 in phases:
            if p1 == p2: continue
            assert len(ids[p1] & ids[p2]) == 0
    
    # Make sure the val/eval transcript is from the right partition
    for phase, id_set in zip(['val', 'eval'], [val_ids, eval_ids]):
        assert (ids[phase] & id_set) == ids[phase]
        
print('Asserts passed.')
        

{16818177, 16818180, 16818182, 16818185, 16820234, 16818188, 16818191, 16818194, 16818198, 16818201, 16820250, 16818204, 16818207, 16818211, 16818214, 16818219, 16818222, 16818225, 16820275, 16818229, 16818232, 16820283, 16818236, 16818238, 16818242, 16818246, 16818249, 16818253, 16820301, 16818255, 16820312, 16818272, 16818277, 16818280, 16818283, 16820331, 16818287, 16818290, 16818293, 16818296, 16820346, 16818299, 16818302, 16818305, 16818308, 16818311, 16818313, 16818316, 16818319, 16820369, 16818322, 16818325, 16818328, 16818331, 16818335, 16818338, 16820386, 16818340, 16818344, 16818347, 16818350, 16818352, 16818355, 16820403, 16818358, 16818361, 16818364, 16818366, 16818369, 16818372, 16820421, 16818375, 16818378, 16818381, 16818384, 16818388, 16820438, 16818392, 16818395, 16818398, 16818401, 16818404, 16820454, 16818407, 16818410, 16818412, 16818415, 16818418, 16818421, 16818423, 16820471, 16818426, 16818429, 16818431, 16818434, 16818437, 16820486, 16818439, 16818442, 16818444,

AssertionError: 

In [51]:
# Check that sampling is actually restrictive (i.e. only successes and yyy are sampled.)

all_sampled_ids = (val_ids | eval_ids | beta_ids)
sampled = all_phono[all_phono.id.isin(all_sampled_ids)]

is_success = all_phono[all_phono.id.isin(all_sampled_ids)].success_token
is_yyy = all_phono[all_phono.id.isin(all_sampled_ids)].yyy_token

assert all(is_success | is_yyy)

print('Passed')


Passed
