In [None]:
#default_exp trainer 

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
#export
import torch, wandb, gc, numpy as np, pandas as pd,os
from wandb.data_types import Histogram
from tqdm.auto import tqdm
from travis_attack.utils import (timecode, show_gpu, merge_dicts, unpack_nested_lists_in_df, 
                                 display_all, append_df_to_csv)
from travis_attack.tests import check_no_nans_or_infs
from travis_attack.models import save_pp_model, resume_pp_model, get_vm_probs, get_start_end_special_token_ids
from travis_attack.charts import plot_grad_flow, plot_examples_chart

In [None]:
#export
import torch, numpy as np, pandas as pd, gc,sys, logging, warnings
from torch.utils.data import DataLoader, RandomSampler
from datasets import load_dataset, load_metric, load_from_disk, DatasetDict
from transformers import (AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, 
                          AutoTokenizer, AdamW, SchedulerType, get_scheduler)
from torch.distributions import Categorical
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import pytorch_cos_sim
from collections import defaultdict
from accelerate import Accelerator, notebook_launcher
from cachetools import cached, LRUCache
from types import MethodType
from timeit import default_timer as timer
from tqdm.auto import tqdm
import itertools
import copy 
import wandb
from undecorated import undecorated

In [None]:
#hide
from travis_attack.utils import set_seed, set_session_options, prepare_logger
from travis_attack.config import Config
from travis_attack.models import prepare_models, get_optimizer
from travis_attack.data import ProcessedDataset
from accelerate import Accelerator, notebook_launcher
from fastcore.basics import store_attr
from IPython.core.debugger import set_trace

In [None]:
#hide
accelerator = Accelerator()
cfg = Config()
cfg.device = accelerator.device
set_seed(cfg.seed)
set_session_options()
logger = prepare_logger()
vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, cfg = prepare_models(cfg)
optimizer = get_optimizer(cfg, pp_model)
ds = ProcessedDataset(cfg, vm_tokenizer, vm_model, pp_tokenizer, sts_model)
vm_model,pp_model,sts_model,optimizer,ds.dld_tkn['train'] = accelerator.prepare(vm_model,pp_model,sts_model,optimizer,ds.dld_tkn['train'])
cfg.n_train_steps = cfg.n_train_epochs * len(ds.dld_tkn['train'])

# Trainer 

This class is used to fine-tune the model 

In [None]:
#export 
class Trainer: 
    def __init__(self, cfg, vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, optimizer, accelerator, ds,
                logger): 
        store_attr()
        self._cfg = self.cfg; del self.cfg;
        self.accumulation_num,self.global_step,self.eval_num = 0,0,0
        self._reset_batch_dicts()
        #resume_pp_model(f"{path_checkpoints}devout-durian-172_39")
        self._setup_wandb_run()
        self._setup_data_stores()
        if self._cfg.wandb['plot_examples']: self._setup_wandb_examples_plots()
        self.start_end_token_d = get_start_end_special_token_ids(self.pp_tokenizer)
        
    def train(self): 
        ## TODO: why is num_processes set to 1?
        #%lprun -f _training_function -f  get_pp_logp -f training_step -f  reward_fn -f  loss_fn -f eval_dl  notebook_launcher(_training_function, args=(pp_model, vm_model, dld_tkn, dld_raw, optimizer), num_processes=1, use_fp16=use_fp16)
        notebook_launcher(self._training_function, args=(), 
                           num_processes=1, use_fp16=self._cfg.use_fp16)
    
    def _reset_batch_dicts(self): 
          # train_batch_d holds all info to write to csv, time_d has times, wandb_d has everything to log to wandb
        # there will be overlap between them. 
        self.batch_d,self.batch_time_d,self.batch_wandb_d = dict(),dict(),dict()
    
    def _setup_wandb_run(self): 
        """Init wandb run, set up paths, create dir for model artifacts if needed, """
        ## TODO: set notebook name and add in save_code 
        self.run = wandb.init(project=self._cfg.wandb['project'], entity=self._cfg.wandb['entity'], 
                              config=vars(self._cfg), mode=self._cfg.wandb['mode'],
                              notes=self._cfg.wandb['run_notes'])
        if self._cfg.wandb['log_grads']: 
            wandb.watch(self.pp_model, log='gradients', log_freq=self._cfg.wandb['log_grads_freq'])
        self._cfg.run_name,self._cfg.run_id = self.run.name, self.run.id
        self._cfg.path_run = f"{self._cfg.path_checkpoints}{self.run.name}/"
        if not os.path.exists(self._cfg.path_run): os.makedirs(self._cfg.path_run, exist_ok=True)
    
    def _setup_data_stores(self): 
        """Setup dict `self.data_d` to store observations. Setup column names for wandb tables. """
        # Raw observation data (lists of dicts, later becomes pandas df)
        self.data_d = dict()
        for split in self._cfg.splits + ['training_step']:   self.data_d[split] = [] 
    
    def _setup_wandb_examples_plots(self): 
        """If we plot a few examples this sets that up."""
        def get_examples_plot_idxs(dataset): 
            """Get data indices for the examples plots"""
            return np.random.choice(dataset['idx'], size=self._cfg.wandb['n_examples_plot'], replace=False).tolist()
        self.plt_idx_d = dict()
        for split in self._cfg.splits:  self.plt_idx_d[split] = get_examples_plot_idxs(self.ds.dsd[split])

    def _training_function(self): 
        self.logger.debug(show_gpu(f'GPU memory usage after loading models:'))
        progress_bar = tqdm(range(self._cfg.n_train_steps))
        self.pp_model.zero_grad(set_to_none=self._cfg.zero_grad_with_none) 
        for self.epoch in range(self._cfg.n_train_epochs): 
            self.logger.info(f"Now on epoch {self.epoch} of {self._cfg.n_train_epochs}")
            if not self.pp_model.training: self.pp_model.train()
            with timecode() as time_train_one_epoch:
                for self.batch_num, (data, raw) in enumerate(zip(self.ds.dld_tkn['train'], self.ds.dld_raw['train'])): 
                    self._reset_batch_dicts()
                    self._training_step(data, raw) 
                    self.accumulation_num += 1  ; self.global_step += 1 ;  progress_bar.update(1) 
                    
            wandb.log({'time/train_one_epoch_time': time_train_one_epoch.t,
                       'time/train_one_epoch_thoroughput': len(self.ds.dsd_tkn['train']) / time_train_one_epoch.t,
                       'epoch': self.epoch}, commit=True)

            if self._cfg.wandb['log_grads'] and self.epoch % self._cfg.wandb_log_grads_freq == 0: 
                plt = plot_grad_flow(self.pp_model.named_parameters())
                wandb.log({"gradient flow": wandb.Image(plt)})  # doesn't work as a non-image (i.e. plotly)
                del plt 
            #gc.collect() 
            #torch.cuda.empty_cache()

            if self._cfg.save_model_while_training and (self.epoch + 1) % self._cfg.save_model_freq == 0:  save_model(epoch)

            # Evaluation loop
            if self.epoch % self._cfg.eval_freq == 0: 
                self.eval_num += 1
                with timecode() as time_eval_train:
                    self._eval_dl(split='train') # or train_eval?
                with timecode() as time_eval_valid:
                    self._eval_dl(split='valid')
                with timecode() as time_eval_compute_metrics: 
                    self._compute_and_log_eval_metrics()
                self._plot_wandb_charts()
                with timecode() as time_eval_gc_collect:
                    gc.collect() 
                with timecode() as time_eval_empty_cache:
                    torch.cuda.empty_cache()
                wandb.log({'time/eval_train_time': time_eval_train.t, 'time/eval_valid_time': time_eval_valid.t,
                           'time/eval_train_thoroughput': len(self.ds.dsd_tkn['train']) / time_eval_train.t,
                           'time/eval_valid_thoroughput': len(self.ds.dsd_tkn['valid']) / time_eval_valid.t, 
                           'time/eval_gc_collect': time_eval_gc_collect.t, 
                           'time/eval_empty_cache': time_eval_empty_cache.t,
                           'time/eval_compute_metrics': time_eval_compute_metrics.t,
                           'epoch': self.epoch}, commit=True)
        # Eval on test set 
        self._eval_dl(split='test')
        
        # Data -> df and save dfs to file 
        for key in ['test']:  
            self.data_d[key] = self._convert_data_d_to_df(key)
            self._set_df_colorder(key)
            self.data_d[key].to_csv(f"{self._cfg.path_run}{key}.csv", index=False)
        
        # _plot_wandb_charts()  # don't think I need this
        self._add_wandb_run_summary_statistics()   
        
        # some debug code to check data frames are working
        self.df_d = dict()
        for split in self._cfg.splits+['training_step']: 
            self.df_d[split] = pd.read_csv(f"{self._cfg.path_run}{key}.csv")
        self.run.finish()
        
    def _training_step(self, data, raw): 
        """Forward pass, loss function, backwards pass, parameter update (with gradient accumulation optional), 
        recording results, wandb logging. 
        """
        if not self.pp_model.training: self.pp_model.train()
        if not self.vm_model.training: self.vm_model.train()
        with timecode() as self.batch_time_d['time_generate_pp']:
            pp_output, pp_l = self._pp_model_forward(data)

        logger.debug(show_gpu(f'TRAIN, epoch {self.epoch}, batch {self.batch_num}, GPU memory usage after forward pass: '))

        with self.accelerator.autocast():
            with timecode() as self.batch_time_d['time_loss_fn']:
                loss_batch = self._loss_fn(data, raw, pp_output, pp_l)
            loss_batch = loss_batch / self._cfg.accumulation_steps  # Normalize our loss for gradient accumulation

        with timecode() as self.batch_time_d['time_backwards']:
            self.accelerator.backward(loss_batch) 

        logger.debug(show_gpu(f'TRAIN, epoch {self.epoch}, batch {self.batch_num}, GPU memory usage after backwards pass: '))
        if (self.accumulation_num + 1) % self._cfg.accumulation_steps == 0: 
            with timecode() as self.batch_time_d['time_opt_step']:
                self.optimizer.step()
            self.pp_model.zero_grad(set_to_none=self._cfg.zero_grad_with_none)
            
        self._prepare_train_batch_d(raw, data, pp_l)
        self.data_d['training_step'].append(self.batch_d)
        
        self._wandb_log_training_step()
    
    def _add_batch_vars_to_batch_d(self, raw, data, pp_l): 
        # Add basics. (results are already added elsewhere)
        self.batch_d = merge_dicts(self.batch_d, { 'idx': raw['idx'],
            'epoch': self.epoch, 'batch_num': self.batch_num, 'global_step': self.global_step,
            'accumulation_num': self.accumulation_num,  "orig_l": raw['text'], 
            "orig_label": data['label'].cpu().tolist(), 
            "orig_truelabel_probs": data['orig_truelabel_probs'].cpu().tolist(),
            'orig_length': self.orig_length, 'orig_batch_size': self.orig_batch_size, 
            "pp_l": pp_l, 'pp_length': self.pp_length, 'pp_batch_size': self.pp_batch_size
        })
        
    def _prepare_train_batch_d(self, raw, data, pp_l): 
        self._add_batch_vars_to_batch_d(raw, data, pp_l)
        # Add times (only for training, not eval)
        for k, v in self.batch_time_d.items(): self.batch_time_d[k] = v.t  # extract time from timecode object
        self.batch_d = merge_dicts(self.batch_d, self.batch_time_d)
    
    def _wandb_log_training_step(self): 
        self.batch_wandb_d = merge_dicts(self.batch_wandb_d, {
            'vm_scores_hist':       Histogram(self.batch_d['vm_score']), 
            'vm_scores_mean':       np.mean(  self.batch_d['vm_score']),
            'sts_scores_hist':      Histogram(self.batch_d['sts_score']),
            'sts_scores_mean':      np.mean(  self.batch_d['sts_score']), 
            'rewards_hist':         Histogram(self.batch_d['reward']),
            'rewards_mean':         np.mean(  self.batch_d['reward']), 
            'pp_logp_hist':         Histogram(self.batch_d['pp_logp']),
            'pp_logp_mean':         np.mean(  self.batch_d['pp_logp']),
            'loss_hist'   :         Histogram(self.batch_d['loss'])})
        self.batch_wandb_d = merge_dicts(self.batch_wandb_d, self.batch_d)
        not_for_wandb_keys = ['orig_l', 'orig_label','orig_truelabel_probs', 'pp_l', 'loss', 'pp_logp', 
                              'reward', 'sts_score', 'vm_score',
                              'pp_predclass_probs', 'label_flip', 'pp_predclass', 'pp_truelabel_probs']
        for k in not_for_wandb_keys:  self.batch_wandb_d.pop(k, None)
        wandb.log(self.batch_wandb_d, commit=True)
        
    def _wandb_log_eval_step(self): 
        ### TODO: implement
        pass
        wandb.log(self.batch_wandb_d, commit=True)
        
    def _convert_data_d_to_df(self, data_d_key): 
        df = pd.DataFrame(self.data_d[data_d_key])
        # check all lists have the same number of elements
        nonscalar_cols = df.columns[[o == np.dtype('object') for o in df.head(1).dtypes]].tolist()
        assert (df[nonscalar_cols].applymap(len) == cfg.batch_size_train).all(None)
        # expand lists and broadcast scalars
        scalar_cols = df.columns[[o != np.dtype('object') for o in df.head(1).dtypes]].tolist()
        df_expanded = unpack_nested_lists_in_df(df, scalar_cols)
        # check shape of new dataframe is correct 
        if data_d_key == "training_step": 
            if self.epoch == 0: 
                df_shape = (self._cfg.ds_length["train"],                       df.shape[1])
            else: 
                df_shape = (self._cfg.ds_length["train"] * self._cfg.eval_freq, df.shape[1])
        elif data_d_key in ["train", "valid", "test"]: 
            df_shape = (self._cfg.ds_length[data_d_key], df.shape[1])
        assert df_expanded.shape == df_shape
        return df_expanded
    
    def _pp_model_forward(self, data): 
        pp_output, pp_l = self._get_paraphrases(data['input_ids'], data['attention_mask'])
        self._assert_start_and_end_tokens_are_correct(orig_ids=data['input_ids'], pp_ids=pp_output.sequences)
        self._update_batch_size_and_length_variables( orig_ids=data['input_ids'], pp_ids=pp_output.sequences)
        return pp_output, pp_l
    
    def _assert_start_and_end_tokens_are_correct(self, orig_ids, pp_ids):
        """Make sure input sequences (orig) and output sequences (pp) start and end with the 
        right special tokens (depends on tokenizer)"""
        # Input
        if self.start_end_token_d['input_start_id'] is not None: 
            assert torch.all(orig_ids[:,0] == self.start_end_token_d['input_start_id'])
        # can probs rewrite this to make it nicer but it's fine for now
        assert torch.all(torch.logical_or(orig_ids[:,-1] == self.start_end_token_d['input_end_id'][0], 
                                          orig_ids[:,-1] == self.start_end_token_d['input_end_id'][1]))

        # Output
        assert torch.all(pp_ids[:,0] == self.start_end_token_d['output_start_id'])
        assert torch.all(torch.logical_or(pp_ids[:,-1] == self.start_end_token_d['output_end_id'][0], 
                                          pp_ids[:,-1] == self.start_end_token_d['output_end_id'][1]))
        
    def _update_batch_size_and_length_variables(self, orig_ids, pp_ids): 
        # Update variables
        # for greedy search self.pp_length is equal to self.orig_batch_size but this won't be for beam search
        self.orig_batch_size     = orig_ids.shape[0]
        self.orig_length         = orig_ids.shape[1]
        self.pp_batch_size       = pp_ids.shape[0]
        self.pp_length           = pp_ids.shape[1] 
    
    def _get_paraphrases(self, orig_ids, attention_mask):
        """Wrapper for generating paraphrases (pp's).  Only greedy search supported at the moment"""
        pp_output = self.pp_model.generate_with_grad(input_ids=orig_ids, 
                                                attention_mask=attention_mask, 
                                                 **self._cfg.pp,
                                                 do_sample=False, 
                                                 return_dict_in_generate=True,
                                                 output_scores=True,
                                                 remove_invalid_values=False, 
                                                 pad_token_id = self.pp_tokenizer.pad_token_id,
                                                 eos_token_id = self.pp_tokenizer.eos_token_id)
        pp_l = self.pp_tokenizer.batch_decode(pp_output.sequences, skip_special_tokens=True)
        return pp_output, pp_l
    
    def _loss_fn(self, data, raw, pp_output, pp_l): 
        with timecode() as self.batch_time_d['time_reward_fn']:
            reward = self._reward_fn(data, raw, pp_l)

        with timecode() as self.batch_time_d['time_pp_logp']:
            pp_logp = self._get_pp_logp(pp_output)

        with timecode() as self.batch_time_d['time_loss_fn_loss_calc']:
            loss = -reward * pp_logp
            loss_batch = torch.mean(loss)

        self.batch_d['pp_logp']    =    pp_logp.detach().cpu().tolist()
        self.batch_d['loss']       =       loss.detach().cpu().tolist()
        self.batch_d['loss_batch'] = loss_batch.detach().cpu().tolist()
        return loss_batch
    
    def _reward_fn(self, data, raw, pp_l): 
        """"""
        # Victim model probability differences between orig and pp
        with timecode() as self.batch_time_d['time_vm_scores']:
            pp_probs = get_vm_probs(pp_l, self._cfg, self.vm_tokenizer, self.vm_model, return_predclass=False)
            pp_predclass = torch.argmax(pp_probs, axis=1)
            pp_truelabel_probs   = torch.gather(pp_probs, 1, data['label'][:,None]).squeeze()
            pp_predclass_probs   = torch.gather(pp_probs, 1, pp_predclass[ :,None]).squeeze()
            label_flip = ((pp_predclass != data['label']) * 1)
            vm_scores = (data['orig_truelabel_probs'] - pp_truelabel_probs)
            
        # STS scores
        with timecode() as self.batch_time_d['time_sts_scores']:
            pp_embeddings  = self.sts_model.encode(pp_l, batch_size=len(raw), convert_to_tensor=True, device=self._cfg.device)
            # This returns a cosine similarity matrix, of which we just want the diagonal
            sts_scores = pytorch_cos_sim(data['orig_sts_embeddings'], pp_embeddings).diagonal()  

        # Reward calculation 
        rewards = torch.tensor([-0.5 if sts < 0.5 else 0.5+v*sts for v,sts in zip(vm_scores, sts_scores)],device=self._cfg.device)

        if self._cfg.normalise_rewards: 
            self.batch_d['reward_unscaled'] = rewards.detach().cpu().tolist()
            rewards = (rewards - torch.mean(rewards)) / torch.std(rewards)
        
        self.batch_d['pp_truelabel_probs']  = pp_truelabel_probs.detach().cpu().tolist()
        self.batch_d['pp_predclass']        = pp_predclass.detach().cpu().tolist()
        self.batch_d['pp_predclass_probs']  = pp_predclass_probs.detach().cpu().tolist()
        self.batch_d['label_flip']          = label_flip.detach().cpu().tolist()
        self.batch_d['label_flip_fraction'] = np.mean(self.batch_d['label_flip'])
        self.batch_d['reward']              = rewards.detach().cpu().tolist()
        self.batch_d['vm_score']            = vm_scores.detach().cpu().tolist()
        self.batch_d['sts_score']           = sts_scores.detach().cpu().tolist()
    
        return rewards
         
    def _get_pp_logp(self, pp_output): 
        """log(p(pp|orig)) basically.
        works for greedy search, will need tweaking for other types probably"""
        ### TODO: this looks like logp to me, not plogp. Find out if this is right and if so rename, if not, fix
        ### We want to align tokens with token probabilities. The first token is given at the start 
        # and has no probability attached to it, so we remove it. 
        seq_without_first_tkn = pp_output.sequences[:, 1:]
        assert seq_without_first_tkn.shape == torch.Size([self.orig_batch_size, self.pp_length - 1])

        ### Convert from tuple of scores to one big tensor of scores 
        scores_stacked = torch.stack(pp_output.scores, 1)
        ### TESTS 
        # We check shape and that there is no +inf or nan in scores. 
        # Scores can have -inf in them - see explanation in `exploring_generation`.  
        assert scores_stacked.shape == torch.Size([self.orig_batch_size, (self.pp_length - 1), self._cfg.vocab_size])
        assert torch.all(~torch.isnan(scores_stacked))
        assert torch.all(~torch.isposinf(scores_stacked))
        # Rough check that all idx before min_length are -inf for all elements in batch
        # We do min_length - 1 because sequences are allowed to have length min_length so that idx 
        # shouldn't be set to -inf
        # Not a 100% test but very likely to identify
        idx_neginf = torch.nonzero(torch.isneginf(scores_stacked))
        assert len(idx_neginf[idx_neginf[:,2] == self.pp_tokenizer.eos_token_id, :]) == \
                  (self._cfg.pp["min_length"] -1) * self.orig_batch_size  
        del idx_neginf

        ### Take log softmax of scores and then extract those that correspond 
        # to the generated sequences    
        scores_log_softmax = scores_stacked.log_softmax(2)
        seq_token_log_probs = torch.gather(scores_log_softmax,2,seq_without_first_tkn[:,:,None]).squeeze(-1)
        ### TESTS 
        # -inf is possible in scores_log_softmax and seq_token_log_probs before the attention mask is added. 
        assert torch.all(~torch.isnan(   scores_log_softmax))
        assert torch.all(~torch.isposinf(scores_log_softmax))
        self._check_scores_log_softmax_sums(scores_log_softmax)
        # probs should be 1-1 with the filtered tkns: check shape to confirm
        assert seq_token_log_probs.shape == seq_without_first_tkn.shape  
        # Check that the last token probability corresponds to a possible end token
        # this has to be tested before the attention mask is multiplied with it because if the 
        # padding token is 0 then this will be 0 too (and not the same as scores_log_softmax)
        output_end_ids = self.start_end_token_d['output_end_id']
        assert all([o in scores_log_softmax[:, -1, output_end_ids] for o in seq_token_log_probs[:,-1]])
        del output_end_ids
        ## THIS ONE IS LONG - a test rather than assert 
        # check_seq_token_log_prob_values_are_correct(seq_without_first_tkn, scores_log_softmax, 
        #                                             seq_token_log_probs) 

        ### Generate attention mask to identify padding tokens. Then apply it to the 
        # sequence probabilities so that we don't consider probability of padding tokens 
        # when getting sequence probabilities. 
        # Also replace the -inf values in seq_token_log_probs with a large negative number because if we 
        # leave them in we end up with nan's introduced after multiplying with attention_mask, 
        # since  -inf * 0 = nan 
        attention_mask = self.pp_model._prepare_attention_mask_for_generation(
            seq_without_first_tkn, self.pp_tokenizer.pad_token_id, self.pp_tokenizer.eos_token_id
        )
        seq_token_log_probs = torch.nan_to_num(seq_token_log_probs, nan=None, posinf=None, neginf=-10000)
        seq_token_log_probs = seq_token_log_probs * attention_mask
        ### TESTS
        assert seq_token_log_probs.shape == attention_mask.shape == seq_token_log_probs.shape
        # check attention mask only has 0 for padding tokens and not eos tokens or anything else
        assert all(seq_without_first_tkn[attention_mask == 0] == self.pp_tokenizer.pad_token_id)
        check_no_nans_or_infs(seq_token_log_probs)
        # check that we aren't picking extrememly rare tokens
        assert torch.all(seq_token_log_probs  > -10)  

        ### Get sequence probabilities by summing up token log probabilities 
        seq_log_prob = seq_token_log_probs.sum(-1)
        ## TESTS 
        assert seq_log_prob.shape == torch.Size([self.pp_batch_size])
        check_no_nans_or_infs(seq_log_prob)
        
        if self.pp_model.training:  # don't bother logging or calculate entropy, token_probs in eval mode
            if self._cfg.wandb['log_token_entropy']:
                with timecode() as self.batch_time_d['time_log_entropy']:
                    self.batch_wandb_d['ent_hist'] = self._get_entropy_hist(scores_stacked, attention_mask) 
            if self._cfg.wandb['log_token_probabilities']: 
                with timecode() as self.batch_time_d['time_log_token_probabilities']:
                    self.batch_wandb_d = merge_dicts(self.batch_wandb_d, 
                        self._get_token_probability_metrics(scores_log_softmax, attention_mask, k=3))
        return seq_log_prob
   
    def _check_scores_log_softmax_sums(self, scores_log_softmax):
        sums = scores_log_softmax.exp().sum(2)
        # check that the axes is right
        # we want to sum over token probabilities at each generation step, so we 
        # should end up with a shape [self.orig_batch_size, self.pp_length]
        assert sums.shape[0] == self.orig_batch_size  
        assert sums.shape[1] == self.pp_length - 1
        # check that they sum to 1 along the self.pp_length axis
        assert torch.allclose(sums, torch.ones(sums.size(), device=self._cfg.device), atol = 1e-4)

    def _check_seq_token_log_prob_values_are_correct(self, seq_without_first_tkn, scores_log_softmax, seq_token_log_probs): 
        """Just enumerates and checks values
        Quite slow for large batches so run as a test rather than an assert in every batch. 
        """
        l = []
        for i_ex in range(self.orig_batch_size):
            for i_step in range(self.pp_length - 1):
                i_tkn = seq_without_first_tkn[i_ex][i_step].item()
                l.append(scores_log_softmax[i_ex,i_step, i_tkn] == seq_token_log_probs[i_ex,i_step])
        assert all(l)    
    
    def _get_entropy_hist(self, scores_stacked, attention_mask): 
        ent = Categorical(logits = scores_stacked).entropy().detach()
        assert ent.shape == attention_mask.shape == torch.Size([self.pp_batch_size, self.pp_length - 1])
        ent = ent * attention_mask  # stop values after eos token from contributing to ent score 
        # first remove structure (otherwise we have ragged arrays), then remove corresponding attention mask values
        # we can't just filter by ent[ent != 0] because we might have zero tokens during the sequence
        att_flat= attention_mask.flatten()
        indices = torch.nonzero(att_flat)
        ent_flat = ent.flatten()[indices].flatten()
        assert ent_flat.shape[0] == (torch.sum(att_flat)*1).item()
        # check everything we filter out is zero 
        torch.isclose(ent.flatten()[torch.nonzero(~(att_flat > 0))].sum(), torch.tensor(0.), 1e-3)
        return Histogram(ent_flat.detach().cpu().tolist())
#         ent_d = dict(
#       #      ent_min             = ent_flat.quantile(0).item(),
#             ent_lower_quartile  = ent_flat.quantile(0.25).item(), 
#             ent_median          = ent_flat.median().item(), 
#       #      ent_mean            = ent_flat.mean().item(), 
#             ent_upper_quartile  = ent_flat.quantile(0.75).item(), 
#       #      ent_max             = ent_flat.quantile(1).item(),   # skews the graph
#             epoch=self.epoch, global_step=self.global_step
#         )
#         return ent_d
        print('tmp')

    def _get_token_probability_metrics(self, scores_log_softmax, attention_mask, k=3): 
        token_prob_d = dict()
        tkn_kmaxprob, _ = torch.topk(scores_log_softmax, largest=True, k=k, dim=2)
        tkn_kmaxprob = tkn_kmaxprob.detach()  
        tkn_kmaxprob = torch.nan_to_num(tkn_kmaxprob, nan=None, posinf=None, neginf=-10000)
        assert tkn_kmaxprob.shape == torch.Size([self.pp_batch_size, self.pp_length - 1, k])

        # % of first prob over 0.9, 0.75, 0.5, 0.3, 0.1
        top_probs = tkn_kmaxprob[:,:,0].exp()
        top_probs = (top_probs * attention_mask).flatten()
        top_probs = top_probs[top_probs != 0]
        prob_threshold_l = [0.99, 0.975, 0.95, 0.90, 0.75, 0.5, 0.3, 0.1]
        for p in prob_threshold_l: 
            token_prob_d[f"top_token_prob_over_{str(p)}"] = (torch.sum(top_probs > p) / top_probs.shape[0]).item()

        # avg + median + lower + upper quartile of first, second, third choice probs
        tkn_kmaxprob_mask = tkn_kmaxprob * attention_mask[:,:,None]  # broadcasting over kth dim
        for i in range(k): 
            probs = tkn_kmaxprob_mask[:,:, i].flatten()
            probs = probs[probs != 0]
            token_prob_d[f"rank_{i+1}_histogram"] = Histogram(probs.detach().cpu().tolist())
            token_prob_d[f"rank_{i+1}_token_prob_mean"] = probs.mean().item()
            token_prob_d[f"rank_{i+1}_token_prob_median"] = probs.median().item()
            token_prob_d[f"rank_{i+1}_token_prob_0.25_quantile"] = probs.quantile(0.25).item()
            token_prob_d[f"rank_{i+1}_token_prob_0.75_quantile"] = probs.quantile(0.75).item()

        # tokens over probs above 0.1, 0.01, 0.001, 0.0001, 1/vocab_size prob 
        allprobs = (scores_log_softmax.detach().exp() * attention_mask[:,:,None]).flatten()
        allprobs = allprobs[allprobs != 0]
        for p in [0.1, 0.01, 0.001, 0.0001, 0.00001]: 
            token_prob_d[f"%_of_tokens_above_prob_{p}"] =  (torch.sum(allprobs > p) / allprobs.shape[0]).item()
        token_prob_d[f"%_of_tokens_above_prob_1/vocab_size"] = \
            (torch.sum(allprobs > (1/self._cfg.vocab_size)) / allprobs.shape[0]).item()
        return token_prob_d
    
    def _eval_dl(self, split): 
        """Get evaluation metrics for a dataloader"""
        ### TODO: delete redundant stuff
        # Put models in eval mode and do the forward pass 
        # Current logic: push all batches together into one big list.   
        self._reset_batch_dicts()
        if self.pp_model.training: self.pp_model.eval()
        if self.vm_model.training: self.vm_model.eval()
        dl_raw = self.ds.dld_raw[split]
        dl_tkn = self.ds.dld_tkn[split]
        with torch.no_grad(): 
            for self.batch_num, (data, raw) in enumerate(zip(dl_tkn, dl_raw)):
                self.logger.debug(show_gpu(f'EVAL, epoch {self.epoch}, batch {self.batch_num}, GPU memory usage after loading data: '))
                for k, v in data.items():
                    ## TODO: do you need this line?
                    if data[k].device != self._cfg.device: data[k] = data[k].to(self._cfg.device)
                pp_output, pp_l = self._pp_model_forward(data)
                _ = self._loss_fn(data, raw, pp_output, pp_l)
                self._add_batch_vars_to_batch_d(raw, data, pp_l)
                self.data_d[split].append(self.batch_d)
                self.logger.debug(show_gpu(f'EVAL, epoch {self.epoch}, batch {self.batch_num}, GPU memory usage after loss_fn pass: '))
        self._wandb_log_eval_step()  # not implemented yet
        
    def _compute_and_log_eval_metrics(self): 
        wandb_d = dict(epoch=self.epoch)
        for split in ['training_step', 'train', 'valid']: 
            # data d -> data frame 
            self.data_d[split] = self._convert_data_d_to_df(split)
            self._set_df_colorder(split)
            # calc metrics 
            df = self.data_d[split][['epoch'] + self._cfg.metrics]
            if split == "training_step": df = df.query("epoch == @self.epoch")
            d = df.mean()[self._cfg.metrics].to_dict()
            wandb_d = merge_dicts(wandb_d, {f"{k}_{split}": v for k, v in d.items()})
            # df append to file + empty data_d
            append_df_to_csv(self.data_d[split], path = f"{self._cfg.path_run}{split}.csv")
            self.data_d[split] = [] 
            
        wandb.log(wandb_d, commit=True)    

    def _plot_wandb_charts(self): 
        ## TODO: rename to indicate it only plots summary and examples charts 
        ## Can you refactor into regular wandb charts?
        if self._cfg.wandb['plot_examples']: 
            # Examples charts 
            for split in ['train', 'valid']:
                df = pd.DataFrame(data_d[split]) if type(self.data_d[split]) is list else self.data_d[split]
                df = df.query("idx in @plt_idx_d[@split]").sort_values(['idx', 'epoch'])
                for metric in self._cfg.metrics: 
                    chart = plot_examples_chart(split, table=wandb.Table(dataframe=df), metric=metric)
                    wandb.log({f"individual_examples/{split}_{metric}_vs_epoch_examples": chart}, commit=False)  
    
    def _set_df_colorder(self, data_d_key): 
        colorder_eval=['idx','epoch', 'orig_l',  'pp_l','orig_truelabel_probs','pp_truelabel_probs',
        'pp_predclass_probs','orig_label','pp_predclass','label_flip', 'vm_score','sts_score',
        'reward', 'pp_logp','loss','batch_num','global_step','accumulation_num','loss_batch', 'label_flip_fraction',
        'orig_length','orig_batch_size','pp_length','pp_batch_size']
        
        if data_d_key == "training_step": 
            colorder_training_step = colorder_eval + [o for o in self.data_d['training_step'].columns if 'time_' in o]
            self.data_d[data_d_key] = self.data_d[data_d_key][colorder_training_step]
        else:                             
            self.data_d[data_d_key] = self.data_d[data_d_key][colorder_eval]

    def _add_wandb_run_summary_statistics(self):
        """Compute test metrics for the run and log them to the wandb run summary pane. """
        ## Summary statistics of the test set 
        # From the last epoch atm because we don't have early stopping 
        test_metrics = self.data_d['test'].filter(self._cfg.metrics, axis=1).mean()
        for metric, val in zip(test_metrics.index, test_metrics): 
            self.run.summary[f"{metric}_avg_test"] = val

In [None]:
cfg.wandb['mode'] = 'disabled'

In [None]:
trainer = Trainer(cfg, vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, optimizer,
                  accelerator, ds, logger)
trainer.train()

Launching training on one GPU.


  0%|          | 0/20 [00:00<?, ?it/s]

Now on epoch 0 of 20
Now on epoch 1 of 20
Now on epoch 2 of 20
Now on epoch 3 of 20
Now on epoch 4 of 20
Now on epoch 5 of 20
Now on epoch 6 of 20
Now on epoch 7 of 20
Now on epoch 8 of 20
Now on epoch 9 of 20
Now on epoch 10 of 20
Now on epoch 11 of 20
Now on epoch 12 of 20
Now on epoch 13 of 20
Now on epoch 14 of 20
Now on epoch 15 of 20
Now on epoch 16 of 20
Now on epoch 17 of 20
Now on epoch 18 of 20
Now on epoch 19 of 20


In [None]:
n_epochs = 20
eval_freq = 1 
print(0)

for i in range(1, n_epochs+1): 
    if i % eval_freq == 0 : 
        print(i)
        
if n_epochs % eval_freq != 0: 
    raise Exception("leftover epochs")

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20


In [None]:
n_epochs % eval_freq

1

In [None]:
print(len(range(1, n_epochs+1)))

20


In [None]:
list(range(1, n_epochs, eval_freq))

[1, 5, 9, 13, 17]

In [None]:
len(range(0, n_epochs, eval_freq))

5

In [None]:
pth = "/data/tproth/model_checkpoints/travis_attack/dummy-SoKrFiSHniynepcHcMHxgD/"
df = pd.read_csv(pth + "training_step.csv")

In [None]:
df.shape

(68, 35)

In [None]:
16 * 4

64

## trainer._cfg.path_run

In [None]:
display_all(trainer.df_d['train'])

Unnamed: 0,idx,epoch,orig_l,pp_l,orig_truelabel_probs,pp_truelabel_probs,pp_predclass_probs,orig_label,pp_predclass,label_flip,vm_score,sts_score,reward,pp_logp,loss,batch_num,global_step,accumulation_num,loss_batch,label_flip_fraction,orig_length,orig_batch_size,pp_length,pp_batch_size
0,2,19,I do really like this film,I do like this movie.,0.676335,0.741146,0.741146,1,1,0,-0.064811,0.907538,0.441182,-0.725912,0.320259,0,20,20,0.238757,0.0,8,4,8,4
1,3,19,I hate this film so much,I hate this movie.,0.830579,0.893286,0.893286,0,0,0,-0.062707,0.911643,0.442833,-0.494259,0.218874,0,20,20,0.238757,0.0,8,4,8,4
2,0,19,I love eating this bread,I love this bread.,0.832408,0.906977,0.906977,1,1,0,-0.074569,0.954017,0.42886,-0.237565,0.101882,0,20,20,0.238757,0.0,8,4,8,4
3,1,19,This banana is terrible,This banana is terrible.,0.859391,0.89823,0.89823,0,0,0,-0.038839,0.975812,0.462101,-0.679536,0.314014,0,20,20,0.238757,0.0,8,4,8,4


In [None]:
%debug

No traceback has been produced, nothing to debug.


In [None]:
for o in iter(ds.dld_raw['test']): 
    print(o)

{'text': ['I do really like this film', 'I hate this film so much', 'I love eating this bread', 'This banana is terrible'], 'idx': [2, 3, 0, 1]}


In [None]:
df= trainer.data_d['training_step']

In [None]:
display_all(df.head(3))

AttributeError: 'list' object has no attribute 'head'

In [None]:
d

In [None]:
df = df.mean()[cfg.metrics]

In [None]:
df.columns

In [None]:
df.keys() = 

In [None]:
df = trainer.data_d['training_step'][['epoch'] + cfg.metrics]

In [None]:
)

In [None]:
s.to_dict()

In [None]:
df.

In [None]:
df[:,  ['loss']]

In [None]:
df.columns

In [None]:
df[:, [x]]

In [None]:
display_all(trainer.data_d['training_step'].head(4))

In [None]:
display_all(trainer.data_d['valid'].head(4))

In [None]:
display_all(trainer.data_d['valid'].head(8))

In [None]:
%debug

In [None]:
tkn_kmaxprob, _ = torch.topk(scores_log_softmax, largest=True, k=k, dim=2)


In [None]:
%debug

In [None]:
trainer

In [None]:
%debug

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()