In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
## Imports and environment variables 
import os
import torch
from datetime import datetime
import wandb
from travis_attack.utils import set_seed, set_session_options, setup_logging, setup_parser, display_all
from travis_attack.config import Config
from travis_attack.models import prepare_models, get_optimizer
from travis_attack.data import ProcessedDataset
from travis_attack.trainer import Trainer
#from travis_attack.insights import (postprocess_df, create_and_log_wandb_postrun_plots, get_training_dfs)
from fastcore.basics import in_jupyter

import logging 
logger = logging.getLogger("paraphrase_eval")

path_baselines = "./pp_eval_baselines/"
datetime_now = datetime.now().strftime("%Y-%m-%d_%H%M%S")

In [None]:
def setup_pp_eval_parser(): 
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name")
    parser.add_argument("--split")
    parser.add_argument("--sts_threshold", type=float)
    parser.add_argument("--contradiction_threshold", type=float)
    #parser.add_argument('args', nargs=argparse.REMAINDER)  # activate to put keywords in kwargs.
    return parser

In [None]:
######### CONFIG (default values) #########
d = dict(
    datetime=datetime_now,
    dataset_name = "rotten_tomatoes",
    split = 'valid',
    sts_threshold = 0.7,
    contradiction_threshold = 0.2
#         pp = {
#         "do_sample": False if self.sampling_strategy == "greedy" else True,
#         "min_length": 4, 
#         "max_length": 48, 
#         "temperature": 0.7,
#         "top_p": 0.98, 
#         "length_penalty" : 1.,
#         "repetition_penalty": 1.
#    }
        
)
###########################################

if not in_jupyter():  # override with any script options
    parser = setup_pp_eval_parser()
    newargs = vars(parser.parse_args())
    for k,v in newargs.items(): 
        if v is not None: d[k] = v

In [None]:
cfg = Config() 
for k,v in d.items(): setattr(cfg, k, v)
if   cfg.dataset_name == "rotten_tomatoes": cfg.adjust_config_for_rotten_tomatoes_dataset()
elif cfg.dataset_name == "financial":       cfg.adjust_config_for_financial_dataset()
elif cfg.dataset_name == "simple":          cfg.adjust_config_for_simple_dataset()
if cfg.use_small_ds:  cfg = cfg.small_ds()
set_seed(cfg.seed)
set_session_options()
setup_logging(cfg, disable_other_loggers=True)
vm_tokenizer, vm_model, pp_tokenizer, pp_model, ref_pp_model, sts_model, nli_tokenizer, nli_model, cfg = prepare_models(cfg)

In [None]:
optimizer = get_optimizer(cfg, pp_model)
ds = ProcessedDataset(cfg, vm_tokenizer, vm_model, pp_tokenizer, sts_model, load_processed_from_file=True)
trainer = Trainer(cfg, vm_tokenizer, vm_model, pp_tokenizer, pp_model, ref_pp_model, sts_model, nli_tokenizer, nli_model, optimizer,
                  ds, initial_eval=False, use_cpu=False)
    

travis_attack.data: INFO     Will load dataset rotten_tomatoes with use_small_ds set to False
travis_attack.data: INFO     Will load dataset rotten_tomatoes with use_small_ds set to False
travis_attack.data: INFO     Cache file found for processed dataset, so loading that dataset.
travis_attack.data: INFO     Cache file found for processed dataset, so loading that dataset.


In [None]:
@patch_to(Trainer)
def get_ref_model_baseline(self): 
    """Calculate baselines for the ref model."""
     # Put models in eval mode and do the forward pass 
    # Current logic: push all batches together into one big list.   
    self._reset_batch_dicts()
    if self.pp_model.training: self.pp_model.eval()
    if self.vm_model.training: self.vm_model.eval()
    # The "train_eval" dataloader is the same as train but a bigger batch size and explicitly no shuffling
    dl_key = "train_eval" if split == "train" else split
    dl_raw = self.ds.dld_raw[dl_key]
    dl_tkn = self.ds.dld_tkn[dl_key]
    with torch.no_grad(): 
        for self.batch_num, (data, raw) in enumerate(zip(dl_tkn, dl_raw)):
            logger.debug(f"EVAL: {split} with dl_key {dl_key}")
            logger.debug(f"Elements in data_d[{split}]: {len(self.data_d[split])}")
            logger.debug(show_gpu(f'EVAL, epoch {self.epoch}, batch {self.batch_num}, GPU memory usage after loading data: '))
            assert data['input_ids'].shape[0] == len(raw['text_with_prefix'])
            self._reset_batch_dicts()
            assert len(self.batch_d) == len(self.batch_time_d) == len(self.batch_wandb_d) == 0 
            for k, v in data.items():
                # Eval data isn't loaded on GPU by default unlike train data. This is because train dataloader goes 
                # through accelerator `prepare` function, but eval dataloaders don't. So here we load the data onto GPU 
                if data[k].device != self._cfg.device: data[k] = data[k].to(self._cfg.device)
            pp_output, pp_l = self._pp_model_forward(data)
            _ = self._loss_fn(data, raw, pp_output, pp_l)
            self._add_batch_vars_to_batch_d(raw, data, pp_l)
            self.data_d[split].append(self.batch_d)
            logger.debug(show_gpu(f'EVAL, epoch {self.epoch}, batch {self.batch_num}, GPU memory usage after loss_fn pass: '))

    
    

    
dataset = ds.dsd_raw[d['split']]
dataloader = ds.dld_raw['train_eval'] if d['split'] == 'train' else ds.dld_raw[d['split']]

In [None]:
from fastcore.basics import patch_to

In [None]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x2abba6658250>