In [1]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

# hacky way to import functions from analysis/compute_properties/helpers.py
sys.path.insert(1, '../compute_properties/')
from helpers import unpipe, pipe, rm_non_alphanum_from_str

# ignore warning
import warnings
warnings.filterwarnings('ignore')

# Convert sample-level to turn-level dataframes

Below is the list of columns the generated turn-level dataframe will have.

| Column | Description | Type |
| ------ | ----------- | ---- |
| **MAIN DATA** |||
| dia_id | ID of original corpus dialogue. | int |
| sample_id | Sample index. | int |
| gen_id | Model response index. | int |
| curr_turn_idx_in_orig | Response utterance index in original dialogue. | int |
| dist_from_prev_turn | Distance (of response) from previous utterance in context. | int |
| prev_turn_txt | Text of previous utterance. | str |
| curr_turn_txt | Text of current (response) utterance. | str |
| **LABELS** |||
| speaker_curr | Speaker label of current turn. | str |
| speaker_prev | Speaker label of previous turn. | str |
| is_same_speaker | Whether current and previous utterances were produced by the same speaker. | bool |
| **OVERLAPS** | | |
| vocab_overlap | VO. | float |
| constr_overlap | CO. | float |
| constr_overlap_h | CO with human response. | float |
| vocab_overlap_h | VO with human response. | float |
| abs_diff_m_h_co | Absolute difference of model and human CO. | float |
| abs_diff_m_h_vo | Absolute difference of model and human VO. | float |
| **ATTRIBUTIONS** |||
| attrib_to_prev | Attribution to previous utterance. | float |
| attrib_to_curr | Attribution to current utterance (response). | float |
| attrib_to_curr_comp | Attribution to current utterance while comprehending human response. | float |
| attrib_to_prev_comp | Attribution to previous utterance while comprehending human response. | float |
| mean_attrib | Mean attribution to utterances in the context. | float |
| mean_attrib_comp | Mean attribution to utterances in the context (while comprehending). | float |
| s_attrib_to_curr | Attribution to current utt. speaker label. | float |
| s_attrib_to_curr_comp | Comprehension attribution to current utt. speaker label. | float |
| s_attrib_to_prev | Attrib. to previous utt. speaker label. | float |
| s_attrib_to_prev_comp | Comprehension attrib. to previous utt. speaker label. | float |
| s_mean_attrib | Mean attribution to speaker labels in the context. | float |
| s_mean_attrib_comp | Mean attribution to speaker labels in the context (while comprehending). | float |
| **CONSTRUCTIONS** |||
| constrs | List of constructions. | pipe-separated strs |
| constr_lens | Lengths of constructions (in words). | pipe-separated ints |
| constr_freqs | Occurence frequencies of constructions. | pipe-separated ints |
| avg_constr_len | Mean construction length (in words). | float |
| pmi | PMI of constructions. | pipe-separated floats |
| pmi_avg | Mean PMI. | float |
| repeated_words | List of repeated words between the two utterances. | pipe-separated strs |
| non_rep_constrs | List of non-repeated constructions. | |
| n_words_not_in_ctx | Number of words in the response that are not in the context. | |
| **GENERATION QUALITY** |||
| bert_precision | BERTScore precision. | float |
| bert_recall | BERTScore recall. | float |
| bert_f1 | BERTScore F1. | float |
| bleu | BLEU score. | float |
| bleu_brevity_penalty | BLEU brevity penalty. | float |
| bleu_length_ratio | BLEU length ratio. | float |
| bleu_translation_length | BLEU translation length. | float |
| bleu_reference_length | BLEU reference length. | float |
| **PERPLEXITIES** |||
| ppl_gpt2_model_trg_dectx | PPL (perplexity) of GPT-2 on current utterance (without context). | float |
| ppl_gpt2_human_trg_dectx | PPL of GPT-2 on the human-produced response (without context). | float |
| abs_diff_ppl_gpt2_dectx | Absolute difference between the above two columns. | float |
| ppl_gpt2_model_trg_ctx | PPL of GPT-2 on current utterance (with context). | float |
| ppl_gpt2_human_trg_ctx | PPL of GPT-2 on human response (with context). | float |
| abs_diff_ppl_gpt2_ctx | Absolute difference between the above two columns. | float |
| ppl_pythia_model_trg_dectx | See above. Model is Pythia-1.4b instead of GPT-2 | float |
| ppl_pythia_human_trg_dectx | ^ | float |
| abs_diff_ppl_pythia_dectx | ^ | float |
| ppl_pythia_model_trg_ctx | ^ | float |
| ppl_pythia_human_trg_ctx | ^ | float |
| abs_diff_ppl_pythia_ctx | ^ | float |

In [2]:
def convert_sample_2_turn_level(
        sample_level_file: str,
        turn_level_file: str = None,
        n_model_responses: int = 5,
        save_every: int = 50,
) -> pd.DataFrame:
    """
    Convert a sample-level dataframe to a turn-level dataframe.

    :param sample_level_file: str
        Path to the sample-level dataframe (input).
    :param turn_level_file: str
        Path to the turn-level dataframe (output).
        If not specified, the created dataframe will be returned.
    :return: pd.DataFrame
        Turn-level dataframe.
    """
    samples = pd.read_csv(
        sample_level_file,
        sep='\t',
        header=0,
        index_col='sample_index',
    )

    ppl_models = set([
        col.split('_')[-1].strip().lower()
        for col in samples.columns
        if col.startswith('ppl_')
    ])

    turn_cols = [
        # ---------------------
        # MAIN COLS
        # ---------------------
        # 'dia_id',
        'sample_id',
        'gen_id',
        'curr_turn_idx_in_orig',
        'dist_from_prev_turn',
        'prev_turn_txt',
        'curr_turn_txt',
        'human_resp_txt',  # NEW
        # ---------------------
        # LABELS
        # ---------------------
        'speaker_curr',
        'speaker_prev',
        'is_same_speaker',
        # ---------------------
        # OVERLAPS
        # ---------------------
        'vocab_overlap',
        'constr_overlap',
        'constr_overlap_h',  # CO with human response
        'vocab_overlap_h'  # VO with human response
        'abs_diff_m_h_co',  # |CO - CO_h|
        'abs_diff_m_h_vo',  # |VO - VO_h|
        # ---------------------
        # ATTRIBUTIONS
        # ---------------------
        'attrib_to_prev',
        'attrib_to_curr',
        'attrib_to_curr_comp',  # attrib_to_curr while comprehending human response
        'attrib_to_prev_comp',
        'mean_attrib',
        'mean_attrib_comp',
        's_attrib_to_curr'  # attrib to [s]peaker label
        's_attrib_to_curr_comp',
        's_attrib_to_prev',
        's_attrib_to_prev_comp',
        's_mean_attrib',
        's_mean_attrib_comp',
        # ---------------------
        # CONSTRUCTIONS
        # ---------------------
        'constrs',
        'constr_lens',
        'constr_freqs',
        'avg_constr_len',
        'pmi',
        'pmi_avg',
        'repeated_words',
        # 'non_rep_constrs',  # ???
        'n_words_not_in_ctx',
        # ---------------------
        # GENERATION QUALITY
        # ---------------------
        # 'bert_precision',
        # 'bert_recall',
        'bert_f1',
        'bleu',
        'bleu_brevity_penalty',
        'bleu_length_ratio',
        # 'bleu_translation_length',
        # 'bleu_reference_length',
        'mauve',  # NEW, corpus-level
    ]
    # ---------------------
    # PERPLEXITIES
    # ---------------------
    if 'gpt2' in ppl_models:
        turn_cols += [
            # GPT-2:
            'ppl_gpt2_model_trg_dectx',
            'ppl_gpt2_human_trg_dectx',
            'abs_diff_ppl_gpt2_dectx',  # |ppl_gpt2_model_trg_dectx - ppl_gpt2_human_trg_dectx|
            'ppl_gpt2_model_trg_ctx',
            'ppl_gpt2_human_trg_ctx',
            'abs_diff_ppl_gpt2_ctx',  # |ppl_gpt2_model_trg_ctx - ppl_gpt2_human_trg_ctx|
        ]
    if 'pythia' in ppl_models:
        turn_cols += [
            # Pythia-1.4b:
            'ppl_pythia_model_trg_dectx',
            'ppl_pythia_human_trg_dectx',
            'abs_diff_ppl_pythia_dectx',  # |ppl_pythia_model_trg_dectx - ppl_pythia_human_trg_dectx|
            'ppl_pythia_model_trg_ctx',
            'ppl_pythia_human_trg_ctx',
            'abs_diff_ppl_pythia_ctx',  # |ppl_pythia_model_trg_ctx - ppl_pythia_human_trg_ctx|
        ]
    turns = pd.DataFrame(columns=turn_cols)

    def save_turn_level_to_disk():
        turns.to_csv(
            turn_level_file, 
            sep='\t', 
            header=True, 
            index=True,
            index_label='index',
        )

    """
    - Each row in the sample-level dataframe will correspond to
      n_utts_context * n_model_responses rows in the turn-level dataframe.
      That is 9 * 5 = 45 rows if the context length is 10 and the number of
      model responses is 5.
    - We pair each response with all the utterance of the context.
    """
    for sample_index, sample in tqdm(
        samples.iterrows(), 
        total=len(samples)
    ):
        # prep excerpt utterances
        context = sample['prompt'].strip()
        context = [utt.strip() for utt in context.split('\\n')]
        context = [{'label': utt[:1].strip(), 'utt': utt[2:].strip()} for utt in context]
        n_utts_ctx = len(context)
        # prep human response
        human_response = sample['human_response'].strip()
        human_response = {'label': human_response[:1].strip(), 'utt': human_response[2:].strip()}
        response_label = human_response['label']
        # prep model responses
        model_responses = []
        for gen_idx in range(n_model_responses):
            try:
                model_resp = sample[f'model_response_r{gen_idx}'].replace('\\n', '').strip()
            except AttributeError:
                # empty model response
                model_resp = ''
            model_responses.append({
                'label': response_label,
                'utt': model_resp
            })
        # get words set in context
        words_ctx = [
            rm_non_alphanum_from_str(utt['utt'].strip().lower()).strip().split()
            for utt in context
        ]
        words_ctx = [word for utt in words_ctx for word in utt]  # flatten
        words_ctx = set(words_ctx)

        # get overlap scores and quality scores
        vos_human = unpipe(sample['vo_human'])
        cos_human = unpipe(sample['co_human'])
        mauve_score = sample['mauve_score']
        bleu_score = sample['bleu_score']
        bleu_bp = sample['bleu_brevity_penalty']
        bleu_lr = sample['bleu_length_ratio']
        # get comprehension attribution scores
        attribs_comp = unpipe(sample['comprehend_attribs'])[::-1]
        attribs_tag_comp = unpipe(sample['comprehend_attribs_tag'])[::-1]
        attrib_curr_comp = attribs_comp[-1] if len(attribs_comp) > 0 else np.nan
        attrib_curr_tag_comp = attribs_tag_comp[-1] if len(attribs_tag_comp) > 0 else np.nan
        attribs_comp = attribs_comp[:-1]  # remove response
        attribs_tag_comp = attribs_tag_comp[:-1]  # remove response
        mean_attrib_comp = np.nanmean(attribs_comp)
        mean_attrib_tag_comp = np.nanmean(attribs_tag_comp)

        # iterate over model responses
        for gen_idx, model_response in enumerate(model_responses):

            # get current utterance and scores
            curr_utt = model_response['utt']
            vos = unpipe(sample[f"vo_r{gen_idx}"])
            cos = unpipe(sample[f"co_r{gen_idx}"])
            bert_f1 = sample[f"bert_f1_r{gen_idx}"]
            # get attribution scores.
            # reverse attribution scores list
            # the original list contains scores where the first element
            # is the score for the response and the last element is the
            # score for the first utterance in the context
            attribs = unpipe(sample[f"attribs_r{gen_idx}"])[::-1]
            attribs_tag = unpipe(sample[f"attribs_tag_r{gen_idx}"])[::-1]
            attrib_curr = attribs[-1] if len(attribs) > 0 else np.nan
            attrib_curr_tag = attribs_tag[-1] if len(attribs_tag) > 0 else np.nan
            attribs = attribs[:-1]  # remove response
            attribs_tag = attribs_tag[:-1]  # remove response
            mean_attrib = np.nanmean(attribs)
            mean_attrib_tag = np.nanmean(attribs_tag)
            # get pmi
            pmi = sample[f"pmi_r{gen_idx}"]
            pmi_avg = np.nanmean(unpipe(pmi))
            
            # iterate over context utterances
            for turn_idx, turn in enumerate(context):

                # get previous utterance and repeated words
                prev_utt = turn['utt']
                repeated_words = pipe(
                    set(curr_utt.strip().split()) \
                        .intersection(set(prev_utt.split()))
                )
                # get constructions
                constrs = unpipe(sample[f"constrs_model_r{gen_idx}_turn{turn_idx}"])
                constr_lens = ''
                if len(constrs) > 0:
                    constr_lens = [
                        len(constr.strip().split())
                        for constr in constrs
                    ]
                    constr_lens = pipe(constr_lens)
                constr_freqs = sample[f"constr_freqs_model_r{gen_idx}_turn{turn_idx}"]
                avg_constr_len = np.nanmean(constr_lens) if not isinstance(constr_lens, str) else np.nan

                # determine number of words not in context
                words_curr = set(rm_non_alphanum_from_str(curr_utt.strip().lower()).strip().split())
                word_not_in_ctx = words_curr - words_ctx
                n_words_not_in_ctx = len(word_not_in_ctx)
                
                # create new row in the turn-level dataframe
                new_row = dict(
                    # main cols:
                    sample_id=sample_index,
                    gen_id=gen_idx,
                    curr_turn_idx_in_orig=sample['first_utt_idx_in_diag'] + n_utts_ctx,
                    dist_from_prev_turn=n_utts_ctx - turn_idx + 1,  # 0: same, 10: first utt in ctx
                    prev_turn_txt=prev_utt,
                    curr_turn_txt=curr_utt,
                    human_resp_txt=human_response['utt'],
                    # labels:
                    speaker_curr=model_response['label'],
                    speaker_prev=turn['label'],
                    is_same_speaker=int(model_response['label'] == turn['label']),
                    # overlaps:
                    vocab_overlap=vos[turn_idx] if len(vos) > 0 else np.nan,
                    constr_overlap=cos[turn_idx] if len(cos) > 0 else np.nan,
                    constr_overlap_h=cos_human[turn_idx] if len(cos_human) > 0 else np.nan,
                    vocab_overlap_h=vos_human[turn_idx] if len(vos_human) > 0 else np.nan,
                    abs_diff_m_h_co=np.abs(cos[turn_idx] - cos_human[turn_idx]) if len(cos) > 0 else np.nan,
                    abs_diff_m_h_vo=np.abs(vos[turn_idx] - vos_human[turn_idx]) if len(vos) > 0 else np.nan,
                    # attributions:
                    attrib_to_prev=attribs[turn_idx] if len(attribs) > 0 else np.nan,
                    attrib_to_curr=attrib_curr,
                    attrib_to_curr_comp=attrib_curr_comp,
                    attrib_to_prev_comp=attribs_comp[turn_idx] if len(attribs_comp) > 0 else np.nan,
                    mean_attrib=mean_attrib,
                    mean_attrib_comp=mean_attrib_comp,
                    s_attrib_to_curr=attrib_curr_tag,
                    s_attrib_to_curr_comp=attrib_curr_tag_comp,
                    s_attrib_to_prev=attribs_tag[turn_idx] if len(attribs_tag) > 0 else np.nan,
                    s_attrib_to_prev_comp=attribs_tag_comp[turn_idx] if len(attribs_tag_comp) > 0 else np.nan,
                    s_mean_attrib=mean_attrib_tag,
                    s_mean_attrib_comp=mean_attrib_tag_comp,
                    # constructions:
                    constrs=constrs,
                    constr_lens=constr_lens,
                    constr_freqs=constr_freqs,
                    avg_constr_len=avg_constr_len,
                    pmi=pmi,
                    pmi_avg=pmi_avg,
                    repeated_words=repeated_words,
                    # non_rep_constrs=None,
                    n_words_not_in_ctx=n_words_not_in_ctx,
                    # gen quality:
                    # bert_precision=None,
                    # bert_recall=None,
                    bert_f1=bert_f1,
                    bleu=bleu_score,
                    bleu_brevity_penalty=bleu_bp,
                    bleu_length_ratio=bleu_lr,
                    # bleu_translation_length=None,
                    # bleu_reference_length=None,
                    mauve=mauve_score,  # TODO: currently this is corpus-level, break into corpus/generation/model/model_type
                )
                # perplexities:
                if 'gpt2' in ppl_models:
                    new_row.update(dict(
                        # GPT-2:
                        ppl_gpt2_model_trg_dectx=sample[f"ppl_model_response_r{gen_idx}_gpt2"],
                        ppl_gpt2_human_trg_dectx=sample[f"ppl_human_response_gpt2"],
                        abs_diff_ppl_gpt2_dectx=np.abs(
                            sample[f"ppl_model_response_r{gen_idx}_gpt2"] - 
                            sample[f"ppl_human_response_gpt2"]
                        ),
                        ppl_gpt2_model_trg_ctx=sample[f"ppl_model_response_r{gen_idx}_context_gpt2"],
                        ppl_gpt2_human_trg_ctx=sample[f"ppl_human_response_context_gpt2"],
                        abs_diff_ppl_gpt2_ctx=np.abs(
                            sample[f"ppl_model_response_r{gen_idx}_context_gpt2"] - 
                            sample[f"ppl_human_response_context_gpt2"]
                        ),
                    ))
                if 'pythia' in ppl_models:
                    new_row.update(dict(
                        # Pythia-1.4b:
                        ppl_pythia_model_trg_dectx=sample[f"ppl_model_response_r{gen_idx}_pythia"],
                        ppl_pythia_human_trg_dectx=sample[f"ppl_human_response_pythia"],
                        abs_diff_ppl_pythia_dectx=np.abs(
                            sample[f"ppl_model_response_r{gen_idx}_pythia"] - 
                            sample[f"ppl_human_response_pythia"]
                        ),
                        ppl_pythia_model_trg_ctx=sample[f"ppl_model_response_r{gen_idx}_context_pythia"],
                        ppl_pythia_human_trg_ctx=sample[f"ppl_human_response_context_pythia"],
                        abs_diff_ppl_pythia_ctx=np.abs(
                            sample[f"ppl_model_response_r{gen_idx}_context_pythia"] - 
                            sample[f"ppl_human_response_context_pythia"]
                        ),
                    ))

                # add new row to the turn-level dataframe
                turns = pd.concat([turns, pd.DataFrame([new_row])], ignore_index=True)
        
        if turn_level_file is not None and sample_index % save_every == 0:
            save_turn_level_to_disk()

    # save turn-level dataframe
    if turn_level_file is not None:
        save_turn_level_to_disk()
    return turns


In [3]:
tl = convert_sample_2_turn_level(
    sample_level_file = '../../data/samples_mini_gpt2_genq_constr_ppl_ol_pmi.tsv',
    turn_level_file   = '../../data/turns_mini_gpt2_genq_constr_ppl_ol_pmi.tsv',
    save_every=5,
)

100%|██████████| 10/10 [00:01<00:00,  5.26it/s]


In [4]:
tl.head()

Unnamed: 0,sample_id,gen_id,curr_turn_idx_in_orig,dist_from_prev_turn,prev_turn_txt,curr_turn_txt,human_resp_txt,speaker_curr,speaker_prev,is_same_speaker,...,ppl_gpt2_model_trg_dectx,ppl_gpt2_human_trg_dectx,abs_diff_ppl_gpt2_dectx,ppl_gpt2_model_trg_ctx,ppl_gpt2_human_trg_ctx,abs_diff_ppl_gpt2_ctx,vocab_overlap_h,abs_diff_m_h_co,s_attrib_to_curr,s_attrib_to_curr_comp
0,978,0,32,10,"Oh, for sure.. Oh, yeah.","Uh.. but, you know,. a pet. I mean, I sometime...",Yeah.,A,B,0,...,3.838915,4.977974,1.139059,2.833378,2.862912,0.029535,0.5,0.434783,-0.234607,-0.670036
1,978,0,32,9,"So, I've had this dog now for, for sixteen yea...","Uh.. but, you know,. a pet. I mean, I sometime...",Yeah.,A,A,1,...,3.838915,4.977974,1.139059,2.833378,2.862912,0.029535,0.5,0.130435,-0.234607,-0.670036
2,978,0,32,8,"Yes,. well. My, my mother has one that's, uh,....","Uh.. but, you know,. a pet. I mean, I sometime...",Yeah.,A,B,0,...,3.838915,4.977974,1.139059,2.833378,2.862912,0.029535,0.5,0.434783,-0.234607,-0.670036
3,978,0,32,7,Uh-huh.,"Uh.. but, you know,. a pet. I mean, I sometime...",Yeah.,A,A,1,...,3.838915,4.977974,1.139059,2.833378,2.862912,0.029535,0.5,0.0,-0.234607,-0.670036
4,978,0,32,6,"But, uh, it was, it, it's it's really made suc...","Uh.. but, you know,. a pet. I mean, I sometime...",Yeah.,A,B,0,...,3.838915,4.977974,1.139059,2.833378,2.862912,0.029535,0.5,0.130435,-0.234607,-0.670036


In [5]:
tl.shape

(450, 49)