In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
## Imports and environment variables 
import os
import torch
import wandb
from travis_attack.utils import set_seed, set_session_options, setup_logging, setup_parser, resume_wandb_run, display_all
from travis_attack.config import Config
from travis_attack.models import prepare_models, get_optimizer
from travis_attack.data import ProcessedDataset
from travis_attack.trainer import Trainer
from travis_attack.insights import (postprocess_df, create_and_log_wandb_postrun_plots, get_training_dfs)
from fastcore.basics import in_jupyter

import logging 
logger = logging.getLogger("run")

In [None]:
from nbdev.export import notebook2script
notebook2script()

!jupyter nbconvert \
    --TagRemovePreprocessor.enabled=True \
    --TagRemovePreprocessor.remove_cell_tags="['hide']" \
    --TemplateExporter.exclude_markdown=True \
    --to python "run.ipynb"

Converted 00_utils.ipynb.
Converted 02_tests.ipynb.
Converted 03_config.ipynb.
Converted 07_models.ipynb.
Converted 10_data.ipynb.
Converted 20_trainer.ipynb.
Converted 25_insights.ipynb.
Converted baselines.ipynb.
Converted baselines_analysis.ipynb.
Converted index.ipynb.
Converted pp_eval_baselines.ipynb.
Converted run.ipynb.
Converted show_examples.ipynb.
Converted test_pp_model.ipynb.
[NbConvertApp] Converting notebook run.ipynb to python


In [None]:
cfg = Config()  # default values
if not in_jupyter():  # override with any -- options when running with command line
    parser = setup_parser()
    newargs = vars(parser.parse_args())
    for k,v in newargs.items(): 
        if v is not None: 
            if k in cfg.pp.keys():  cfg.pp[k] = v
            else:                   setattr(cfg, k, v)
if cfg.use_small_ds:  cfg = cfg.small_ds()
set_seed(cfg.seed)
set_session_options()
setup_logging(cfg, disable_other_loggers=True)
vm_tokenizer, vm_model, pp_tokenizer, pp_model, ref_pp_model, sts_model, nli_tokenizer, nli_model, cfg = prepare_models(cfg)
optimizer = get_optimizer(cfg, pp_model)
ds = ProcessedDataset(cfg, vm_tokenizer, vm_model, pp_tokenizer, sts_model, load_processed_from_file=False)
cfg.wandb['mode'] = 'disabled'

In [None]:
trainer = Trainer(cfg, vm_tokenizer, vm_model, pp_tokenizer, pp_model, ref_pp_model, sts_model, nli_tokenizer, nli_model, optimizer,
                  ds, initial_eval=False, use_cpu=False)


True

In [None]:
from travis_attack.utils import unpack_nested_lists_in_df
from datasets import Dataset
from sentence_transformers.util import pytorch_cos_sim
import numpy as np, pandas as pd 
from wandb.data_types import Histogram

## Params
split='valid'

## Pre-epoch setup 
pp_model = pp_model.to(cfg.device) # (not needed in trainer )
# Data containers and data loaders
eval_epoch_df_d = dict(train=[], valid=[], test=[]) # each eval epoch appended to here 


for epoch in range(3): 
    eval_batch_results = list()  # each eval batch appended to here, list of dicts
    dl_key = "train_eval" if split == "train" else split
    dl_raw = ds.dld_raw[dl_key]
    dl_tkn = ds.dld_tkn[dl_key]
    ## Loop through batches in eval set
    for eval_batch_num, (data, raw) in enumerate(zip(dl_tkn, dl_raw)):
        pp_output = pp_model.generate(
                            input_ids=data['input_ids'].to(cfg.device), attention_mask=data['attention_mask'].to(cfg.device), 
                            **cfg.eval_gen_params,   remove_invalid_values=False, 
                            pad_token_id = pp_tokenizer.pad_token_id,eos_token_id = pp_tokenizer.eos_token_id)
        pp_l = pp_tokenizer.batch_decode(pp_output, skip_special_tokens=True)
        pp_l_nested = [pp_l[i:i+cfg.n_eval_seq] for i in range(0, len(pp_l), cfg.n_eval_seq)]
        all([len(l) == cfg.n_eval_seq for l in pp_l_nested])  # make sure we generate the same number of paraphrases for each
        eval_batch_results.append({'idx': raw['idx'], 'orig': raw['text'], 'pp_l':pp_l_nested, 'orig_n_letters': data['n_letters'].tolist(), 
                              'label': raw['label'], 'orig_truelabel_probs': data['orig_truelabel_probs'].tolist(), 'orig_sts_embeddings': data['orig_sts_embeddings'] })

    ## Convert eval batches to dataframes and create paraphrase identifier `pp_idx`
    df = pd.DataFrame(eval_batch_results)    
    df = df.apply(pd.Series.explode).reset_index(drop=True)  # This dataframe has one row per original example
    def get_pp_idx(row): return ["orig_" + str(row['idx']) + "-epoch_" + str(epoch) +  "-pp_" +  str(pp_i) for pp_i in range(1, len(row['pp_l'])+1)]
    df['pp_idx'] = df.apply(get_pp_idx, axis=1)

    ## Create seperate dataframe for sts scores and expand original dataframe
    df_sts = df[['pp_idx', 'pp_l', 'orig_sts_embeddings']] 
    df1 = df.drop(columns='orig_sts_embeddings')
    scalar_cols = [o for o in df1.columns if o not in ['pp_l', 'pp_idx']]
    df_expanded = unpack_nested_lists_in_df(df1, scalar_cols=scalar_cols) # This dataframe has one row per paraphrase

    ## Add vm_scores, sts_scores, pp_letter_diff, contradiction scores
    ds_expanded = Dataset.from_pandas(df_expanded)
    def add_vm_scores_eval(batch): 
        output = trainer._get_vm_scores(pp_l=batch['pp_l'], labels=torch.tensor(batch['label'], device = cfg.device), 
                                        orig_truelabel_probs=torch.tensor(batch['orig_truelabel_probs'], device=cfg.device))
        for k, v in output.items(): batch[k] = v.cpu().tolist() 
        return batch
    def add_pp_letter_diff(batch): 
        output = trainer._get_pp_letter_diff(pp_l=batch['pp_l'], orig_n_letters=batch['orig_n_letters'])
        for k, v in output.items(): batch[k] = v.tolist() 
        return batch
    def add_contradiction_score(batch): 
        batch['contradiction_scores'] = trainer._get_contradiction_scores(orig_l=batch['orig'], pp_l=batch['pp_l']).cpu().tolist()
        return batch
    ds_expanded = ds_expanded.map(add_vm_scores_eval,        batched=True)
    ds_expanded = ds_expanded.map(add_pp_letter_diff,        batched=True)
    ds_expanded = ds_expanded.map(add_contradiction_score,   batched=True)
    def add_sts_scores_eval(row):  return trainer._get_sts_scores_one_to_many(row['pp_l'], row['orig_sts_embeddings'])[0]
    df_sts['sts_scores'] = df_sts.apply(add_sts_scores_eval, axis=1)

    ## Merge together results 
    df_sts = df_sts.drop(columns = ['pp_l','orig_sts_embeddings'])
    df_sts_expanded = df_sts.apply(pd.Series.explode).reset_index(drop=True)
    ds_expanded = Dataset.from_pandas(ds_expanded.to_pandas().merge(df_sts_expanded, how='left', on='pp_idx').reset_index(drop=True))

    ## Calculate rewards and identify adversarial examples 
    def add_reward(batch): 
        batch['reward'] = trainer._get_reward(vm_scores=batch['vm_scores'], sts_scores=batch['sts_scores'],
                  pp_letter_diff=batch['pp_letter_diff'], contradiction_scores=batch['contradiction_scores']).cpu().tolist()
        return batch
    ds_expanded = ds_expanded.map(add_reward,   batched=True)
    def add_is_valid_pp(example): 
        example['is_valid_pp'] = trainer._is_valid_pp(sts_score=example['sts_scores'],
             pp_letter_diff=example['pp_letter_diff'], contradiction_score=example['contradiction_scores'])*1
        return example 
    ds_expanded = ds_expanded.map(add_is_valid_pp,   batched=False)
    def add_is_adv_example(batch): 
        batch['is_adv_example'] = (np.array(batch['is_valid_pp']) * np.array(batch['label_flip'])).tolist()
        return batch
    ds_expanded = ds_expanded.map(add_is_adv_example,   batched=True)

    ## Calculate summary statistics
    df_expanded = ds_expanded.to_pandas()
    eval_metric_cols = ['label_flip', 'is_valid_pp', 'is_adv_example', 'reward', 'vm_scores', 'sts_scores',  'pp_letter_diff', 'contradiction_scores']
    agg_metrics = ['mean','std']  # not going to use the median 
    # avg across each orig 
    df_grp_stats = df_expanded[['idx'] + eval_metric_cols].groupby('idx').agg(agg_metrics)
    df_grp_stats.columns = df_grp_stats.columns = ["-".join(a) for a in df_grp_stats.columns.to_flat_index()]
    # avg across whole dataset 
    df_overall_stats = df_expanded[eval_metric_cols].groupby(lambda _ : True).agg(agg_metrics).reset_index(drop=True)
    df_overall_stats.columns = df_overall_stats.columns = ["-".join(a) + "-" + split for a in df_overall_stats.columns.to_flat_index()]
    df_overall_metrics = df_overall_stats.iloc[0].to_dict()   ## WANDB this 
    df_overall_metrics['any_adv_example_proportion' + "-" + split] = np.mean((df_grp_stats['is_adv_example-mean'] > 0 ) * 1)
    # add epoch key
    df_expanded['epoch'] = epoch
    df_overall_metrics['epoch'] = epoch

    ## Log results to wandb 
    wandb_eval_d = dict()
    mean_only = ['label_flip', 'is_valid_pp', 'is_adv_example']
    mean_and_std = ['reward', 'vm_scores', 'sts_scores', 'pp_letter_diff', 'contradiction_scores']
    for k in mean_only + mean_and_std: 
        name = k + "-mean"
        wandb_eval_d[name + "-"+ split + "-hist"] = Histogram(df_grp_stats[name].tolist())
    for k in mean_and_std:
        name = k + "-std"
        wandb_eval_d[name + "-" + split + "-hist"] = Histogram(df_grp_stats[name].tolist())
    wandb_eval_d = merge_dicts(df_overall_metrics, wandb_eval_d)

    ## Save paraphrase-level dataframe 
    eval_epoch_df_d[split].append(df_expanded)
    
eval_final_dfs = dict()
for k in ['train', 'valid', 'test']:   eval_final_dfs[k] =  pd.concat(eval_epoch_df_d[k]) if eval_epoch_df_d[k] != [] else []
    
    

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




{'label_flip-mean-valid': 0.22916666666666666,
 'label_flip-std-valid': 0.42474439539379405,
 'is_valid_pp-mean-valid': 0.6041666666666666,
 'is_valid_pp-std-valid': 0.4942039949782704,
 'is_adv_example-mean-valid': 0.14583333333333334,
 'is_adv_example-std-valid': 0.35667395763741655,
 'reward-mean-valid': 0.5016760751605034,
 'reward-std-valid': 0.9725948553559205,
 'vm_scores-mean-valid': 0.09351677012940247,
 'vm_scores-std-valid': 0.20581416741302028,
 'sts_scores-mean-valid': 0.8149712784215808,
 'sts_scores-std-valid': 0.1742561535377078,
 'pp_letter_diff-mean-valid': 4.354166666666667,
 'pp_letter_diff-std-valid': 7.990658331287597,
 'contradiction_scores-mean-valid': 0.17747944022024362,
 'contradiction_scores-std-valid': 0.313157452842603,
 'any_adv_example_proportion-valid': 0.5,
 'epoch': 2}

{'num_return_sequences': 8,
 'max_length': 48,
 'do_sample': True,
 'num_beams': 1,
 'top_p': 0.95,
 'temperature': 0.8,
 'length_penalty': 1,
 'label_flip-mean-valid': 0.22916666666666666,
 'label_flip-std-valid': 0.42474439539379405,
 'is_valid_pp-mean-valid': 0.6041666666666666,
 'is_valid_pp-std-valid': 0.4942039949782704,
 'is_adv_example-mean-valid': 0.14583333333333334,
 'is_adv_example-std-valid': 0.35667395763741655,
 'reward-mean-valid': 0.5016760751605034,
 'reward-std-valid': 0.9725948553559205,
 'vm_scores-mean-valid': 0.09351677012940247,
 'vm_scores-std-valid': 0.20581416741302028,
 'sts_scores-mean-valid': 0.8149712784215808,
 'sts_scores-std-valid': 0.1742561535377078,
 'pp_letter_diff-mean-valid': 4.354166666666667,
 'pp_letter_diff-std-valid': 7.990658331287597,
 'contradiction_scores-mean-valid': 0.17747944022024362,
 'contradiction_scores-std-valid': 0.313157452842603,
 'any_adv_example_proportion-valid': 0.5,
 'epoch': 2}

In [None]:
### TODO next time: integrate in this bit to save ref model results
results = merge_dicts(cfg.eval_gen_params, df_overall_metrics)
ref_model_keys = ['pp_name', 'dataset_name', 'sts_name', 'nli_name', 'vm_name', 'seed', 'use_small_ds',  'reward_fn',
'reward_clip_max', 'reward_clip_min', 'reward_base', 'reward_vm_multiplier', 
 'sts_threshold', 'contradiction_threshold', 'pp_letter_diff_threshold',
 'max_pp_length', 'n_eval_seq', 'eval_decode_method', 'orig_max_length']
d = vars(cfg)
ref_model_d = dict((k, d[k]) for k in ref_model_keys if k in d)
results = merge_dicts(ref_model_d, results)

In [None]:
results

{'pp_name': 'prithivida/parrot_paraphraser_on_T5',
 'dataset_name': 'rotten_tomatoes',
 'sts_name': 'sentence-transformers/paraphrase-MiniLM-L12-v2',
 'nli_name': 'howey/electra-small-mnli',
 'vm_name': 'textattack/distilbert-base-uncased-rotten-tomatoes',
 'seed': 420,
 'use_small_ds': True,
 'reward_fn': 'reward_fn_contradiction_and_letter_diff',
 'reward_clip_max': 3,
 'reward_clip_min': 0,
 'reward_base': 0,
 'reward_vm_multiplier': 12,
 'sts_threshold': 0.7,
 'contradiction_threshold': 0.2,
 'pp_letter_diff_threshold': 30,
 'max_pp_length': 48,
 'n_eval_seq': 8,
 'eval_decode_method': 'sampling',
 'orig_max_length': 32,
 'num_return_sequences': 8,
 'max_length': 48,
 'do_sample': True,
 'num_beams': 1,
 'top_p': 0.95,
 'temperature': 0.8,
 'length_penalty': 1,
 'label_flip-mean-valid': 0.22916666666666666,
 'label_flip-std-valid': 0.42474439539379405,
 'is_valid_pp-mean-valid': 0.6041666666666666,
 'is_valid_pp-std-valid': 0.4942039949782704,
 'is_adv_example-mean-valid': 0.145833

In [None]:
cfg.path_run

In [None]:
vars(cfg).keys()

dict_keys(['pp_name', 'dataset_name', 'sts_name', 'nli_name', 'vm_name', 'seed', 'use_small_ds', 'sampling_strategy', 'lr', 'reward_fn', 'reward_clip_max', 'reward_clip_min', 'reward_base', 'reward_vm_multiplier', 'sts_threshold', 'contradiction_threshold', 'pp_letter_diff_threshold', 'reward_penalty_type', 'kl_coef', 'ref_logp_coef', 'max_pp_length', 'pp', 'n_eval_seq', 'eval_decode_method', 'eval_gen_params', 'orig_max_length', 'pin_memory', 'zero_grad_with_none', 'pad_token_embeddings', 'embedding_padding_multiple', 'orig_padding_multiple', 'bucket_by_length', 'shuffle_train', 'remove_misclassified_examples', 'remove_long_orig_examples', 'unfreeze_last_n_layers', 'n_shards', 'shard_contiguous', 'save_model_while_training', 'save_model_freq', 'wandb', 'device', 'devicenum', 'n_wkrs', 'splits', 'metrics', 'path_data', 'path_checkpoints', 'path_run', 'path_data_cache', 'path_logs', 'path_logfile', 'orig_cname', 'label_cname', 'batch_size_train', 'batch_size_eval', 'acc_steps', 'n_train

In [None]:
vars(cfg)

{'pp_name': 'prithivida/parrot_paraphraser_on_T5',
 'dataset_name': 'rotten_tomatoes',
 'sts_name': 'sentence-transformers/paraphrase-MiniLM-L12-v2',
 'nli_name': 'howey/electra-small-mnli',
 'vm_name': 'textattack/distilbert-base-uncased-rotten-tomatoes',
 'seed': 420,
 'use_small_ds': True,
 'sampling_strategy': 'sample',
 'lr': 4e-05,
 'reward_fn': 'reward_fn_contradiction_and_letter_diff',
 'reward_clip_max': 3,
 'reward_clip_min': 0,
 'reward_base': 0,
 'reward_vm_multiplier': 12,
 'sts_threshold': 0.7,
 'contradiction_threshold': 0.2,
 'pp_letter_diff_threshold': 30,
 'reward_penalty_type': 'ref_logp',
 'kl_coef': 0.2,
 'ref_logp_coef': 0.05,
 'max_pp_length': 48,
 'pp': {'do_sample': True,
  'min_length': 4,
  'max_length': 48,
  'temperature': 0.7,
  'top_p': 0.98,
  'length_penalty': 1.0,
  'repetition_penalty': 1.0},
 'n_eval_seq': 8,
 'eval_decode_method': 'sampling',
 'eval_gen_params': {'num_return_sequences': 8,
  'max_length': 48,
  'do_sample': True,
  'num_beams': 1,
 

In [None]:
df_eval_valid.query('idx ==600')

Unnamed: 0,idx,orig,orig_n_letters,label,orig_truelabel_probs,pp_l,pp_idx,pp_truelabel_probs,pp_predclass,pp_predclass_probs,vm_scores,label_flip,pp_letter_diff,pp_letter_percent,contradiction_scores,sts_scores,reward,is_valid_pp,is_adv_example,epoch
40,600,it's not original enough .,26,0,0.940268,he's not original enough.,orig_600-epoch_0-pp_1,0.936015,0,0.936015,0.004253,0,1,0.961538,0.011591,0.618556,0.0,0,0,0
41,600,it's not original enough .,26,0,0.940268,it's not original enough.,orig_600-epoch_0-pp_2,0.940268,0,0.940268,0.0,0,1,0.961538,0.003557,1.0,0.0,1,0,0
42,600,it's not original enough .,26,0,0.940268,i'm not original.,orig_600-epoch_0-pp_3,0.916352,0,0.916352,0.023916,0,9,0.653846,0.009075,0.646019,0.0,0,0,0
43,600,it's not original enough .,26,0,0.940268,i don't think it's original enough,orig_600-epoch_0-pp_4,0.878543,0,0.878543,0.061725,0,-8,1.307692,0.004506,0.837635,0.7407,1,0,0
44,600,it's not original enough .,26,0,0.940268,he's not original enough.,orig_600-epoch_0-pp_5,0.936015,0,0.936015,0.004253,0,1,0.961538,0.011591,0.618556,0.0,0,0,0
45,600,it's not original enough .,26,0,0.940268,it's not unique enough.,orig_600-epoch_0-pp_6,0.934703,0,0.934703,0.005565,0,3,0.884615,0.003278,0.695859,0.0,0,0,0
46,600,it's not original enough .,26,0,0.940268,it's not original enough.,orig_600-epoch_0-pp_7,0.940268,0,0.940268,0.0,0,1,0.961538,0.003557,1.0,0.0,1,0,0
47,600,it's not original enough .,26,0,0.940268,it's not enough.,orig_600-epoch_0-pp_8,0.894117,0,0.894117,0.046151,0,10,0.615385,0.004059,0.682818,0.0,0,0,0
40,600,it's not original enough .,26,0,0.940268,it's not original enough.,orig_600-epoch_1-pp_1,0.940268,0,0.940268,0.0,0,1,0.961538,0.003557,1.0,0.0,1,0,1
41,600,it's not original enough .,26,0,0.940268,that's not original enough.,orig_600-epoch_1-pp_2,0.937044,0,0.937044,0.003224,0,-1,1.038462,0.002852,0.927247,0.03869,1,0,1


In [None]:
# df_grp_stats  # log to wandb as hists, don't save 
# df_overall_metrics  # log to wandb as scalars, use to compare train+valid
# df_expanded  # save as csv

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Unnamed: 0,idx,orig,orig_n_letters,label,orig_truelabel_probs,pp_l,pp_idx,pp_truelabel_probs,pp_predclass,pp_predclass_probs,vm_scores,label_flip,pp_letter_diff,pp_letter_percent,contradiction_scores,sts_scores,reward,is_valid_pp,is_adv_example,epoch
0,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,initial strangeness gives way inexorably to rote sentimentality and mystical tenderness becomes narrative urgency.,orig_900-pp_1,0.2061,1,0.7939,0.3954986,1,4,0.966102,0.93078,0.978896,0.0,0,0,2
1,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,strangeness inevitably leads to rote sentimentality and magical tenderness becomes a narrative enticement..,orig_900-pp_2,0.290491,1,0.709509,0.3111077,1,11,0.90678,0.138746,0.912779,3.0,1,1,2
2,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,initial strangerity gives way inexorably to rote sentimentality and mystical tenderness becomes narrative -... '',orig_900-pp_3,0.40821,1,0.59179,0.1933883,1,5,0.957627,0.048963,0.868479,2.32066,1,1,2
3,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,inexorably it gives rise to a rote sentimentality and mystical tenderness becomes a narrative advantage.,orig_900-pp_4,0.06435,1,0.93565,0.5372486,1,14,0.881356,0.977789,0.909396,0.0,0,0,2
4,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,initial strangeness inexorably passes on to a rote sentimentality and mystical tenderness becomes narrative prophecy.,orig_900-pp_5,0.182131,1,0.817869,0.419467,1,1,0.991525,0.622755,0.939616,0.0,0,0,2
5,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,inexorably initial strangeness gives way to rote sentimentality and mystical tenderness becomes narrative-agent.,orig_900-pp_6,0.269235,1,0.730766,0.3323639,1,6,0.949153,0.557876,0.977993,0.0,0,0,2
6,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,inexorably the initial strangeness morphs into rote sentimentality and mystical tenderness becomes narrative esl,orig_900-pp_7,0.178586,1,0.821414,0.4230123,1,6,0.949153,0.056653,0.950465,3.0,1,1,2
7,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative power.,orig_900-pp_8,0.12717,1,0.87283,0.474428,1,6,0.949153,0.96321,0.984255,0.0,0,0,2
8,100,"an entertaining , if somewhat standardized , action movie .",59,1,0.937699,a charming if slightly standardized action film.,orig_100-pp_1,0.936667,1,0.936667,0.001031458,0,11,0.813559,0.004294,0.887767,0.012378,1,0,2
9,100,"an entertaining , if somewhat standardized , action movie .",59,1,0.937699,a fun if somewhat uniform action movie.,orig_100-pp_2,0.909019,1,0.909019,0.02867925,0,20,0.661017,0.006304,0.899923,0.344151,1,0,2


In [None]:
from pprint import pprint

In [None]:
epoch=2
df_expanded

Unnamed: 0,idx,orig,orig_n_letters,label,orig_truelabel_probs,pp_l,pp_idx,pp_truelabel_probs,pp_predclass,pp_predclass_probs,vm_scores,label_flip,pp_letter_diff,pp_letter_percent,contradiction_scores,sts_scores,reward,is_valid_pp,is_adv_example,epoch
0,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,the initial strangeness inexorably give way to the rote sentimentality and the mystical tenderness turns into narrative expediency.,orig_900-pp_1,0.16786,1,0.83214,0.4337382,1,-13,1.110169,0.063116,0.979874,3.0,1,1,2
1,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,initial strangeness inexorably gives way to rote sentimentality and the mystical tenderness becomes a narrative point.,orig_900-pp_2,0.120555,1,0.879445,0.4810435,1,0,1.0,0.077189,0.97251,3.0,1,1,2
2,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,inexorably strangeness inexorably gives way to sentimentality and mystical tenderness becomes a narrative ruse.,orig_900-pp_3,0.339279,1,0.660721,0.262319,1,7,0.940678,0.01261,0.942503,3.0,1,1,2
3,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,the initial strangeness gives way to rote sentimentality and the mystical tenderness is the narratives of the esoteric,orig_900-pp_4,0.121511,1,0.878489,0.4800876,1,0,1.0,0.246013,0.873858,0.0,0,0,2
4,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,inexorably the initial strangeness gives way to the rote sentimentality and mystical tenderness becomes the narrative power.,orig_900-pp_5,0.095084,1,0.904916,0.506514,1,-6,1.050847,0.974344,0.972545,0.0,0,0,2
5,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,initial shock and confusion inexorably give way to rote sentimental and mystical tenderness becomes narrative eloquence.,orig_900-pp_6,0.28143,1,0.71857,0.3201682,1,-2,1.016949,0.575564,0.89205,0.0,0,0,2
6,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,initially oddness gives way to rote sentimentality and mystical tenderness becomes narrative purpose...,orig_900-pp_7,0.543799,0,0.543799,0.05779958,0,15,0.872881,0.906632,0.952789,0.0,0,0,2
7,900,initial strangeness inexorably gives way to rote sentimentality and mystical tenderness becomes narrative expedience .,118,0,0.601598,initial strangeness inevitably gives way to rote sentimentality and mystifying tenderness becomes narrative expediency.,orig_900-pp_8,0.724889,0,0.724889,-0.1232904,0,-1,1.008475,0.090988,0.94463,0.0,1,0,2
8,100,"an entertaining , if somewhat standardized , action movie .",59,1,0.937699,an entertaining if somewhat standard action film,orig_100-pp_1,0.906897,1,0.906897,0.03080124,0,11,0.813559,0.011093,0.929442,0.369615,1,0,2
9,100,"an entertaining , if somewhat standardized , action movie .",59,1,0.937699,"a satirical, if a little standardized, action movie.",orig_100-pp_2,0.59733,1,0.59733,0.3403689,0,7,0.881356,0.209774,0.850707,0.0,0,0,2


In [None]:
cfg.wandb['mode'] = 'disabled'
trainer = Trainer(cfg, vm_tokenizer, vm_model, pp_tokenizer, pp_model, ref_pp_model, sts_model, nli_tokenizer, nli_model, optimizer,
                  ds, initial_eval=False, use_cpu=False)

print(vars(cfg))
trainer.train()



{'pp_name': 'prithivida/parrot_paraphraser_on_T5', 'dataset_name': 'rotten_tomatoes', 'sts_name': 'sentence-transformers/paraphrase-MiniLM-L12-v2', 'nli_name': 'howey/electra-small-mnli', 'vm_name': 'textattack/distilbert-base-uncased-rotten-tomatoes', 'seed': 420, 'use_small_ds': True, 'sampling_strategy': 'sample', 'lr': 4e-05, 'reward_fn': 'reward_fn_contradiction_and_letter_diff', 'reward_clip_max': 3, 'reward_clip_min': 0, 'reward_base': 0, 'reward_vm_multiplier': 12, 'sts_threshold': 0.6, 'contradiction_threshold': 0.2, 'pp_letter_diff_threshold': 30, 'reward_penalty_type': 'ref_logp', 'kl_coef': 0.2, 'ref_logp_coef': 0.01, 'pp': {'do_sample': True, 'min_length': 4, 'max_length': 48, 'temperature': 1, 'length_penalty': 1, 'top_p': 1, 'repetition_penalty': 1}, 'orig_max_length': 32, 'pin_memory': True, 'zero_grad_with_none': False, 'pad_token_embeddings': False, 'embedding_padding_multiple': 8, 'orig_padding_multiple': 8, 'bucket_by_length': True, 'shuffle_train': False, 'remove_m

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

travis_attack.trainer: INFO     Now on epoch 1 of 5


HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=8.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=2.0, style=ProgressStyle(description_width=…




travis_attack.trainer: INFO     Now on epoch 2 of 5


HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=8.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=2.0, style=ProgressStyle(description_width=…




travis_attack.trainer: INFO     Now on epoch 3 of 5


HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=8.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=2.0, style=ProgressStyle(description_width=…




travis_attack.trainer: INFO     Now on epoch 4 of 5


HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=8.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=2.0, style=ProgressStyle(description_width=…




travis_attack.trainer: INFO     Now on epoch 5 of 5


HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=4.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=8.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=2.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=1.0, style=ProgressStyle(description_width=…




In [None]:
# ## TO RESUME RUN
# cfg = Config()
# cfg.run_id = '2jq83fdx'
# cfg.run_name = "pleasant-wind-125"
# cfg.path_run = f"{cfg.path_checkpoints}{cfg.run_name}/"
# run = resume_wandb_run(cfg)


In [None]:
df_d = get_training_dfs(cfg.path_run, postprocessed=False)
for k, df in df_d.items(): 
    df_d[k] = postprocess_df(df, filter_idx=None, num_proc=1)
    df_d[k].to_pickle(f"{cfg.path_run}{k}_postprocessed.pkl")    
create_and_log_wandb_postrun_plots(df_d)
trainer.run.finish()
#run.finish()

travis_attack.insights: INFO     Dataframes have shapes ['training_step: (145, 48)', 'train: (145, 32)', 'valid: (30, 32)', 'test: (2, 32)']
travis_attack.insights: INFO     Adding text metrics for column orig_l


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Adding text metrics for column pp_l





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Calculating metric differences between orig and pp
travis_attack.insights: INFO     Calculating text pair statistics for (orig, pp) unique pairs





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Adding text metrics for column orig_l





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Adding text metrics for column pp_l





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Calculating metric differences between orig and pp
travis_attack.insights: INFO     Calculating text pair statistics for (orig, pp) unique pairs





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Adding text metrics for column orig_l





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Adding text metrics for column pp_l





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Calculating metric differences between orig and pp
travis_attack.insights: INFO     Calculating text pair statistics for (orig, pp) unique pairs





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Adding text metrics for column orig_l





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Adding text metrics for column pp_l





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

travis_attack.insights: INFO     Calculating metric differences between orig and pp
travis_attack.insights: INFO     Calculating text pair statistics for (orig, pp) unique pairs





HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…


