In [None]:
#default_exp data

In [None]:
#hide
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
#export
import torch, random, pandas as pd
from torch.utils.data import DataLoader, RandomSampler
from datasets import load_dataset, load_from_disk, DatasetDict, ClassLabel
from IPython.display import display, HTML
from travis_attack.models import get_vm_probs
from travis_attack.config import Config

In [None]:
#hide
import inspect
from IPython.core.debugger import set_trace

# Preparing data

## Classes 

### Base class

In [None]:
#export
class ProcessedDataset: 
    """Class that wraps a raw dataset (e.g. from huggingface datasets) and performs preprocessing on it."""
    def __init__(self, cfg, vm_tokenizer, vm_model, pp_tokenizer, sts_model): 
        self._cfg,self._vm_tokenizer,self._vm_model,self._pp_tokenizer,self._sts_model = cfg,vm_tokenizer,vm_model,pp_tokenizer,sts_model
        if   self._cfg.dataset_name == "simple":          self._prep_dsd_raw_simple()
        elif self._cfg.dataset_name == "rotten_tomatoes": self._prep_dsd_raw_rotten_tomatoes()
        else: raise Exception("cfg.dataset_name must be either 'simple' or 'rotten_tomatoes'")
        self._preprocess_dataset() 
        self._update_cfg()
            
    def _prep_dsd_raw_simple(self): 
        """Load the simple dataset and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        self.dsd_raw = DatasetDict()
        for s in self._cfg.splits:  
            self.dsd_raw[s] = load_dataset('csv', data_files=f"{self._cfg.path_data}simple_dataset_{s}.csv")['train']    
        
    def _prep_dsd_raw_rotten_tomatoes(self):
        """Load the rotten tomatoes dataet and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        self.dsd_raw = load_dataset("rotten_tomatoes")
        self.dsd_raw['valid'] = self.dsd_raw.pop('validation')  # "valid" is easier than "validation" 
        # make sure that all datasets have the same number of labels as what the victim model predicts
        for _,ds in self.dsd_raw.items(): assert ds.features[self._cfg.label_cname].num_classes == self._cfg.vm_num_labels 

    def _prep_dsd_raw_snli(self): 
        ## For snli
        # remove_minus1_labels = lambda x: x[label_cname] != -1
        # ds_train = ds_train.filter(remove_minus1_labels)
        # valid = valid.filter(remove_minus1_labels)
        # test = test.filter(remove_minus1_labels)
        raise NotImplementedError("SNLI not implemented yet.")
    
    def _preprocess_dataset(self): 
        """Add columns, tokenize, transform, prepare dataloaders, and do other preprocessing tasks."""
        ##### TODO: why do you need train_eval? how is it different from train?
        self.dsd_raw = self.dsd_raw.map(self._add_idx, batched=True, with_indices=True)  # add idx column
        if self._cfg.use_small_ds: self._shard_dsd_raw()  # do after adding idx so it's consistent across runs
        # add VM score & filter out misclassified examples. (at this point lengths are diff between dsd_raw and dsd_tkn)
        self.dsd_tkn = self.dsd_raw.map(self._add_vm_orig_score, batched=True)  
        if self._cfg.remove_misclassified_examples: 
            self.dsd_tkn = self.dsd_tkn.filter(lambda x: x['orig_vm_predclass'] == x['label']) 
        self.dsd_tkn = self.dsd_tkn.map(self._add_sts_orig_embeddings, batched=True)  # add STS score 
        self.dsd_tkn = self.dsd_tkn.map(self._tokenize_fn,             batched=True)  # tokenize
        self.dsd_tkn = self.dsd_tkn.map(self._add_n_tokens,            batched=True)  # add n_tokens
        if self._cfg.bucket_by_length: self.dsd_tkn = self.dsd_tkn.sort("n_tokens", reverse=True)  # sort by n_tokens (high to low), useful for cuda memory caching
        
        # filter out rows in dsd_raw that aren't in dsd_tkn
        for s in self._cfg.splits:
            idx_list = self.dsd_tkn[s]['idx']
            self.dsd_raw[s] = self.dsd_raw[s].filter(lambda x: x['idx'] in idx_list)
        for s in self._cfg.splits: assert len(self.dsd_raw[s]) == len(self.dsd_tkn[s])  # check ds has same number of elements in raw and tkn
        
        # Prepare dataloaders
        self.dld_raw = self._get_dataloaders_dict(self.dsd_raw, collate_fn=self._collate_fn_raw)  # dict of data loaders that serve raw text
        self.dld_tkn = self._get_dataloaders_dict(self.dsd_tkn, collate_fn=self._collate_fn_tkn)  # dict of data loaders that serve tokenized text
        
    def _add_idx(self, batch, idx):
        """Add row numbers"""
        batch['idx'] = idx 
        return batch   
    
    def _add_n_tokens(self, batch): 
        """Add the number of tokens of the orig text """
        batch['n_tokens'] = [len(o) for o in batch['input_ids']]
        return batch 
    
    def _add_sts_orig_embeddings(self, batch): 
        """Add the sts embeddings of the orig text"""
        batch['orig_sts_embeddings'] = self._sts_model.encode(batch[self._cfg.orig_cname], batch_size=64, convert_to_tensor=False)
        return batch
    
    def _add_vm_orig_score(self, batch): 
        """Add the vm score of the orig text"""
        labels = torch.tensor(batch[self._cfg.label_cname], device=self._cfg.device)
        orig_probs,orig_predclass = get_vm_probs(batch[self._cfg.orig_cname], self._cfg, self._vm_tokenizer,
                                                 self._vm_model, return_predclass=True)
        batch['orig_truelabel_probs'] = torch.gather(orig_probs,1, labels[:,None]).squeeze().cpu().tolist()
        batch['orig_vm_predclass'] = orig_predclass.cpu().tolist()
        return batch
    
    def _tokenize_fn(self, batch):  
        """Tokenize a batch of orig text using the paraphrase tokenizer."""
        return self._pp_tokenizer(batch[self._cfg.orig_cname], truncation=True, max_length=self._cfg.orig_max_length)  
    
    def _collate_fn_tkn(self, x): 
        """Collate function used by the DataLoader that serves tokenized data."""
        d = dict()
        for k in ['idx', 'attention_mask', 'input_ids', 'label', 'orig_truelabel_probs', 'orig_sts_embeddings']: 
            d[k] = [o[k] for o in x]
        return self._pp_tokenizer.pad(d, pad_to_multiple_of=self._cfg.orig_padding_multiple, return_tensors="pt")

    def _collate_fn_raw(self, x): 
        """Collate function used by the DataLoader that serves raw data."""
        d = dict()
        for k in [self._cfg.orig_cname, 'idx']: d[k] = [o[k] for o in x]
        return d 

    def _get_sampler(self, ds): 
        """Returns a RandomSampler. Used so we can keep the same shuffle order across multiple data loaders.
        Used when self._cfg.shuffle_train = True"""
        g = torch.Generator()
        g.manual_seed(seed)
        return RandomSampler(ds, generator=g)
    
    def _shard_dsd_raw(self):
        """Replaces dsd_raw with a smaller shard of itself."""
        for k,v in self.dsd_raw.items():  
            self.dsd_raw[k] = v.shard(self._cfg.n_shards, 0, contiguous=self._cfg.shard_contiguous)
        
    def _get_dataloaders_dict(self, dsd, collate_fn): 
        """Prepare a dict of dataloaders for train, valid and test"""
        if self._cfg.bucket_by_length and self._cfg.shuffle_train:  raise Exception("Can only do one of bucket by length or shuffle")
        d = dict()
        for split, ds in dsd.items(): 
            if self._cfg.shuffle_train:
                if split == "train": 
                    sampler = self.get_sampler(ds)
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_train, 
                                           sampler=sampler, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
                else: 
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_eval, 
                                           shuffle=False, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
            if self._cfg.bucket_by_length: 
                batch_size = self._cfg.batch_size_train if split == "train" else self._cfg.batch_size_eval
                d[split] =  DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, 
                                       num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 

        # Add eval dataloader for train 
        d['train_eval'] = DataLoader(dsd['train'], batch_size=self._cfg.batch_size_eval, shuffle=False,
                                    collate_fn=collate_fn, 
                                     num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
        return d 
    
    def _update_cfg(self): 
        self._cfg.ds_length = dict()
        
        for k in self.dsd_raw.keys(): 
            self._cfg.ds_length[k] = len(self.dsd_raw[k])
        
    def show_random_elements(self, ds, num_examples=10):
        """Print some elements in a nice format so you can take a look at them. 
        Split is one of 'train', 'test', 'valid'. 
        Use for a dataset `ds` from the `dataset` package.  """
        assert num_examples <= len(ds), "Can't pick more elements than there are in the dataset."
        picks = []
        for _ in range(num_examples):
            pick = random.randint(0, len(ds)-1)
            while pick in picks:
                pick = random.randint(0, len(ds)-1)
            picks.append(pick)
        df = pd.DataFrame(ds[picks])
        for column, typ in ds.features.items():
            if isinstance(typ, ClassLabel):
                df[column] = df[column].transform(lambda i: typ.names[i])
        display(HTML(df.to_html()))

## Usage 

### Basics 

Here we have defined a class `ProcessedDataset` that will load and preprocess a dataset. But before processing the dataset you must load both the config object and all models/tokenizers, so we do this first. 

In [None]:
from travis_attack.models import prepare_models
cfg = Config()
vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, cfg = prepare_models(cfg)

Currently there are two choices for dataset: 

* `simple`, a dataset of simple sentences with four elements each in the train, test and valid splits
* `rotten_tomatoes`, a dataset of movie reviews scraped from the Rotten Tomatoes site. 

The dataset is specified by the config class. There are two ways to do this. 

1. Edit the `self.dataset_name` variable in the Config class to either `simple` or `rotten_tomatoes`. An error will be thrown if the name is not one of these two. This is the best way to use when doing runs.   
2. Use the `adjust_dataset_...` methods of the config class: e.g. `cfg = cfg.adjust_dataset_for_rotten_tomatoes_dataset() = Config()`. This is easiest for automated testing so we will do this here. 

Once the config is specified and loaded, create an object of class `ProcessedDataset` by passing the config, models and tokenizers as variables. This will do all preprocessing automatically in creating the object (the preprocessing code is in the `__init__()` function of the class.  

In [None]:
cfg_simple = cfg.adjust_config_for_simple_dataset()
ds = ProcessedDataset(cfg_simple,  vm_tokenizer, vm_model, pp_tokenizer, sts_model)

Using custom data configuration default-b253756c445fb811
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-b253756c445fb811/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-c802946231f72062
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-c802946231f72062/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-43a49c5188c42e69
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-43a49c5188c42e69/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

If you want to use the small dataset adjust the config before creating the `ProcessedDataset` object. 

In [None]:
cfg_rt_small_ds = cfg.adjust_config_for_rotten_tomatoes_dataset().small_ds()
ds = ProcessedDataset(cfg_rt_small_ds,  vm_tokenizer, vm_model, pp_tokenizer, sts_model)

Using custom data configuration default
Reusing dataset rotten_tomatoes_movie_review (/data/tproth/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/e06abb624abab47e1a64608fdfe65a913f5a68c66118408032644a3285208fb5)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
cfg_rt_small_ds.ds_length

{'train': 196, 'test': 23, 'valid': 23}

### Accessing datasets

You can access raw data with `ds.dsd_raw` and processed data with the `ds.dsd_tkn`. (The dsd here stands for "DatasetDict")

In [None]:
ds.dsd_raw

DatasetDict({
    train: Dataset({
        features: ['idx', 'label', 'text'],
        num_rows: 196
    })
    test: Dataset({
        features: ['idx', 'label', 'text'],
        num_rows: 23
    })
    valid: Dataset({
        features: ['idx', 'label', 'text'],
        num_rows: 23
    })
})

In [None]:
ds.dsd_tkn

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 196
    })
    test: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 23
    })
    valid: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 23
    })
})

You can access elements by indexing: 

In [None]:
ds.dsd_raw['valid'][0:2]

{'idx': [0, 40],
 'label': [1, 1],
 'text': ['compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
  'a perceptive , good-natured movie .']}

In [None]:
ds.dsd_tkn['valid'][0:2]

{'attention_mask': [[1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1],
  [1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1]],
 'idx': [320, 520],
 'input_ids': [[0,
   405,
   18,
   479,
   479,
   479,
   966,
   5,
   1823,
   1351,
   7,
   192,
   41,
   3025,
   2156,
   202,
   2021,
   7,
   434,
   11,
   39,
   5127,
   2202,
   2156,
   464,
   150,
   2405,
   1528,
   7,
   39,
   7797,
   19,
   10,
   822,
   1060,
   182,
   2087,
   16,
   2156,
   1341,
   3273,
   352,
   2156,
   59,
  

Alternately you can look at some random elements of a dataset with the `ds.show_random_elements()` method. 

In [None]:
ds.show_random_elements(ds.dsd_raw['train'], num_examples=3)

Unnamed: 0,idx,label,text
0,3480,pos,"mostly honest , this somber picture reveals itself slowly , intelligently , artfully ."
1,4880,neg,"apparently designed as a reverie about memory and regret , but the only thing you'll regret is remembering the experience of sitting through it ."
2,3640,pos,"when you think you've figured out bielinsky's great game , that's when you're in the most trouble : he's the con , and you're just the mark ."


In [None]:
ds.show_random_elements(ds.dsd_tkn['train'], num_examples=3)

Unnamed: 0,attention_mask,idx,input_ids,label,n_tokens,orig_sts_embeddings,orig_truelabel_probs,orig_vm_predclass,text
0,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",8240,"[0, 405, 16, 1099, 2156, 53, 1819, 45, 396, 13032, 25, 4000, 479, 2]",neg,14,"[0.07409124821424484, -0.28555047512054443, -0.06796510517597198, -0.3252895772457123, 0.036855921149253845, 0.3267824649810791, -0.015050191432237625, 0.1567767858505249, 0.23650391399860382, 0.03273976594209671, -0.09359420835971832, 0.2742508053779602, 0.32041049003601074, 0.2548767924308777, -0.032434217631816864, -0.12088172882795334, 0.7475268840789795, -0.3853236436843872, 0.009074598550796509, 0.14479413628578186, 0.059270359575748444, 0.11770723760128021, 0.35336387157440186, -0.045286357402801514, -0.18193283677101135, -0.16301120817661285, 0.08815793693065643, -0.08909820765256882, -0.09062936156988144, 0.24540287256240845, -0.22299952805042267, -0.1411578357219696, 0.08367875963449478, -0.020001769065856934, -0.48106276988983154, -0.04740080237388611, -0.19724921882152557, 0.07344017922878265, -0.007313099689781666, -0.2797631621360779, 0.0035986825823783875, 0.20114031434059143, -0.09965568780899048, 0.020152201876044273, -0.14113999903202057, 0.1161009818315506, 0.02756551466882229, -0.043488629162311554, -0.09224851429462433, 0.005897983908653259, 0.03382517397403717, 0.1340554803609848, -0.040826257318258286, -0.2407383918762207, -0.13136343657970428, -0.08395872265100479, 0.30115050077438354, 0.0979730486869812, -0.39226090908050537, 0.1936008334159851, 0.14521317183971405, -0.3308068811893463, 0.06126178801059723, -0.2474205642938614, 0.22160477936267853, -0.25219130516052246, -0.23798421025276184, 0.16496378183364868, -0.4071105122566223, 0.042752884328365326, 0.20724201202392578, -0.08050593733787537, 0.3702399730682373, 0.37275031208992004, -0.1345374584197998, 0.15251857042312622, 0.3488125801086426, -0.393639475107193, 0.07935017347335815, -0.21165278553962708, 0.15413376688957214, -0.057682741433382034, 0.06879235059022903, -0.12775708734989166, 0.10400110483169556, -0.38888782262802124, -0.008198261260986328, 0.18483886122703552, -0.07993965595960617, 0.07461892068386078, -0.028302567079663277, 0.33864718675613403, 0.17056851089000702, 0.33880072832107544, 0.4056403636932373, 0.15868745744228363, -0.1332210749387741, -0.2087058126926422, -0.3613317310810089, 0.47166791558265686, ...]",0.891341,0,"it is bad , but certainly not without merit as entertainment ."
1,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",2360,"[0, 627, 1374, 1683, 16, 21531, 8, 15955, 480, 8, 10, 7782, 8745, 7, 120, 15, 10, 792, 8, 2156, 37463, 2156, 27373, 2156, 22633, 479, 2]",pos,27,"[0.04401284456253052, -0.16428114473819733, -0.022668588906526566, 0.06744283437728882, -0.28163883090019226, -0.37925055623054504, 0.5876341462135315, 0.2901967167854309, -0.042840976268053055, 0.02269667014479637, 0.10843411833047867, -0.15031231939792633, -0.10917647182941437, -0.08137007802724838, 0.30406785011291504, 0.41915276646614075, 0.4364165663719177, -0.13273179531097412, -0.3431737720966339, 0.43271544575691223, -0.0889233872294426, 0.07603029161691666, -0.19508618116378784, 0.10890885442495346, -0.48713722825050354, -0.41958820819854736, -0.21349890530109406, -0.09140163660049438, -0.12284262478351593, 0.07895946502685547, -0.240946963429451, -0.10187531262636185, 0.4471614956855774, -0.17021271586418152, -0.19228853285312653, 0.27584782242774963, -0.38175585865974426, -0.033521562814712524, 0.10815444588661194, -0.25065308809280396, 0.021973855793476105, -0.17884160578250885, 0.04595806822180748, 0.29307445883750916, -0.20278792083263397, 0.10606871545314789, 0.2252230942249298, 0.06923412531614304, 0.2900986969470978, -0.06988903880119324, 0.29655948281288147, -0.4341394007205963, -0.07459859549999237, -0.44544002413749695, 0.03734232857823372, 0.3711625635623932, -0.04192529246211052, -0.16738931834697723, -0.2558805048465729, 0.2206665575504303, 0.013594625517725945, 0.34035345911979675, -0.13952438533306122, 0.24023175239562988, 0.34313473105430603, -0.2783070206642151, -0.4136805534362793, -0.054238516837358475, -0.13204915821552277, 0.21631714701652527, 0.5743158459663391, -0.09112890809774399, 0.00302168820053339, 0.12183459103107452, 0.3105069696903229, -0.22926218807697296, -0.2347143143415451, -0.21948403120040894, -0.3696134388446808, -0.10159849375486374, -0.2504853904247284, -0.17711061239242554, 0.4149620532989502, 0.2324047088623047, -0.2801794409751892, -0.2146194726228714, 0.40810856223106384, -0.16749858856201172, -0.041506774723529816, 0.5186586380004883, -0.20605482161045074, 0.031907256692647934, -0.19117681682109833, -0.03963850811123848, 0.09361729770898819, -0.05620230734348297, -0.4075450897216797, -0.09582684934139252, -0.5087288022041321, 0.17758481204509735, ...]",0.817904,1,"the overall effect is awe and affection -- and a strange urge to get on a board and , uh , shred , dude ."
2,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",8160,"[0, 281, 10, 9284, 6066, 14481, 8, 1900, 9001, 2156, 1900, 30, 1530, 10698, 5, 4392, 350, 3615, 479, 2]",neg,20,"[0.2546972930431366, 0.31509658694267273, -0.27262750267982483, 0.07603449374437332, 0.5300198197364807, 0.05333120748400688, 0.02824675478041172, 0.21855942904949188, 0.44942358136177063, 0.03154435381293297, 0.3224776089191437, 0.1438864767551422, 0.2147073596715927, -0.06450001150369644, 0.19891870021820068, -0.3067059814929962, 0.2084534913301468, -0.24242067337036133, -0.3110296130180359, 0.189059779047966, -0.3393832743167877, -0.05364064499735832, 0.5092564225196838, -0.2554610073566437, 0.25382623076438904, -0.17573469877243042, 0.16240203380584717, 0.06380569189786911, -0.2885981500148773, -0.3036135137081146, 0.2608518600463867, 0.3256901502609253, 0.1099996492266655, -0.2423398494720459, 0.26518669724464417, 0.011865802109241486, -0.009212568402290344, 0.4233095645904541, -0.08658035844564438, -0.26237180829048157, -0.3861198127269745, 0.48275575041770935, 0.18322139978408813, 0.10145892947912216, -0.13132549822330475, -0.14981836080551147, -0.44625774025917053, 0.284708172082901, -0.28664588928222656, -0.18883949518203735, -0.14426255226135254, 0.11285009980201721, -0.15344053506851196, 0.19297651946544647, -0.10687289386987686, -0.17021144926548004, -0.45022425055503845, 0.1612005978822708, 0.2535587251186371, 0.1981431245803833, -0.25220850110054016, 0.11309175938367844, 0.025743544101715088, -0.26750752329826355, 0.3559701144695282, -0.126265749335289, 0.01946205459535122, -0.1899697631597519, 0.08389642089605331, 0.4294525384902954, 0.3849283754825592, -0.06310809403657913, 0.2618269622325897, 0.2347196787595749, -0.058131616562604904, 0.0105327432975173, -0.2610444128513336, 0.08814212679862976, 0.21297456324100494, -0.19235694408416748, -0.11726375669240952, -0.0457448810338974, 0.04035811498761177, -0.2645147144794464, -0.14847487211227417, -0.2706614136695862, 0.1790958195924759, 0.1490056812763214, -0.04857505485415459, 0.10515721887350082, -0.07829747349023819, -0.2905997335910797, 0.27712059020996094, 0.12021821737289429, -0.11091139167547226, 0.16729740798473358, -0.12929250299930573, 0.20075982809066772, 0.1181543692946434, 0.2683846056461334, ...]",0.917005,0,"as a hybrid teen thriller and murder mystery , murder by numbers fits the profile too closely ."


### Accessing DataLoaders 

We can access dataloaders for the raw text in `ds.dsd_raw` with `ds.dld_raw`, and for the tokenised text in `ds.dsd_tkn` with `ds_dld_tkn`. Both of these are dictionaries of dataloaders with keys `['train', 'valid', 'test', 'train_eval]`. 

In [None]:
print("Dataloader dict has keys:", ds.dld_raw.keys())
batch_raw = next(iter(ds.dld_raw['train']))
batch_tkn = next(iter(ds.dld_tkn['train']))
print(batch_raw.keys())
print(batch_tkn.keys())
print("Tokenised input is of shape:", batch_tkn['input_ids'].shape)

Dataloader dict has keys: dict_keys(['train', 'test', 'valid', 'train_eval'])
dict_keys(['text', 'idx'])
dict_keys(['idx', 'attention_mask', 'input_ids', 'label', 'orig_truelabel_probs', 'orig_sts_embeddings'])
Tokenised input is of shape: torch.Size([16, 64])


#hide
## Export 

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 02_tests.ipynb.
Converted 03_config.ipynb.
Converted 07_models.ipynb.
Converted 10_data.ipynb.
Converted 20_trainer.ipynb.
Converted 30_logging.ipynb.
Converted 35_charts.ipynb.
Converted index.ipynb.
Converted run.ipynb.
