In [None]:
#default_exp data

In [None]:
#hide
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
#export
import torch, random, pandas as pd, os, warnings, shutil, uuid
from torch.utils.data import DataLoader, RandomSampler
from datasets import load_dataset, load_from_disk, DatasetDict, ClassLabel
from IPython.display import display, HTML
from travis_attack.models import get_vm_probs
from travis_attack.config import Config
from travis_attack.utils import robust_rmtree, timecode
from IPython.core.debugger import set_trace
import logging
logger = logging.getLogger("travis_attack.data")


In [None]:
#hide
import inspect


# Preparing data

## Classes 

### Base class

In [None]:
#export
class ProcessedDataset: 
    """Class that wraps a raw dataset (e.g. from huggingface datasets) and performs preprocessing on it."""
    def __init__(self, cfg, vm_tokenizer, vm_model, pp_tokenizer, sts_model,
                 load_processed_from_file=True): 
        """load_processed_from_file: set to true to load completed version from file, false will process the data. """
        self._cfg,self._vm_tokenizer,self._vm_model,self._pp_tokenizer,self._sts_model = cfg,vm_tokenizer,vm_model,pp_tokenizer,sts_model
        shard_suffix = f"_{self._cfg.n_shards}_shards" if self._cfg.use_small_ds else ""
        self.cache_path_raw = f"{self._cfg.path_data_cache}{self._cfg.dataset_name}_raw{shard_suffix}"
        self.cache_path_tkn = f"{self._cfg.path_data_cache}{self._cfg.dataset_name}_tkn{shard_suffix}"
        
        logger.info(f"Will load dataset {self._cfg.dataset_name} with use_small_ds set to {self._cfg.use_small_ds}")
        
        if load_processed_from_file:
            if os.path.exists(self.cache_path_raw) and os.path.exists(self.cache_path_tkn):
                logger.info("Cache file found for processed dataset, so loading that dataset.")
                self.dsd_raw = load_from_disk(self.cache_path_raw) 
                self.dsd_tkn = load_from_disk(self.cache_path_tkn)
                self._prep_dataloaders()
            else: 
                warnings.warn("Cache file not found, so will now process the raw dataset.")
                self._preprocess_dataset() 
        else:   
            self._preprocess_dataset() 
        self._update_cfg()
        
        logger.debug(f"Dataset lengths: {self._cfg.ds_length}")
        logger.debug(f"Total training epochs:{self._cfg.n_train_steps}")
        logger.debug(f"Last batch size in each epoch is: {self._cfg.dl_leftover_batch_size}")
        logger.debug(f"Dataloader batch sizes are: {self._cfg.dl_batch_sizes}") 
            
            
    def _prep_dsd_simple(self): 
        """Load the simple dataset and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        dsd = DatasetDict()
        for s in self._cfg.splits:  
            dsd[s] = load_dataset('csv', 
                data_files=f"{self._cfg.path_data}simple_dataset_{s}.csv", keep_in_memory=False)['train']
        return dsd
        
    def _prep_dsd_rotten_tomatoes(self):
        """Load the rotten tomatoes dataet and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        dsd = load_dataset("rotten_tomatoes")
        dsd['valid'] = dsd.pop('validation')  # "valid" is easier than "validation" 
        # make sure that all datasets have the same number of labels as what the victim model predicts
        for _,ds in dsd.items(): assert ds.features[self._cfg.label_cname].num_classes == self._cfg.vm_num_labels 
        return dsd 
    
    def _prep_dsd_raw_snli(self): 
        ## For snli
        # remove_minus1_labels = lambda x: x[label_cname] != -1
        # ds_train = ds_train.filter(remove_minus1_labels)
        # valid = valid.filter(remove_minus1_labels)
        # test = test.filter(remove_minus1_labels)
        raise NotImplementedError("SNLI not implemented yet.")
    
    def _preprocess_dataset(self): 
        """Add columns, tokenize, transform, prepare dataloaders, and do other preprocessing tasks."""
        if   self._cfg.dataset_name == "simple":          dsd = self._prep_dsd_simple()
        elif self._cfg.dataset_name == "rotten_tomatoes": dsd = self._prep_dsd_rotten_tomatoes()
        else: raise Exception("cfg.dataset_name must be either 'simple' or 'rotten_tomatoes'")
        dsd = dsd.map(self._add_idx, batched=True, with_indices=True)  # add idx column
        if self._cfg.use_small_ds: dsd = self._shard_dsd(dsd)  # do after adding idx so it's consistent across runs
        # add VM score & filter out misclassified examples.
        # use a common variable dsd, add all columns, and later filter columns to get dsd_raw and dsd_tkn
        dsd = dsd.map(self._add_vm_orig_score, batched=True)  
        if self._cfg.remove_misclassified_examples:  dsd = dsd.filter(lambda x: x['orig_vm_predclass'] == x['label']) 
        dsd = dsd.map(self._add_sts_orig_embeddings, batched=True)  # add STS score 
        dsd = dsd.map(self._tokenize_fn,             batched=True)  # tokenize
        dsd = dsd.map(self._add_n_tokens,            batched=True)  # add n_tokens
        dsd = dsd.map(self._add_n_letters,           batched=True)  # add n_letters
        if self._cfg.bucket_by_length: dsd = dsd.sort("n_tokens", reverse=True)  # sort by n_tokens (high to low), useful for cuda memory caching
        # Split dsd into dsd_raw and dsd_tkn
        assert dsd.column_names['train'] == dsd.column_names['valid'] == dsd.column_names['test']
        self.cnames_dsd_raw = ['idx', 'text', 'label']
        self.cnames_dsd_tkn = [o for o in dsd.column_names['train'] if o != 'text'] 
        self.dsd_raw = dsd.remove_columns([o for o in  dsd['train'].column_names if o not in self.cnames_dsd_raw])
        self.dsd_tkn = dsd.remove_columns(["text"])
        for s in self._cfg.splits: assert len(self.dsd_raw[s]) == len(self.dsd_tkn[s])  # check ds has same number of elements in raw and tkn
        self._cache_processed_ds()
        self._prep_dataloaders()
        
    def _prep_dataloaders(self): 
        self.dld_raw = self._get_dataloaders_dict(self.dsd_raw, collate_fn=self._collate_fn_raw)  # dict of data loaders that serve raw text
        self.dld_tkn = self._get_dataloaders_dict(self.dsd_tkn, collate_fn=self._collate_fn_tkn)  # dict of data loaders that serve tokenized text
        
    def _add_idx(self, batch, idx):
        """Add row numbers"""
        batch['idx'] = idx 
        return batch   
    
    def _add_n_tokens(self, batch): 
        """Add the number of tokens of the orig text """
        batch['n_tokens'] = [len(o) for o in batch['input_ids']]
        return batch 
    
    def _add_n_letters(self, batch): 
        batch['n_letters'] = [len(o) for o in batch['text']]
        return batch
    
    def _add_sts_orig_embeddings(self, batch): 
        """Add the sts embeddings of the orig text"""
        batch['orig_sts_embeddings'] = self._sts_model.encode(batch[self._cfg.orig_cname], batch_size=64, convert_to_tensor=False)
        return batch
    
    def _add_vm_orig_score(self, batch): 
        """Add the vm score of the orig text"""
        labels = torch.tensor(batch[self._cfg.label_cname], device=self._cfg.device)
        orig_probs,orig_predclass = get_vm_probs(batch[self._cfg.orig_cname], self._cfg, self._vm_tokenizer,
                                                 self._vm_model, return_predclass=True)
        batch['orig_truelabel_probs'] = torch.gather(orig_probs,1, labels[:,None]).squeeze().cpu().tolist()
        batch['orig_vm_predclass'] = orig_predclass.cpu().tolist()
        return batch
    
    def _tokenize_fn(self, batch):  
        """Tokenize a batch of orig text using the paraphrase tokenizer."""
        return self._pp_tokenizer(batch[self._cfg.orig_cname], truncation=True, max_length=self._cfg.orig_max_length)  
    
    def _collate_fn_tkn(self, x): 
        """Collate function used by the DataLoader that serves tokenized data. 
        x is a list (with length batch_size) of dicts. Keys should be the same across dicts.
        I guess an error is raised if not. """
        # check all keys are the same in the list. the assert is quick (~1e-5 seconds)
        for o in x: assert set(o) == set(x[0])
        d = dict()
        for k in x[0].keys():  d[k] = [o[k] for o in x]
        return self._pp_tokenizer.pad(d, pad_to_multiple_of=self._cfg.orig_padding_multiple, return_tensors="pt")

    def _collate_fn_raw(self, x): 
        """Collate function used by the DataLoader that serves raw data. x is a list of data."""
        d = dict()
        for o in x: assert set(o) == set(x[0])  # check all keys are the same in list
        for k in x[0].keys(): d[k] = [o[k] for o in x]
        return d 

    def _get_sampler(self, ds): 
        """Returns a RandomSampler. Used so we can keep the same shuffle order across multiple data loaders.
        Used when self._cfg.shuffle_train = True"""
        g = torch.Generator()
        g.manual_seed(seed)
        return RandomSampler(ds, generator=g)
    
    def _shard_dsd(self, dsd):
        """Replaces dsd with a smaller shard of itself."""
        for k,v in dsd.items():  
            dsd[k] = v.shard(self._cfg.n_shards, 0, contiguous=self._cfg.shard_contiguous)
        return dsd
        
    def _get_dataloaders_dict(self, dsd, collate_fn): 
        """Prepare a dict of dataloaders for train, valid and test"""
        if self._cfg.bucket_by_length and self._cfg.shuffle_train:  raise Exception("Can only do one of bucket by length or shuffle")
        d = dict()
        for split, ds in dsd.items(): 
            if self._cfg.shuffle_train:
                if split == "train": 
                    sampler = self.get_sampler(ds)
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_train, 
                                           sampler=sampler, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
                else: 
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_eval, 
                                           shuffle=False, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
            if self._cfg.bucket_by_length: 
                batch_size = self._cfg.batch_size_train if split == "train" else self._cfg.batch_size_eval
                d[split] =  DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, 
                                       num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 

        # Add eval dataloader for train: same as train but bigger batch size and explicitly no shuffling.
        d['train_eval'] = DataLoader(dsd['train'], batch_size=self._cfg.batch_size_eval, shuffle=False,
                                    collate_fn=collate_fn, 
                                     num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
        return d 
    
    def _update_cfg(self): 
        self._cfg.ds_length,self._cfg.dl_n_batches,self._cfg.dl_leftover_batch_size,self._cfg.dl_batch_sizes = dict(),dict(),dict(),dict()
        def get_dl_batch_sizes(batch_size, dl_n_batches): 
            if self._cfg.dl_leftover_batch_size[k] == 0: 
                return [batch_size for i in range(dl_n_batches)]
            else: 
                l = [batch_size for i in range(dl_n_batches - 1)]
                l.append(self._cfg.dl_leftover_batch_size[k])
                return l
                
        for k,v in self.dsd_raw.items(): self._cfg.ds_length[k] = len(v)   # Dataset lengths
        for k,v in self.dld_raw.items(): 
            self._cfg.dl_n_batches[k] = len(v)   # Dataloader lengths 
            # Dataloader last batch size and list of batch sizes
            ds_k = "train" if k == "train_eval" else k 
            if k == "train": 
                self._cfg.dl_leftover_batch_size[k] = self._cfg.ds_length[ds_k] % self._cfg.batch_size_train
                self._cfg.dl_batch_sizes[k]     = get_dl_batch_sizes(self._cfg.batch_size_train, self._cfg.dl_n_batches[k])
            else: 
                self._cfg.dl_leftover_batch_size[k] = self._cfg.ds_length[ds_k] % self._cfg.batch_size_eval
                self._cfg.dl_batch_sizes[k]     = get_dl_batch_sizes(self._cfg.batch_size_eval, self._cfg.dl_n_batches[k])
            
        # Total number of training steps
        self._cfg.n_train_steps = self._cfg.n_train_epochs * self._cfg.dl_n_batches['train']
    
    def _cache_processed_ds(self):
        def _reset_dir(path): 
            if os.path.exists(path) and os.path.isdir(path):    
                # So deleting the old files sometimes throws errors because of race conditions, I think 
                # so as a workaround we will just move files to old directories and then periodicallly clean them. 
                #                robust_rmtree(path, logger=None, max_retries=6)  
                path_old_files = f"{self._cfg.path_data_cache}old_files/"
                os.makedirs(path_old_files, exist_ok=True)
                shutil.move(path, f"{path_old_files}{uuid.uuid4().hex}") 
            os.makedirs(path, exist_ok=True)
        _reset_dir(self.cache_path_raw)
        _reset_dir(self.cache_path_tkn)
        self.dsd_raw.save_to_disk(dataset_dict_path = self.cache_path_raw)
        self.dsd_tkn.save_to_disk(dataset_dict_path = self.cache_path_tkn)
        
    def show_random_elements(self, ds, num_examples=10):
        """Print some elements in a nice format so you can take a look at them. 
        Split is one of 'train', 'test', 'valid'. 
        Use for a dataset `ds` from the `dataset` package.  """
        assert num_examples <= len(ds), "Can't pick more elements than there are in the dataset."
        picks = []
        for _ in range(num_examples):
            pick = random.randint(0, len(ds)-1)
            while pick in picks:
                pick = random.randint(0, len(ds)-1)
            picks.append(pick)
        df = pd.DataFrame(ds[picks])
        for column, typ in ds.features.items():
            if isinstance(typ, ClassLabel):
                df[column] = df[column].transform(lambda i: typ.names[i])
        display(HTML(df.to_html()))

## Usage 

### Basics 

Here we have defined a class `ProcessedDataset` that will load and preprocess a dataset. But before processing the dataset you must load both the config object and all models/tokenizers, so we do this first. 

In [None]:
from travis_attack.models import prepare_models
cfg = Config()
vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, cfg = prepare_models(cfg)

Currently there are two choices for dataset: 

* `simple`, a dataset of simple sentences with four elements each in the train, test and valid splits
* `rotten_tomatoes`, a dataset of movie reviews scraped from the Rotten Tomatoes site. 

The dataset is specified by the config class. There are two ways to do this. 

1. Edit the `self.dataset_name` variable in the Config class to either `simple` or `rotten_tomatoes`. An error will be thrown if the name is not one of these two. This is the best way to use when doing runs.   
2. Use the `adjust_dataset_...` methods of the config class: e.g. `cfg = cfg.adjust_dataset_for_rotten_tomatoes_dataset() = Config()`. This is easiest for automated testing so we will do this here. 

Once the config is specified and loaded, create an object of class `ProcessedDataset` by passing the config, models and tokenizers as variables. This will do all preprocessing automatically in creating the object (the preprocessing code is in the `__init__()` function of the class.  

In [None]:
cfg_simple = cfg.adjust_config_for_simple_dataset()
ds = ProcessedDataset(cfg_simple, vm_tokenizer, vm_model, pp_tokenizer, sts_model, 
                      load_processed_from_file=False)

Using custom data configuration default-b253756c445fb811
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-b253756c445fb811/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-c802946231f72062
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-c802946231f72062/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-43a49c5188c42e69
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-43a49c5188c42e69/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

If you want to use the small dataset adjust the config before creating the `ProcessedDataset` object. 

In [None]:
cfg_rt_small_ds = cfg.adjust_config_for_rotten_tomatoes_dataset().small_ds()
ds = ProcessedDataset(cfg_rt_small_ds, vm_tokenizer, vm_model, pp_tokenizer, sts_model,
                      load_processed_from_file=False)

Using custom data configuration default
Reusing dataset rotten_tomatoes_movie_review (/data/tproth/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

### Accessing datasets

You can access raw data with `ds.dsd_raw` and processed data with the `ds.dsd_tkn`. (The dsd here stands for "DatasetDict")

In [None]:
ds.dsd_raw

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'idx'],
        num_rows: 154
    })
    test: Dataset({
        features: ['text', 'label', 'idx'],
        num_rows: 18
    })
    valid: Dataset({
        features: ['text', 'label', 'idx'],
        num_rows: 16
    })
})

In [None]:
ds.dsd_tkn

DatasetDict({
    train: Dataset({
        features: ['label', 'idx', 'orig_truelabel_probs', 'orig_vm_predclass', 'orig_sts_embeddings', 'input_ids', 'attention_mask', 'n_tokens', 'n_letters'],
        num_rows: 154
    })
    test: Dataset({
        features: ['label', 'idx', 'orig_truelabel_probs', 'orig_vm_predclass', 'orig_sts_embeddings', 'input_ids', 'attention_mask', 'n_tokens', 'n_letters'],
        num_rows: 18
    })
    valid: Dataset({
        features: ['label', 'idx', 'orig_truelabel_probs', 'orig_vm_predclass', 'orig_sts_embeddings', 'input_ids', 'attention_mask', 'n_tokens', 'n_letters'],
        num_rows: 16
    })
})

You can access elements by indexing: 

In [None]:
ds.dsd_raw['valid'][0:2]

{'text': ['berling and béart . . . continue to impress , and isabelle huppert . . . again shows uncanny skill in getting under the skin of her characters .',
  "i'm not sure which will take longer to heal : the welt on johnny knoxville's stomach from a riot-control projectile or my own tortured psyche ."],
 'label': [1, 0],
 'idx': [400, 750]}

In [None]:
ds.dsd_tkn['valid'][0:2]

{'label': [1, 0],
 'idx': [400, 750],
 'orig_truelabel_probs': [0.9531264901161194, 0.8985466957092285],
 'orig_vm_predclass': [1, 0],
 'orig_sts_embeddings': [[0.10162420570850372,
   -0.15061667561531067,
   -0.02832075022161007,
   0.18470776081085205,
   -0.144411101937294,
   -0.02936149388551712,
   0.3306199908256531,
   0.07828425616025925,
   -0.030337156727910042,
   -0.14721500873565674,
   0.23712067306041718,
   -0.288776159286499,
   -0.07359367609024048,
   -0.28062987327575684,
   0.04253583401441574,
   0.07229325920343399,
   0.15865543484687805,
   0.22507694363594055,
   -0.143656387925148,
   -0.08837874233722687,
   -0.3990294337272644,
   -0.023890824988484383,
   0.05965694040060043,
   0.1094372570514679,
   0.11948876827955246,
   -0.30829957127571106,
   -0.10492965579032898,
   0.1315850019454956,
   0.11578385531902313,
   -0.45641112327575684,
   -0.06523998081684113,
   -0.08991163969039917,
   -0.30539849400520325,
   0.14492005109786987,
   -0.265550404

Alternately you can look at some random elements of a dataset with the `ds.show_random_elements()` method. 

In [None]:
ds.show_random_elements(ds.dsd_raw['train'], num_examples=3)

Unnamed: 0,text,label,idx
0,the affectionate loopiness that once seemed congenital to demme's perspective has a tough time emerging from between the badly dated cutesy-pie mystery scenario and the newfangled hollywood post-production effects .,neg,6150
1,a battle between bug-eye theatre and dead-eye matinee .,neg,6400
2,"it is a comedy that's not very funny and an action movie that is not very thrilling ( and an uneasy alliance , at that ) .",neg,6450


In [None]:
ds.show_random_elements(ds.dsd_tkn['train'], num_examples=3)

Unnamed: 0,label,idx,orig_truelabel_probs,orig_vm_predclass,orig_sts_embeddings,input_ids,attention_mask,n_tokens,n_letters
0,neg,5900,0.812017,0,"[0.26082396507263184, 0.16894088685512543, -0.000696819624863565, 0.42390045523643494, 0.12588629126548767, -0.2456263303756714, 0.25331953167915344, -0.10000850260257721, -0.07515019178390503, -0.02511535957455635, 0.03716929629445076, -0.1675442010164261, 0.21177415549755096, 0.10612250864505768, -0.11261694133281708, 0.018164006993174553, 0.08856771141290665, -0.0724453255534172, -0.4939577579498291, 0.15241022408008575, 0.16478994488716125, 0.30741605162620544, 0.19510920345783234, -0.04479996860027313, 0.16139273345470428, -0.10825827717781067, -0.02531084045767784, -0.34616410732269287, -0.07442427426576614, -0.15140996873378754, 0.05133926495909691, -0.49399861693382263, -0.27262526750564575, -0.12282334268093109, 0.14226046204566956, 0.2190028727054596, 0.08068602532148361, 0.23904889822006226, 0.07611984759569168, -0.2089087814092636, -0.07901982218027115, 0.37364712357521057, -0.11654812097549438, 0.11241912841796875, -0.16392573714256287, -0.2552502155303955, -0.04766825586557388, -0.14634528756141663, -0.02990572713315487, -0.1224118322134018, -0.28583237528800964, -0.05020933970808983, 0.08913134038448334, -0.13715822994709015, 0.12391725927591324, -0.1633434295654297, 0.29116320610046387, 0.30185750126838684, -0.2698788046836853, 0.22278627753257751, 0.02050643600523472, 0.17432183027267456, 0.25925225019454956, -0.0584811270236969, 0.16076985001564026, -0.3483920991420746, -0.04350963607430458, 0.12565301358699799, 0.0007046845857985318, 0.7770388126373291, 0.14977848529815674, -0.3200789988040924, -0.11471429467201233, 0.1588069200515747, 0.2392265349626541, 0.06355749815702438, -0.11811453104019165, 0.017058227211236954, 0.1150524839758873, 0.015598387457430363, 0.23242153227329254, -0.08645912259817123, 0.37758150696754456, 0.18046456575393677, -0.29646632075309753, -0.17986299097537994, -0.017268827185034752, -0.1653253436088562, -0.07298078387975693, -0.14704126119613647, -0.08836515247821808, -0.2627905309200287, 0.11998080462217331, 0.11300908774137497, 0.20445899665355682, 0.12975050508975983, -0.13014553487300873, 0.1109667420387268, -0.2775782644748688, 0.04304569587111473, ...]","[16478, 415, 131, 116, 6796, 112, 21024, 119, 115, 8704, 110, 108, 47011, 61377, 1759, 117, 19965, 110, 107, 1]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",20,89
1,neg,6550,0.560361,0,"[-0.28279241919517517, -0.10911419987678528, 0.185452401638031, -0.31022754311561584, 0.2721474766731262, 0.143635094165802, 0.1588427573442459, 0.07887238264083862, -0.07462634146213531, -0.04302912577986717, 0.19168832898139954, 0.12744538486003876, 0.03302061930298805, 0.35184308886528015, 0.02633996494114399, -0.12195242941379547, 0.2733384966850281, -0.02968870848417282, 0.2663077116012573, 0.033906809985637665, 0.014440209604799747, 0.02836601994931698, 0.3171355128288269, 0.17385075986385345, -0.09732065349817276, -0.2866568863391876, 0.17569024860858917, -0.004056242294609547, -0.11474941670894623, -0.38688334822654724, 0.010693566873669624, 0.08699995279312134, 0.1914653778076172, -0.07150154560804367, -0.11033782362937927, -0.035256337374448776, 0.3646480143070221, -0.07560033351182938, 0.10404468327760696, -0.15902841091156006, -0.18942876160144806, -0.0454980842769146, 0.07817275822162628, -0.11899391561746597, -0.021510759368538857, -0.2692556381225586, 0.07824939489364624, -0.23629973828792572, -0.06936605274677277, -0.2710849642753601, 0.015071428380906582, 0.1451604962348938, -0.08288583159446716, -0.007211702410131693, 0.09828613698482513, -0.18080736696720123, 0.1366795301437378, -0.08440432697534561, 0.07456615567207336, -0.09058354794979095, -0.36064258217811584, -0.12619048357009888, 0.03774487227201462, -0.33537110686302185, 0.23234142363071442, -0.15576718747615814, -0.21634942293167114, -0.6222754716873169, 0.014109872281551361, -0.1254936158657074, -0.07818818837404251, 0.0686282068490982, 0.1632554680109024, 0.03296481445431709, 0.05164318159222603, -0.0860288068652153, 0.10972880572080612, 0.05365408584475517, 0.013793966732919216, -0.15731367468833923, -0.30272218585014343, -0.21046552062034607, 0.3018110990524292, -0.06533841788768768, -0.35051441192626953, -0.3832939565181732, -0.06097899377346039, 0.08307389169931412, -0.17937450110912323, 0.3539979159832001, -0.29797446727752686, 0.2458031326532364, 0.3172275125980377, 0.2503081262111664, -0.0986585021018982, -0.25740835070610046, 0.11290610581636429, -0.01475759968161583, -0.3525673449039459, 0.19793599843978882, ...]","[3721, 1121, 117, 288, 109, 209, 474, 120, 131, 116, 1092, 2190, 115, 2453, 2092, 110, 108, 114, 896, 141, 109, 1348, 121, 14787, 12103, 1772, 3021, 8805, 143, 51265, 110, 158, 120, 2841, 130, 610, 372, 587, 113, 109, 4508, 5088, 113, 52182, 51167, 115, 109, 450, 121, 11513, 47414, 415, 278, 110, 107, 1]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",56,230
2,pos,950,0.894819,1,"[-0.1295185387134552, 0.08725076913833618, 0.16249580681324005, 0.09270603209733963, -0.05479922890663147, -0.4042092561721802, 0.044919490814208984, -0.015336044132709503, -0.05520963668823242, -0.1714986115694046, 0.13645341992378235, -0.21943482756614685, -0.01515096053481102, -0.1524168848991394, -0.021992869675159454, -0.1730937957763672, 0.4978390336036682, -0.39657658338546753, -0.06187061592936516, 0.11838138103485107, 0.0286897923797369, -0.09177327156066895, -0.11579828709363937, 0.0384468212723732, 0.27413299679756165, -0.08362582325935364, -0.33973729610443115, 0.289218008518219, -0.34868723154067993, 0.010074760764837265, 0.41043534874916077, 0.2695227265357971, -0.47849276661872864, 0.02920924872159958, 0.17637252807617188, 0.02537361904978752, 0.25593501329421997, -0.2687920928001404, 0.19637033343315125, 0.10471389442682266, 0.5327776670455933, 0.26942288875579834, 0.09989957511425018, 0.0004130508750677109, -0.22754546999931335, -0.16906039416790009, 0.1634337306022644, -0.46608710289001465, -0.26222971081733704, 0.03476766496896744, 0.020135406404733658, 0.16166219115257263, -0.11212167143821716, 0.0611935555934906, 0.1479577124118805, -0.35471445322036743, -0.14004448056221008, -0.2564741373062134, -0.34316256642341614, -0.32408127188682556, 0.12188054621219635, -0.09297027438879013, 0.07313328981399536, -0.019537940621376038, -0.11266116797924042, -0.49935227632522583, -0.5597988963127136, 0.057982150465250015, 0.30742019414901733, -0.04193183407187462, 0.0877077728509903, 0.11447340250015259, 0.32417207956314087, 0.15870186686515808, -0.13938716053962708, 0.20942211151123047, 0.16791647672653198, -0.40716519951820374, -0.47167089581489563, -0.10710559040307999, -0.06997077912092209, -0.1667996644973755, -0.3509700298309326, -0.18834857642650604, -0.059045568108558655, -0.02201288379728794, 0.10729056596755981, -0.11800283938646317, 0.17891091108322144, 0.17431680858135223, -0.21734818816184998, 0.3496575355529785, 0.34084752202033997, 0.1076493114233017, 0.3619900345802307, -0.01222921907901764, 0.14957047998905182, -0.015920229256153107, 0.03864670917391777, 0.060621462762355804, ...]","[14729, 113, 1316, 218, 129, 332, 154, 197, 372, 22245, 113, 85810, 19425, 110, 108, 155, 126, 131, 116, 309, 848, 6028, 110, 107, 1]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",25,104


### Accessing DataLoaders 

We can access dataloaders for the raw text in `ds.dsd_raw` with `ds.dld_raw`, and for the tokenised text in `ds.dsd_tkn` with `ds_dld_tkn`. Both of these are dictionaries of dataloaders with keys `['train', 'valid', 'test', 'train_eval]`. 

In [None]:
print("Dataloader dict has keys:", ds.dld_raw.keys())
batch_raw = next(iter(ds.dld_raw['train']))
batch_tkn = next(iter(ds.dld_tkn['train']))
print(batch_raw.keys())
print(batch_tkn.keys())
print("Tokenised input is of shape:", batch_tkn['input_ids'].shape)

Dataloader dict has keys: dict_keys(['train', 'test', 'valid', 'train_eval'])
dict_keys(['text', 'label', 'idx'])
dict_keys(['label', 'idx', 'orig_truelabel_probs', 'orig_vm_predclass', 'orig_sts_embeddings', 'input_ids', 'attention_mask', 'n_tokens', 'n_letters'])
Tokenised input is of shape: torch.Size([4, 56])


#hide
## Export 

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 02_tests.ipynb.
Converted 03_config.ipynb.
Converted 07_models.ipynb.
Converted 10_data.ipynb.
Converted 20_trainer.ipynb.
Converted 25_insights.ipynb.
Converted index.ipynb.
Converted run.ipynb.
Converted show_examples.ipynb.
