In [None]:
#default_exp trainer 

In [None]:
#export
import torch, wandb, gc, numpy as np, pandas as pd,os
from tqdm.auto import tqdm
from travis_attack.utils import timecode, show_gpu
from travis_attack.models import save_model
from travis_attack.charts import plot_grad_flow, plot_wandb_charts
from travis_attack.charts import plot_examples_charts 
from travis_attack.charts import plot_summary_charts

In [None]:
#hide
from travis_attack.utils import set_seed, set_session_options, prepare_logger
from travis_attack.config import Config
from travis_attack.models import prepare_models
from travis_attack.data import ProcessedDataset
from travis_attack.trainer import get_optimizer
from accelerate import Accelerator, notebook_launcher
from fastcore.basics import store_attr




accelerator = Accelerator()
cfg = Config()
cfg.device = accelerator.device
cfg = cfg.adjust_config_for_simple_dataset()  # for testing, easier 
set_seed(cfg.seed)
set_session_options()
logger = prepare_logger()
vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, cfg = prepare_models(cfg)
optimizer = get_optimizer(cfg, pp_model)
ds = ProcessedDataset(cfg, vm_tokenizer, vm_model, pp_tokenizer, sts_model)
vm_model,pp_model,sts_model,optimizer,ds.dld_tkn['train'] = accelerator.prepare(vm_model,pp_model,sts_model,optimizer,ds.dld_tkn['train'])
cfg.n_train_steps = cfg.n_train_epochs * len(ds.dld_tkn['train'])

Using custom data configuration default-b253756c445fb811
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-b253756c445fb811/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-c802946231f72062
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-c802946231f72062/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-43a49c5188c42e69
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-43a49c5188c42e69/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

# Trainer 

This class is used to fine-tune the model 

In [None]:
#export
def get_vm_probs(text, cfg, vm_tokenizer, vm_model, return_predclass=False): 
    """Used in data cleaning and by the reward_fn to get vm_score"""
    if vm_model.training: vm_model.eval()
    with torch.no_grad():
        tkns = vm_tokenizer(text, truncation=True, padding=True, pad_to_multiple_of=cfg.orig_padding_multiple,
                            return_tensors="pt").to(cfg.device)
        logits = vm_model(**tkns).logits
        probs = torch.softmax(logits,1)
        if return_predclass:    return probs, torch.argmax(probs,1)
        else:                   return probs

def get_optimizer(cfg, pp_model):  return torch.optim.AdamW(pp_model.parameters(), lr=cfg.lr)

In [None]:
#export 
class Trainer: 
    def __init__(self, cfg, vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, accelerator, ds,
                logger): 
        store_attr()
        self._cfg = self.cfg; del self.cfg;
        self.accumulation_num,self.global_step = 0,0
        self._setup_wandb_run()
        self._setup_data_stores()
        self._setup_wandb_examples_plots()
    
    def _setup_wandb_run(self): 
        """Init wandb run, set up paths, create dir for model artifacts if needed, """
        self.run = wandb.init(project=self._cfg.wandb['project'], entity=self._cfg.wandb['project'], 
                              config=vars(self._cfg), mode=self._cfg.wandb['mode'],
                              notes=self._cfg.wandb['run_notes'], save_code=True)
        if self._cfg.wandb['log_grads']: 
            wandb.watch(self.pp_model, log='gradients', log_freq=self._cfg.wandb['log_grads_freq'])
        self._cfg.run_name,self._cfg.run_id = self.run.name, self.run.id
        self._cfg.path_run = f"{self._cfg.path_checkpoints}{self.run.name}/"
        if not os.path.exists(self._cfg.path_run): os.makedirs(self._cfg.path_run, exist_ok=True)

    
    def _setup_data_stores(self): 
        """Setup dict `self.data_d` to store observations. Setup column names for wandb tables.  
         """
        # Raw observation data (lists of dicts, later becomes pandas df)
        self.data_d = dict()
        # These have to be in the keys of the output from eval_dl
        self.table_columns = ['idx', 'orig_l',  'truelabel', 'orig_truelabel_probs', 'epoch', 'pp_l',
                     'pp_truelabel_probs', "pp_predclass", "pp_predclass_probs"] + self._cfg.metrics
        self.summary_table_columns = ['epoch','split'] + [f'{m}_avg' for m in self._cfg.metrics]
        for key in self._cfg.splits + ['training_summary']:         self.data_d[key]             = [] 
        if self._cfg.wandb['log_training_step_table']:              self.data_d['training_step'] = []
    
    
    def _setup_wandb_examples_plots(self): 
        """If we plot a few examples this sets that up."""
        # Get indices for the examples plots
        if self._cfg.wandb['plot_examples']:
            def get_examples_plot_idxs(ds): 
                return np.random.choice(ds['idx'], size=self._cfg.wandb['n_examples_plot'], replace=False).tolist()
            self.plt_idx_d = dict()
            for split in self._cfg.splits:  self.plt_idx_d[split] = get_examples_plot_idxs(self.ds.dsd[split])
    
    
    def training_function(self): 
        #self.logger.info(show_gpu(f' GPU memory usage after loading models:'))
        progress_bar = tqdm(range(self._cfg.n_train_steps))
        self.pp_model.zero_grad(set_to_none=self._cfg.zero_grad_with_none) 
        for self.epoch in range(self.cfg_n_train_epochs): 
            self.logger.info(f"Now on epoch {self.epoch} of {self._cfg.n_train_epochs}")
            if not self.pp_model.training: self.pp_model.train()
            with timecode() as time_train_one_epoch:
                for self.batch_num, (data, raw) in enumerate(zip(self.ds.dld_tkn['train'], self.ds.dld_raw['train'])): 
                    if self.batch_num % 10 == 0 :   self.logger.info(f"Now processing batch {self.batch_num} out of {len(dld_tkn['train'])}")
                    self.training_step(data, raw) 
                    self.accumulation_num += 1  ; self.global_step += 1 ;  progress_bar.update(1) 
                    
            wandb.log({'time/train_one_epoch_time': time_train_one_epoch.t,
                       'time/train_one_epoch_thoroughput': len(self.ds.dsd['train']) / time_train_one_epoch.t,
                       'epoch': self.epoch}, commit=True)

            if self._cfg.wandb_log_grads and self.epoch % self._cfg.wandb_log_grads_freq == 0: 
                plt = plot_grad_flow(self.pp_model.named_parameters())
                wandb.log({"gradient flow": wandb.Image(plt)})  # doesn't work as a non-image (i.e. plotly)
                del plt 
            #gc.collect() 
            #torch.cuda.empty_cache()

            if self._cfg.save_model_while_training and (self.epoch + 1) % self._cfg.save_model_freq == 0:  save_model(epoch)

            # Evaluation loop
            if self.epoch % self._cfg.eval_freq == 0: 
                self.logger.info(f"Now doing train eval")
                with timecode() as time_eval_train:
                    train_set_preds = self.eval_dl(dl_tkn=self.ds.dld_tkn['train_eval'], 
                                                   dl_raw=self.ds.dld_raw['train_eval'])

                self.logger.info(f"Now doing valid eval")
                with timecode() as time_eval_valid:
                    valid_set_preds = self.eval_dl(dl_tkn=self.ds.dld_tkn['valid'],
                                                   dl_raw=self.ds.dld_raw['valid'])

                # update the tables every epoch and log them
                with timecode() as time_update_training_summary_table:
                    self.update_training_summary_table(train_set_preds, split='train')
                    self.update_training_summary_table(valid_set_preds, split='valid')
                with timecode() as time_add_eval_preds_to_data_d:    
                    self.add_preds_to_data_d(train_set_preds, split='train')
                    self.add_preds_to_data_d(valid_set_preds, split='valid')
                self.plot_wandb_charts()
                del train_set_preds
                del valid_set_preds
                with timecode() as time_eval_gc_collect:
                    gc.collect() 
                with timecode() as time_eval_empty_cache:
                    torch.cuda.empty_cache()
                wandb.log({'time/eval_train_time': time_eval_train.t, 'time/eval_valid_time': time_eval_valid.t,
                           'time/eval_train_thoroughput': len(self.ds.dsd['train']) / time_eval_train.t,
                           'time/eval_valid_thoroughput': len(self.ds.dsd['valid']) / time_eval_valid.t,
                           'time/eval_update_training_summary_table': time_update_training_summary_table.t, 
                           'time/eval_add_preds_to_data_d': time_add_eval_preds_to_data_d.t,
                           'time/eval_gc_collect': time_eval_gc_collect.t, 
                           'time/eval_empty_cache': time_eval_empty_cache.t,
                   'epoch': self.epoch}, commit=True)

        self.logger.info(f"Now doing test eval")        
        # Eval on test set 
        test_set_preds = self.eval_dl(dl_tkn = self.ds.dld_tkn['test'], dl_raw=self.ds.dld_raw['test'])
        self.add_preds_to_data_d(test_set_preds, split='test')

        # Data -> df and save dfs to file 
        for key in data_d.keys():  # splits and sometimes 'training_step' too 
            self.data_d[key] = pd.DataFrame(self.data_d[key]) # dict of list of dict -> dict of dataframe
            self.data_d[key].to_csv(f"{path_run}{key}.csv", index=False)
        # Save training_summary table to csv too 
        pd.DataFrame(self.data_d['training_summary']).to_csv(f"{path_run}training_summary.csv", index=False)

        # plot_wandb_charts()  # don't think I need this
        add_wandb_run_summary_statistics(run)     

In [None]:
trainer = Trainer(cfg, vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, accelerator, ds)

In [None]:
trainer.test1()

0
1
2
3
4
5
6
7
8
9


In [None]:

    
        
        
    def training_step(self, data, raw, accelerator): 
        """With gradient accumulation"""
        with timecode() as time_generate_pp:
            pp_output, pp_l = self.self.pp_model_forward(data)

        #logger.info(show_gpu(f'Batch {self.batch_num}, GPU memory usage after forward pass: '))

        with self.accelerator.autocast():
            with timecode() as time_loss_fn:
                if self._cfg.wandb_log_training_step_table: 
                    results_d = self.loss_fn(data, raw, pp_output, pp_l, return_components=True)
                    loss_batch = results_d['loss_batch']
                else: 
                    loss_batch = self.loss_fn(data, raw, pp_output, pp_l, return_components=False)

            loss_batch = loss_batch / self._cfg.accumulation_steps  # Normalize our loss for gradient accumulation

        with timecode() as time_backwards:
            self.accelerator.backward(loss_batch) 

        #logger.info(show_gpu(f'Batch {self.batch_num}, GPU memory usage after backwards pass: '))
        if (self.accumulation_num + 1) % self._cfg.accumulation_steps == 0: 
            with timecode() as time_opt_step:
                optimizer.step()

            self.pp_model.zero_grad(set_to_none=zero_grad_with_none)
        if wandb_log_training_step_table: 
            with timecode() as time_add_to_training_step_table:
                results_d = process_results_d1(results_d, raw)
                add_preds_to_data_d(results_d, split='training_step') 

    #     print("### INSIDE training_step ####")    
    #     print("self.epoch", self.epoch)
    #     print("self.batch_num", self.batch_num)
    #     print("self.global_step", self.global_step)
    #     print("model in training mode", self.pp_model.training)
    #     print("orig", raw['text'])
    #     print("pp_l", pp_l)
    #     print("pp_seq", pp_output.sequences)
    #     print("pp_length", pp_output.sequences.shape, pp_length)
    #     print("loss_fn_time", time_loss_fn.t)
    #     print("### INSIDE training_step ####")

        wandb.log({'time/generate_pp': time_generate_pp.t, 'time/loss_fn': time_loss_fn.t, 
                   'time/backwards_pass': time_backwards.t, 'time/optimizer_step': time_opt_step.t, 
                   'time/add_to_training_step_table': time_add_to_training_step_table.t, 
                   'self.epoch': self.epoch, 'self.global_step': self.global_step,'self.batch_num': self.batch_num,
                   'orig_length': orig_length,'orig_batch_size': orig_batch_size,
                  'pp_length': pp_length, 'pp_batch_size': pp_batch_size}
                  ,commit=False)
        
        
        def get_paraphrases(input_ids,attention_mask):
        """Wrapper for generating paraphrases (pp's). Most keywords are passed on to self.pp_model.generate function, 
        so see docs for that function. """
        # Only greedy search supported at the moment
        pp_output = self.pp_model.generate_with_grad(input_ids=input_ids, 
                                                attention_mask=attention_mask, 
                                                 **self.pp_model_params,
                                                 do_sample=False, 
                                                 return_dict_in_generate=True,
                                                 output_scores=True,
                                                    remove_invalid_values=False, 
                                                 pad_token_id = self.pp_tokenizer.pad_token_id,
                                                 eos_token_id = self.pp_tokenizer.eos_token_id)
        pp_l = self.pp_tokenizer.batch_decode(pp_output.sequences, skip_special_tokens=True)
        if track_pp_sizes:  # DEV CODE (can delete later)
            orig_max_l.append(batch['input_ids'].shape[1])
            pp_max_l.append(pp_output.sequences.shape[1])
        return pp_output, pp_l

    def get_pp_logp(pp_output, log_times=True): 
        """log(p(pp|orig)) basically.
        works for greedy search, will need tweaking for other types probably"""
        ### TODO: this looks like logp to me, not plogp. Find out if this is right and if so rename, if not, fix
        ### We want to align tokens with token probabilities. The first token is given at the start 
        # and has no probability attached to it, so we remove it. 
        seq_without_first_tkn = pp_output.sequences[:, 1:]
        assert seq_without_first_tkn.shape == torch.Size([orig_batch_size, pp_length - 1])

        ### Convert from tuple of scores to one big tensor of scores 
        scores_stacked = torch.stack(pp_output.scores, 1)
        ### TESTS 
        # We check shape and that there is no +inf or nan in scores. 
        # Scores can have -inf in them - see explanation in `exploring_generation`.  
        assert scores_stacked.shape == torch.Size([orig_batch_size, (pp_length - 1), vocab_size])
        assert torch.all(~torch.isnan(scores_stacked))
        assert torch.all(~torch.isposinf(scores_stacked))
        # Rough check that all idx before min_length are -inf for all elements in batch
        # We do min_length - 1 because sequences are allowed to have length min_length so that idx 
        # shouldn't be set to -inf
        # Not a 100% test but very likely to identify
        idx_neginf = torch.nonzero(torch.isneginf(scores_stacked))
        assert len(idx_neginf[idx_neginf[:,2] == self.pp_tokenizer.eos_token_id, :]) == \
                  (self.pp_model_params["min_length"] -1) * orig_batch_size  
        del idx_neginf

        ### Take log softmax of scores and then extract those that correspond 
        # to the generated sequences    
        scores_log_softmax = scores_stacked.log_softmax(2)
        seq_token_log_probs = torch.gather(scores_log_softmax,2,seq_without_first_tkn[:,:,None]).squeeze(-1)
        ### TESTS 
        # -inf is possible in scores_log_softmax and seq_token_log_probs before the attention mask is added. 
        assert torch.all(~torch.isnan(   scores_log_softmax))
        assert torch.all(~torch.isposinf(scores_log_softmax))
        check_scores_log_softmax_sums(scores_log_softmax)
        # probs should be 1-1 with the filtered tkns: check shape to confirm
        assert seq_token_log_probs.shape == seq_without_first_tkn.shape  
        # Check that the last token probability corresponds to a possible end token
        # this has to be tested before the attention mask is multiplied with it because if the 
        # padding token is 0 then this will be 0 too (and not the same as scores_log_softmax)
        output_end_ids = get_start_end_special_token_ids(self.pp_tokenizer)['output_end_id']
        assert all([o in scores_log_softmax[:, -1, output_end_ids] for o in seq_token_log_probs[:,-1]])
        del output_end_ids
        ## THIS ONE IS LONG - a test rather than assert 
        # check_seq_token_log_prob_values_are_correct(seq_without_first_tkn, scores_log_softmax, 
        #                                             seq_token_log_probs) 

        ### Generate attention mask to identify padding tokens. Then apply it to the 
        # sequence probabilities so that we don't consider probability of padding tokens 
        # when getting sequence probabilities. 
        # Also replace the -inf values in seq_token_log_probs with a large negative number because if we 
        # leave them in we end up with nan's introduced after multiplying with attention_mask, 
        # since  -inf * 0 = nan 
        attention_mask = self.pp_model._prepare_attention_mask_for_generation(
            seq_without_first_tkn, self.pp_tokenizer.pad_token_id, self.pp_tokenizer.eos_token_id
        )
        seq_token_log_probs = torch.nan_to_num(seq_token_log_probs, nan=None, posinf=None, neginf=-10000)
        seq_token_log_probs = seq_token_log_probs * attention_mask
        ### TESTS
        assert seq_token_log_probs.shape == attention_mask.shape == seq_token_log_probs.shape
        # check attention mask only has 0 for padding tokens and not eos tokens or anything else
        assert all(seq_without_first_tkn[attention_mask == 0] == self.pp_tokenizer.pad_token_id)
        check_no_nans_or_infs(seq_token_log_probs)
        # check that we aren't picking extrememly rare tokens
        assert torch.all(seq_token_log_probs  > -10)  

        ### Get sequence probabilities by summing up token log probabilities 
        seq_log_prob = seq_token_log_probs.sum(-1)
        ## TESTS 
        assert seq_log_prob.shape == torch.Size([pp_batch_size])
        check_no_nans_or_infs(seq_log_prob)


        if wandb_log_token_entropy:
            with timecode() as time_log_entropy:
                ent_d = self._get_entropy_metrics(scores_stacked, attention_mask)
            ent_d['time/log_entropy'] = time_log_entropy.t
            if log_times:   # need a better way to handle this. 
                wandb.log(ent_d, commit=False)


        if wandb_log_token_probabilities: 
            with timecode() as time_log_token_probabilities:
                token_prob_d = self._get_token_probability_metrics(scores_log_softmax, attention_mask, k=3)
            token_prob_d['time/log_token_probabilities'] = time_log_token_probabilities.t
            if log_times: 
                wandb.log(token_prob_d, commit=False)

        return seq_log_prob

    def reward_fn(data, raw, pp_l, return_components=False, log_times=True): 
        """orig_l, pp_l are lists of original and paraphrase respectively"""
        # Victim model probability differences between orig and pp
        with timecode() as time_vm_scores:
            pp_probs = get_vm_probs(pp_l) 
            pp_predclass = torch.argmax(pp_probs, axis=1)
            pp_truelabel_probs   = torch.gather(pp_probs, 1, data['label'][:,None]).squeeze()
            pp_predclass_probs   = torch.gather(pp_probs, 1, pp_predclass[ :,None]).squeeze()
            label_flip = (pp_predclass != data['label']) * 1
            vm_scores = (data['orig_truelabel_probs'] - pp_truelabel_probs)


        # STS scores
        with timecode() as time_sts_scores:
            pp_embeddings   = self.sts_model.encode(pp_l,        batch_size=len(raw), convert_to_tensor=True, device=self._cfg.device)
            # This returns a cosine similarity matrix, of which we just want the diagonal
            sts_scores = pytorch_cos_sim(data['orig_sts_embeddings'], pp_embeddings).diagonal()  

        # Reward calculation 
        rewards = torch.tensor([-0.5 if sts < 0.5 else 0.5+v*sts for v,sts in zip(vm_scores, sts_scores)],device=self._cfg.device)

        if log_times:
            wandb.log({'self.epoch': self.epoch, 'self.global_step': self.global_step, 
                       'time/vm_scores': time_vm_scores.t, 'time/sts_scores': time_sts_scores.t }, 
                       commit=False)

        if return_components: 
            return {
                "orig_l": raw['text'],
                "pp_l": pp_l,  
                "truelabel": data['label'],
                "orig_truelabel_probs": data['orig_truelabel_probs'],
                "pp_truelabel_probs":  pp_truelabel_probs,
                "pp_predclass": pp_predclass,
                "pp_predclass_probs": pp_predclass_probs,
                "vm_score": vm_scores, 
                "sts_score": sts_scores,
                "reward": rewards,
                "label_flip": label_flip
            }
        else:  return {"reward": rewards}

    def self.pp_model_forward(data): 
        global orig_batch_size,orig_length,pp_batch_size,pp_length
        orig_batch_size     = data['input_ids'].shape[0]
        orig_length         = data['input_ids'].shape[1]
        pp_output, pp_l = get_paraphrases(data['input_ids'], data['attention_mask']) 
        pp_batch_size = pp_output.sequences.shape[0]
        # for greedy search pp_length is equal to orig_batch_size but this won't be for beam search
        pp_length     = pp_output.sequences.shape[1]  
        return pp_output, pp_l

    def loss_fn(data, raw, pp_output, pp_l, return_components=False, log_times=True): 
        with timecode() as time_reward_fn:
            d = reward_fn(data, raw, pp_l, return_components=return_components, log_times=log_times)

        if normalise_rewards: 
            d['orig_reward'] = copy.deepcopy(d['reward'])
            d['reward'] = (d['reward']-torch.mean(d['reward']))/torch.std(d['reward'])

        with timecode() as time_pp_logp:
            d['pp_logp'] = get_pp_logp(pp_output,log_times=log_times)

        with timecode() as time_loss_fn_loss_calc:
            d['loss'] = -d['reward'] * d['pp_logp']
            d['loss_batch'] = torch.mean(d['loss'])
            if return_components ==  False: return d['loss_batch'] 

        # remove some items from compgraph
        with timecode() as time_loss_fn_detach:
            d['pp_logp'] = d['pp_logp'].detach()  
            d['loss']    = d['loss'].detach()

    #     # This was taking a lot of time so removed it. Add it in if needed. 
    #     #gc.collect() 
    #     print("\t### INSIDE loss_fn### ")
    #     print("\tself.global_step", self.global_step)
    #     print("\treward_fn_time", time_reward_fn.t)
    #     print("\t######### ")

        if log_times:   # true for training, not eval. 
            wandb.log({'self.epoch': self.epoch, 'self.global_step': self.global_step, 
                       'time/reward_fn': time_reward_fn.t, 'time/pp_logp': time_pp_logp.t, 
                      'time/loss_fn_loss_calc': time_loss_fn_loss_calc.t, 
                       'time/pp_logp_detach': time_loss_fn_detach.t, 
                      }, 
                     commit=False)
        return d

    def process_results_d1(results_d, raw): 
        """REFACTOR THIS LATER"""
        # wandb logging 
        results_d['self.epoch'] = self.epoch
        results_d['idx'] = raw['idx']
        for k,v in results_d.items(): 
            if torch.is_tensor(v): 
                results_d[k] = v.detach().cpu().tolist()
            elif type(v) == int or type(v) == float: 
                # make into list repeated n times
                results_d[k] = [v for o in range(batch_size_train)]
        return results_d    
        
    def _get_entropy_metrics(self, scores_stacked, attention_mask): 
        ent = Categorical(logits = scores_stacked).entropy().detach()
        assert ent.shape == attention_mask.shape == torch.Size([pp_batch_size, pp_length - 1])

        ent = ent * attention_mask  # stop values after eos token from contributing to ent score 
        # first remove structure (otherwise we have ragged arrays)
        # then remove corresponding attention mask values
        # we can't just filter by ent[ent != 0] because we might have zero tokens during the sequence
        att_flat= attention_mask.flatten()
        indices = torch.nonzero(att_flat)
        ent_flat = ent.flatten()[indices].flatten()
        assert ent_flat.shape[0] == (torch.sum(att_flat)*1).item()
        # check everything we filter out is zero 
        torch.isclose(ent.flatten()[torch.nonzero(~(att_flat > 0))].sum(), torch.tensor(0.), 1e-3)
        ent_d = dict(
            ent_min             = ent_flat.quantile(0).item(),
            ent_lower_quartile  = ent_flat.quantile(0.25).item(), 
            ent_median          = ent_flat.median().item(), 
            ent_mean            = ent_flat.mean().item(), 
            ent_upper_quartile  = ent_flat.quantile(0.75).item(), 
            ent_max             = ent_flat.quantile(1).item(), 
            self.epoch=self.epoch, self.global_step=self.global_step
        )
        return ent_d

    def _get_token_probability_metrics(self, scores_log_softmax, attention_mask, k=3): 
        token_prob_d = dict()
        tkn_kmaxprob, tkn_kmaxidx = torch.topk(scores_log_softmax,largest=True,  k=k, dim=2)
        tkn_kmaxprob = tkn_kmaxprob.detach()  # log these 
        assert tkn_kmaxprob.shape == torch.Size([pp_batch_size, pp_length - 1, k])

        # % of first prob over 0.9, 0.75, 0.5, 0.3, 0.1
        top_probs = tkn_kmaxprob[:,:,0].exp()
        top_probs = (top_probs * attention_mask).flatten()
        top_probs = top_probs[top_probs != 0]
        prob_threshold_l = [0.99, 0.975, 0.95, 0.90, 0.75, 0.5, 0.3, 0.1]
        for p in prob_threshold_l: 
            token_prob_d[f"top_token_prob_over_{str(p)}"] = (torch.sum(top_probs > p) / top_probs.shape[0]).item()

        # avg + median + lower + upper quartile of first, second, third choice probs
        tkn_kmaxprob_mask = tkn_kmaxprob * attention_mask[:,:,None]  # broadcasting over kth dim
        for i in range(k): 
            probs = tkn_kmaxprob_mask[:,:, i].flatten()
            probs = probs[probs != 0]
            token_prob_d[f"rank_{i+1}_token_prob_mean"] = probs.mean().item()
            token_prob_d[f"rank_{i+1}_token_prob_median"] = probs.median().item()
            token_prob_d[f"rank_{i+1}_token_prob_0.25_quantile"] = probs.quantile(0.25).item()
            token_prob_d[f"rank_{i+1}_token_prob_0.75_quantile"] = probs.quantile(0.75).item()

        # tokens over probs above 0.1, 0.01, 0.001, 0.0001, 1/vocab_size prob 
        allprobs = (scores_log_softmax.detach().exp() * attention_mask[:,:,None]).flatten()
        allprobs = allprobs[allprobs != 0]
        for p in [0.1, 0.01, 0.001, 0.0001, 0.00001]: 
            token_prob_d[f"%_of_tokens_above_prob_{p}"] =  (torch.sum(allprobs > p) / allprobs.shape[0]).item()
        token_prob_d[f"%_of_tokens_above_prob_1/vocab_size"] = \
            (torch.sum(allprobs > (1/vocab_size)) / allprobs.shape[0]).item()

        token_prob_d['self.epoch'] = self.epoch
        token_prob_d['self.global_step'] = self.global_step
        return token_prob_d
    
    
    
    def table2df(table):  return pd.DataFrame(data=table.data, columns=table.columns)  # wandb table to dataframe

    def process_results_d_for_wandb(results_d): 
        # Flatten batches for each key, depending on datatype (e.g. lists of lists )
        for k,v in results_d.items(): 
            # v[0] is arbitrary - we are just checking the first item in the list to see the type
            if type(v) == float or type(v) == int: 
                next
            elif  torch.is_tensor(v[0]): 
                # case where we have a list of scalars - the cat function doesn't work here 
                if  v[0].size() == torch.Size([]): x = torch.stack(v)
                else:                              x = torch.cat(v)
                results_d[k] = x.detach().cpu().squeeze().tolist()  # convert to list (squeeze is for single scalar list)
            elif type(v[0]) == list:  # this is True for tensors also, so it has to go after the is_tensor check
                results_d[k] = list(itertools.chain(*v)) 
            elif type(v) == list: 
                next
            else: 
                raise Exception("shouldn't get here")
        return results_d

    def eval_dl(dl_tkn, dl_raw): 
        """Get evaluation metrics for a dataloader"""
        # Put models in eval mode and do the forward pass 
        # Current logic: push all batches together into one big list.     
        if pp_model.training: pp_model.eval()
        if vm_model.training: vm_model.eval()
        results_d = defaultdict(list)
        with torch.no_grad(): 
            for eval_batch_num, (data, raw) in enumerate(zip(dl_tkn, dl_raw)):
              #  logger.info(show_gpu(f'EVAL, batch {i}, GPU memory usage after loading data: '))
                for k, v in data.items(): 
                    if data[k].device != self._cfg.device: data[k] = data[k].to(self._cfg.device)
               # if data['input_ids'].device != self._cfg.device: data['input_ids'].to(self._cfg.device)

                pp_output, pp_l = pp_model_forward(data)
                d = loss_fn(data, raw, pp_output, pp_l, return_components=True, log_times=False)
                #logger.info(show_gpu(f'EVAL, batch {eval_batch_num}, GPU memory usage after loss_fn pass: '))
                d['idx'] = raw['idx']

                for k,v in d.items(): 
                    results_d[k].append(v) 
        del eval_batch_num, data, raw, pp_output, pp_l, d
        results_d = process_results_d_for_wandb(results_d)

        # Calculate additional metrics 
        results_d['epoch'] = epoch
        return results_d

    def add_preds_to_data_d(results_d, split):
        if split not in data_d.keys() or split == "training_summary": # training summary table logic is elsewhere
            raise Exception("split not in table keys or split == training_summary ") 
        #table = table_d[split]

        # Need epoch to be repeated to the same length as the rest of the fields 
        # (this isn't the batch size because we concat a bunch of stuff)
        # we don't want to change the `epoch` key because it screws up logging of the other metrics. 
        # So we make a new dict.
        # d1 = copy.deepcopy(results_d)
        d1 = results_d
        d1['epoch'] = [epoch for o in range(len(d1['pp_l']))]
        dcols = [d1[c] for c in table_columns]  # filter out loss_batch
        assert len(set([len(o) for o in dcols])) == 1  # all lists should be of the same length 

        for row in zip(*dcols):
            d2 = {k:v for k,v in zip(table_columns,row)}
            data_d[split].append(d2)

    def update_training_summary_table(results_d, split):
        d = dict()
        # key names here have to match those in summary_table_columns
        d['epoch'] = epoch
        d['split'] = split
        for metric in metrics:
            d[f'{metric}_avg'] = np.mean(results_d[metric])
        #data_d['training_summary'].append(*[d[c] for c in summary_table_columns])
        data_d['training_summary'].append(d)

#     def log_wandb_tables(run): 
#         """Log wandb tables to the UI"""
#         d = dict()
#         d["eval/training_summary_table"] = table_d['training_summary']
#       #  print(len(d["eval/training_summary_table"].data))
#         run.log(d)


    def plot_wandb_charts(self): 
        if self._cfg.wandb_plot_examples: 
            # Examples charts 
            for split in ['train', 'valid']:
                df = pd.DataFrame(data_d[split]) if type(data_d[split]) is list else data_d[split]
                df = df.query("idx in @plt_idx_d[@split]").sort_values(['idx', 'epoch'])
                for metric in metrics: 
                    chart = plot_examples_chart(split, table=wandb.Table(dataframe=df), metric=metric)
                    wandb.log({f"individual_examples/{split}_{metric}_vs_epoch_examples": chart}, commit=False)
                    
                    
        ## Summary charts 
        for metric in metrics: 
            df = pd.DataFrame(data_d['training_summary'])
            chart = plot_summary_charts(metric, table=wandb.Table(dataframe=df))
            wandb.log({f"summary_charts/avg_{metric}_vs_epoch": chart}, 
                      commit=True if metric == metrics[len(metrics)-1] else False)

            
    def add_wandb_run_summary_statistics(run):
        ## Training summary statistics 
        df_summary = pd.DataFrame(data_d['training_summary']) 
        # We calculate the best epoch according to the validation set
        best_epoch_idx = df_summary.query("split=='valid'")['loss_avg'].idxmin() 
        valid_row = df_summary.iloc[best_epoch_idx]
        best_epoch = valid_row['epoch'].item()
        run.summary['best_epoch'] = best_epoch
        # iloc transforms 1row df to series (so it is same as  valid_row)
        train_row = df_summary.query("split=='train' & epoch==@best_epoch").iloc[0]  
        for metric in metrics: 
            run.summary[f"{metric}_avg_train"] = train_row[f"{metric}_avg"].item()
            run.summary[f"{metric}_avg_valid"] = valid_row[f"{metric}_avg"].item()

        ## Summary statistics of the test set 
        # From the last epoch atm because we don't have early stopping 
        test_metrics = data_d['test'].filter(metrics, axis=1).mean()
        for metric, val in zip(test_metrics.index, test_metrics): 
            run.summary[f"{metric}_avg_test"] = val
    
    
    def get_start_end_special_token_ids(tokenizer): 
    """The token id's that input/output sequences should start and end with"""
    d = {}
    if tokenizer.name_or_path in ['eugenesiow/bart-paraphrase', 'tdopierre/ProtAugment-ParaphraseGenerator']: 
        d["input_start_id"] =  tokenizer.bos_token_id
        d["input_end_id"] =  [tokenizer.pad_token_id, tokenizer.eos_token_id]
        d["output_start_id"] =  tokenizer.eos_token_id 
        d["output_end_id"] =  [tokenizer.pad_token_id, tokenizer.eos_token_id]
    elif tokenizer.name_or_path == "tuner007/pegasus_paraphrase":
        d["input_start_id"] =  None
        d["input_end_id"] =  [tokenizer.pad_token_id, tokenizer.eos_token_id] 
        d["output_start_id"] =  tokenizer.pad_token_id
        d["output_end_id"] =  [tokenizer.pad_token_id, tokenizer.eos_token_id]
    else: 
        raise Exception("unrecognised tokenizer")
    return d



    def check_no_nans_or_infs(x):
        assert torch.all(~torch.isnan(x))
        assert torch.all(~torch.isneginf(x))
        assert torch.all(~torch.isposinf(x))

    def assert_start_and_end_tokens_are_correct(tokenizer, orig_token_ids, pp_token_ids):
        """Make sure input sequences (orig) and output sequences (pp) start and end with the 
        right special tokens (depends on tokenizer)"""
        start_end_token_d = get_start_end_special_token_ids(pp_tokenizer)

        # Input
        if start_end_token_d['input_start_id'] is not None: 
            assert torch.all(orig_token_ids[:,0] == start_end_token_d['input_start_id'])
        # can probs rewrite this to make it nicer but it's fine for now
        assert torch.all(torch.logical_or(orig_token_ids[:,-1] == start_end_token_d['input_end_id'][0], 
                                          orig_token_ids[:,-1] == start_end_token_d['input_end_id'][1]))

        # Output
        assert torch.all(pp_token_ids[:,0] == start_end_token_d['output_start_id'])
        assert torch.all(torch.logical_or(pp_token_ids[:,-1] == start_end_token_d['output_end_id'][0], 
                                          pp_token_ids[:,-1] == start_end_token_d['output_end_id'][1]))

    def check_scores_log_softmax_sums(scores_log_softmax):
        sums = scores_log_softmax.exp().sum(2)
        # check that the axes is right
        # we want to sum over token probabilities at each generation step, so we 
        # should end up with a shape [orig_batch_size, pp_length]
        assert sums.shape[0] == orig_batch_size  
        assert sums.shape[1] == pp_length - 1
        # check that they sum to 1 along the pp_length axis
        assert torch.allclose(sums, torch.ones(sums.size(), device=self._cfg.device), atol = 1e-4)

    def check_seq_token_log_prob_values_are_correct(seq_without_first_tkn, scores_log_softmax, seq_token_log_probs): 
        """Just enumerates and checks values
        Quite slow for large batches so run as a test rather than an assert in every batch. 
        """
        l = []
        for i_ex in range(orig_batch_size):
            for i_step in range(pp_length - 1):
                i_tkn = seq_without_first_tkn[i_ex][i_step].item()
                l.append(scores_log_softmax[i_ex,i_step, i_tkn] == seq_token_log_probs[i_ex,i_step])
        assert all(l)    

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 03_config.ipynb.
Converted 07_models.ipynb.
Converted 10_data.ipynb.
Converted 20_trainer.ipynb.
Converted 30_logging.ipynb.
Converted index.ipynb.
Converted run.ipynb.
