In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
import pandas as pd, numpy as np 
from travis_attack.config import Config
from travis_attack.insights import get_training_dfs, _prepare_df_concat
from travis_attack.utils import display_all
from IPython.core.debugger import set_trace
import logging 
logger = logging.getLogger("run")

In [None]:
cfg = Config()
run_name = "civilized-womprat-133"
path_run = f"{cfg.path_checkpoints}{run_name}/"

## Looking at some examples 

Below we just look at some examples to get a feel for what is going on. 

In [None]:
def get_interesting_idx(df, n):
    def get_idx_with_top_column_values(cname, n=5, ascending=False):
        return df[['idx',cname]].\
            drop_duplicates().\
            sort_values(cname, ascending=ascending)\
            ['idx'][0:n].values.tolist()
    
    def sample_idx_with_label_flips(n=5): 
        df1 = df[['idx','label_flip']].query("label_flip!=0")
        if len(df1) == 0 : print("No label flips detected"); return None
        else: return df1.drop_duplicates()['idx'].sample(n).values.tolist()
    
    idx_d = dict(
        random = df.idx.drop_duplicates().sample(n).tolist(),
        label_flips = sample_idx_with_label_flips(n=n),
        idx_n_unique_pp  = get_idx_with_top_column_values('idx_n_unique_pp',n=n,ascending=False),
       # idx_n_pp_changes = get_idx_with_top_column_values('idx_n_pp_changes',n=n,ascending=False),
        high_contradiction = get_idx_with_top_column_values('contradiction_score',n=n,ascending=False)
    )
    return idx_d

def print_stats(df, idx_d, key, i):
    print("\n###############\n")
    print(key, i, "\n")
    if idx_d[key] is None: return
    idx = idx_d[key][i]
    # Setup 
    df1 = df.query('idx==@idx')
    orig = pd.unique(df1['orig_l'])[0]
    print("Original:", orig)
    print("Original label", pd.unique(df1['orig_label'])[0] )
    pp_all = list(df1['pp_l'])
    #print("All paraphrases", pp_all)
    pp_unique = list(pd.unique(df1['pp_l']))
    n_pp_unique = len(pp_unique)

    # showing a "timeline" of how the paraphrases change over the epochs
    g_fields = ["pp_l","pp_truelabel_probs","vm_score","sts_score","pp_letter_diff", "contradiction_score", "reward","label_flip", 'pp_logp',
       'ref_logp', 'kl_div', 'reward_with_kl', 'loss']
    #g_fields = ["pp_l","vm_score"]
    g = df1.groupby(g_fields).agg({'epoch' : lambda x: list(x)})
    g = g.sort_values(by='epoch', key = lambda col: col.map(lambda x: np.min(x)))
    print("Unique paraphrases:", n_pp_unique)
    print("How the paraphrases change:")
    display_all(g)

    # Showing a dataframe of the few best paraphrases
    best_pps = df1.sort_values('reward', ascending=False).iloc[0]
    print("Best Paraphrase")
    display_all(best_pps.to_frame().T)
        
def print_interesting_text_stats(df, n): 
    idx_d = get_interesting_idx(df, n)
    for key in idx_d.keys():
        for i in range(n): 
            print_stats(df, idx_d, key,i)

In [None]:
split = 'training_step'
df_d = get_training_dfs(path_run, postprocessed=True)
idx_d = get_interesting_idx(df_d[split], n=2)
print_interesting_text_stats(df_d[split], n=2)

ValueError: Cannot take a larger sample than population when 'replace=False'

## Looking at common removals and insertions 

In [None]:
def get_common_removals_and_insertions(df_concat): 
    idx = df_concat[['data_split','orig_l', 'pp_l']].drop_duplicates().index
    df_unique_pp = df_concat[['data_split','orig_l', 'pp_l','insertions', 'removals']].iloc[idx]
    def flatten_list(l): return [item for sublist in l for item in sublist] 
    removals_flat   =  flatten_list(df_unique_pp['removals'].values)
    insertions_flat =  flatten_list(df_unique_pp['insertions'].values)
    return pd.value_counts(removals_flat), pd.value_counts(insertions_flat)

In [None]:
df_concat = _prepare_df_concat(df_d)
removals, insertions = get_common_removals_and_insertions(df_concat)

print("\n#### REMOVALS ####\n")
print(removals.head(30))
print("\n#### INSERTIONS ####\n")
print(insertions.head(30))


#### REMOVALS ####

not                   7
this                  5
love                  3
film                  2
movie                 2
apple                 2
like                  1
hate                  1
I do not like this    1
this apple            1
do                    1
I do not like         1
I love                1
dtype: int64

#### INSERTIONS ####

.                                    21
n't                                   7
like                                  2
that                                  2
film.                                 2
- I like it                           1
thyme                                 1
really                                1
Not so much                           1
'.                                    1
:                                     1
''.                                   1
love                                  1
ie                                    1
cherry.                               1
enjoy                          

Here you can look at a specific phrase and examples of where it appears. 

In [None]:
def investigate_phrase(phrase, cname, n ): 
    mask = [phrase in strs for strs in df_concat[cname]]
    display_all(df_concat[mask].sample(n))

In [None]:
investigate_phrase('despite', 'removals', 4)

ValueError: a must be greater than 0 unless no samples are taken