In [None]:
#default_exp data

In [None]:
#hide
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
#export
import torch, random, pandas as pd
from torch.utils.data import DataLoader, RandomSampler
from datasets import load_dataset, load_from_disk, DatasetDict, ClassLabel
from IPython.display import display, HTML
from travis_attack.models import get_vm_probs
from travis_attack.config import Config

In [None]:
#hide
import inspect
from IPython.core.debugger import set_trace

# Preparing data

## Classes 

### Base class

In [None]:
#export
class ProcessedDataset: 
    """Class that wraps a raw dataset (e.g. from huggingface datasets) and performs preprocessing on it."""
    def __init__(self, cfg, vm_tokenizer, vm_model, pp_tokenizer, sts_model): 
        self._cfg,self._vm_tokenizer,self._vm_model,self._pp_tokenizer,self._sts_model = cfg,vm_tokenizer,vm_model,pp_tokenizer,sts_model
        if   self._cfg.dataset_name == "simple":          self._prep_dsd_raw_simple()
        elif self._cfg.dataset_name == "rotten_tomatoes": self._prep_dsd_raw_rotten_tomatoes()
        else: raise Exception("cfg.dataset_name must be either 'simple' or 'rotten_tomatoes'")
        self._preprocess_dataset() 
            
    def _prep_dsd_raw_simple(self): 
        """Load the simple dataset and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        self.dsd_raw = DatasetDict()
        for s in self._cfg.splits:  
            self.dsd_raw[s] = load_dataset('csv', data_files=f"{self._cfg.path_data}simple_dataset_{s}.csv")['train']    
        
    def _prep_dsd_raw_rotten_tomatoes(self):
        """Load the rotten tomatoes dataet and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        self.dsd_raw = load_dataset("rotten_tomatoes")
        self.dsd_raw['valid'] = self.dsd_raw.pop('validation')  # "valid" is easier than "validation" 
        # make sure that all datasets have the same number of labels as what the victim model predicts
        for _,ds in self.dsd_raw.items(): assert ds.features[self._cfg.label_cname].num_classes == self._cfg.vm_num_labels 

    def _prep_dsd_raw_snli(self): 
        ## For snli
        # remove_minus1_labels = lambda x: x[label_cname] != -1
        # ds_train = ds_train.filter(remove_minus1_labels)
        # valid = valid.filter(remove_minus1_labels)
        # test = test.filter(remove_minus1_labels)
        raise NotImplementedError("SNLI not implemented yet.")
    
    def _preprocess_dataset(self): 
        """Add columns, tokenize, transform, prepare dataloaders, and do other preprocessing tasks."""
        ##### TODO: why do you need train_eval? how is it different from train?
        self.dsd_raw = self.dsd_raw.map(self._add_idx, batched=True, with_indices=True)  # add idx column
        if self._cfg.use_small_ds: self._shard_dsd_raw()  # do after adding idx so it's consistent across runs
        # add VM score & filter out misclassified examples. (at this point lengths are diff between dsd_raw and dsd_tkn)
        self.dsd_tkn = self.dsd_raw.map(self._add_vm_orig_score, batched=True)  
        if self._cfg.remove_misclassified_examples: 
            self.dsd_tkn = self.dsd_tkn.filter(lambda x: x['orig_vm_predclass'] == x['label']) 
        self.dsd_tkn = self.dsd_tkn.map(self._add_sts_orig_embeddings, batched=True)  # add STS score 
        self.dsd_tkn = self.dsd_tkn.map(self._tokenize_fn,             batched=True)  # tokenize
        self.dsd_tkn = self.dsd_tkn.map(self._add_n_tokens,            batched=True)  # add n_tokens
        if self._cfg.bucket_by_length: self.dsd_tkn = self.dsd_tkn.sort("n_tokens", reverse=True)  # sort by n_tokens (high to low), useful for cuda memory caching
        
        # filter out rows in dsd_raw that aren't in dsd_tkn
        for s in self._cfg.splits:
            idx_list = self.dsd_tkn[s]['idx']
            self.dsd_raw[s] = self.dsd_raw[s].filter(lambda x: x['idx'] in idx_list)
        for s in self._cfg.splits: assert len(self.dsd_raw[s]) == len(self.dsd_tkn[s])  # check ds has same number of elements in raw and tkn
        
        # Prepare dataloaders
        self.dld_raw = self._get_dataloaders_dict(self.dsd_raw, collate_fn=self._collate_fn_raw)  # dict of data loaders that serve raw text
        self.dld_tkn = self._get_dataloaders_dict(self.dsd_tkn, collate_fn=self._collate_fn_tkn)  # dict of data loaders that serve tokenized text
        
    def _add_idx(self, batch, idx):
        """Add row numbers"""
        batch['idx'] = idx 
        return batch   
    
    def _add_n_tokens(self, batch): 
        """Add the number of tokens of the orig text """
        batch['n_tokens'] = [len(o) for o in batch['input_ids']]
        return batch 
    
    def _add_sts_orig_embeddings(self, batch): 
        """Add the sts embeddings of the orig text"""
        batch['orig_sts_embeddings'] = self._sts_model.encode(batch[self._cfg.orig_cname], batch_size=64, convert_to_tensor=False)
        return batch
    
    def _add_vm_orig_score(self, batch): 
        """Add the vm score of the orig text"""
        labels = torch.tensor(batch[self._cfg.label_cname], device=self._cfg.device)
        orig_probs,orig_predclass = get_vm_probs(batch[self._cfg.orig_cname], self._cfg, self._vm_tokenizer,
                                                 self._vm_model, return_predclass=True)
        batch['orig_truelabel_probs'] = torch.gather(orig_probs,1, labels[:,None]).squeeze().cpu().tolist()
        batch['orig_vm_predclass'] = orig_predclass.cpu().tolist()
        return batch
    
    def _tokenize_fn(self, batch):  
        """Tokenize a batch of orig text using the paraphrase tokenizer."""
        return self._pp_tokenizer(batch[self._cfg.orig_cname], truncation=True, max_length=self._cfg.orig_max_length)  
    
    def _collate_fn_tkn(self, x): 
        """Collate function used by the DataLoader that serves tokenized data."""
        d = dict()
        for k in ['idx', 'attention_mask', 'input_ids', 'label', 'orig_truelabel_probs', 'orig_sts_embeddings']: 
            d[k] = [o[k] for o in x]
        return self._pp_tokenizer.pad(d, pad_to_multiple_of=self._cfg.orig_padding_multiple, return_tensors="pt")

    def _collate_fn_raw(self, x): 
        """Collate function used by the DataLoader that serves raw data."""
        d = dict()
        for k in [self._cfg.orig_cname, 'idx']: d[k] = [o[k] for o in x]
        return d 

    def _get_sampler(self, ds): 
        """Returns a RandomSampler. Used so we can keep the same shuffle order across multiple data loaders.
        Used when self._cfg.shuffle_train = True"""
        g = torch.Generator()
        g.manual_seed(seed)
        return RandomSampler(ds, generator=g)
    
    def _shard_dsd_raw(self):
        """Replaces dsd_raw with a smaller shard of itself."""
        for k,v in self.dsd_raw.items():  
            self.dsd_raw[k] = v.shard(self._cfg.n_shards, 0, contiguous=self._cfg.shard_contiguous)
        
    def _get_dataloaders_dict(self, dsd, collate_fn): 
        """Prepare a dict of dataloaders for train, valid and test"""
        if self._cfg.bucket_by_length and self._cfg.shuffle_train:  raise Exception("Can only do one of bucket by length or shuffle")
        d = dict()
        for split, ds in dsd.items(): 
            if self._cfg.shuffle_train:
                if split == "train": 
                    sampler = self.get_sampler(ds)
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_train, 
                                           sampler=sampler, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
                else: 
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_eval, 
                                           shuffle=False, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
            if self._cfg.bucket_by_length: 
                batch_size = self._cfg.batch_size_train if split == "train" else self._cfg.batch_size_eval
                d[split] =  DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, 
                                       num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 

        # Add eval dataloader for train 
        d['train_eval'] = DataLoader(dsd['train'], batch_size=self._cfg.batch_size_eval, shuffle=False,
                                    collate_fn=collate_fn, 
                                     num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
        return d 
    
    def show_random_elements(self, ds, num_examples=10):
        """Print some elements in a nice format so you can take a look at them. 
        Split is one of 'train', 'test', 'valid'. 
        Use for a dataset `ds` from the `dataset` package.  """
        assert num_examples <= len(ds), "Can't pick more elements than there are in the dataset."
        picks = []
        for _ in range(num_examples):
            pick = random.randint(0, len(ds)-1)
            while pick in picks:
                pick = random.randint(0, len(ds)-1)
            picks.append(pick)
        df = pd.DataFrame(ds[picks])
        for column, typ in ds.features.items():
            if isinstance(typ, ClassLabel):
                df[column] = df[column].transform(lambda i: typ.names[i])
        display(HTML(df.to_html()))

## Usage 

### Basics 

Here we have defined a class `ProcessedDataset` that will load and preprocess a dataset. But before processing the dataset you must load both the config object and all models/tokenizers, so we do this first. 

In [None]:
from travis_attack.models import prepare_models
cfg = Config()
vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, cfg = prepare_models(cfg)

Currently there are two choices for dataset: 

* `simple`, a dataset of simple sentences with four elements each in the train, test and valid splits
* `rotten_tomatoes`, a dataset of movie reviews scraped from the Rotten Tomatoes site. 

The dataset is specified by the config class. There are two ways to do this. 

1. Edit the `self.dataset_name` variable in the Config class to either `simple` or `rotten_tomatoes`. An error will be thrown if the name is not one of these two. This is the best way to use when doing runs.   
2. Use the `adjust_dataset_...` methods of the config class: e.g. `cfg = cfg.adjust_dataset_for_rotten_tomatoes_dataset() = Config()`. This is easiest for automated testing so we will do this here. 

Once the config is specified and loaded, create an object of class `ProcessedDataset` by passing the config, models and tokenizers as variables. This will do all preprocessing automatically in creating the object (the preprocessing code is in the `__init__()` function of the class.  

In [None]:
cfg_simple = cfg.adjust_config_for_simple_dataset()
ds = ProcessedDataset(cfg_simple,  vm_tokenizer, vm_model, pp_tokenizer, sts_model)

Using custom data configuration default-b253756c445fb811
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-b253756c445fb811/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-c802946231f72062
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-c802946231f72062/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-43a49c5188c42e69
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-43a49c5188c42e69/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

If you want to use the small dataset adjust the config before creating the `ProcessedDataset` object. 

In [None]:
cfg_rt_small_ds = cfg.adjust_config_for_rotten_tomatoes_dataset().small_ds()
ds = ProcessedDataset(cfg_rt_small_ds,  vm_tokenizer, vm_model, pp_tokenizer, sts_model)

Using custom data configuration default
Reusing dataset rotten_tomatoes_movie_review (/data/tproth/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/e06abb624abab47e1a64608fdfe65a913f5a68c66118408032644a3285208fb5)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

### Accessing datasets

You can access raw data with `ds.dsd_raw` and processed data with the `ds.dsd_tkn`. (The dsd here stands for "DatasetDict")

In [None]:
ds.dsd_raw

DatasetDict({
    train: Dataset({
        features: ['idx', 'label', 'text'],
        num_rows: 196
    })
    test: Dataset({
        features: ['idx', 'label', 'text'],
        num_rows: 23
    })
    valid: Dataset({
        features: ['idx', 'label', 'text'],
        num_rows: 23
    })
})

In [None]:
ds.dsd_tkn

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 196
    })
    test: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 23
    })
    valid: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 23
    })
})

You can access elements by indexing: 

In [None]:
ds.dsd_raw['valid'][0:2]

{'idx': [0, 40],
 'label': [1, 1],
 'text': ['compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
  'a perceptive , good-natured movie .']}

In [None]:
ds.dsd_tkn['valid'][0:2]

{'attention_mask': [[1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1],
  [1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1]],
 'idx': [320, 520],
 'input_ids': [[0,
   405,
   18,
   479,
   479,
   479,
   966,
   5,
   1823,
   1351,
   7,
   192,
   41,
   3025,
   2156,
   202,
   2021,
   7,
   434,
   11,
   39,
   5127,
   2202,
   2156,
   464,
   150,
   2405,
   1528,
   7,
   39,
   7797,
   19,
   10,
   822,
   1060,
   182,
   2087,
   16,
   2156,
   1341,
   3273,
   352,
   2156,
   59,
  

Alternately you can look at some random elements of a dataset with the `ds.show_random_elements()` method. 

In [None]:
ds.show_random_elements(ds.dsd_raw['train'], num_examples=3)

Unnamed: 0,idx,label,text
0,2240,pos,"a smart , sassy and exceptionally charming romantic comedy ."
1,6960,neg,"[nelson's] movie about morally compromised figures leaves viewers feeling compromised , unable to find their way out of the fog and the ashes ."
2,1680,pos,"if you come from a family that eats , meddles , argues , laughs , kibbitzes and fights together , then go see this delightful comedy ."


In [None]:
ds.show_random_elements(ds.dsd_tkn['train'], num_examples=3)

Unnamed: 0,attention_mask,idx,input_ids,label,n_tokens,orig_sts_embeddings,orig_truelabel_probs,orig_vm_predclass,text
0,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",6200,"[0, 627, 129, 8354, 606, 77, 5, 7751, 1747, 3825, 8, 47, 120, 7, 989, 5, 7364, 479, 2]",neg,19,"[0.2718726396560669, -0.10213348269462585, -0.1086643859744072, -0.2356581836938858, 0.27173563838005066, 0.17503871023654938, 0.12863856554031372, 0.19121694564819336, 0.4207119941711426, 0.007817253470420837, -0.02263920195400715, 0.10309284180402756, -0.23597969114780426, -0.11926668137311935, -0.08704448491334915, -0.05439629405736923, 0.10555088520050049, 0.03964345157146454, -0.44482049345970154, -0.22134195268154144, -0.10061632841825485, -0.20155005156993866, -0.08711832016706467, 0.34374916553497314, 0.13874274492263794, -0.09023632854223251, -0.2435479611158371, -0.21252116560935974, 0.17128978669643402, -0.11253011971712112, -0.38164660334587097, -0.11789140105247498, 0.13599206507205963, 0.1033342108130455, 0.043450262397527695, 0.25180506706237793, -0.3405490219593048, -0.2393033653497696, -0.07114789634943008, -0.12654420733451843, -0.20744846761226654, 0.04983171448111534, -0.2545621991157532, 0.25938358902931213, -0.3537670075893402, -0.20977991819381714, 0.15646576881408691, 0.0785306841135025, 0.278666615486145, 0.005268873181194067, -0.11387224495410919, 0.15580874681472778, -0.22373265027999878, 0.07511617243289948, -0.2665550708770752, -0.03529684618115425, 0.24930939078330994, -0.11912687122821808, -0.34383586049079895, -0.043090708553791046, -0.23505821824073792, 0.07891485095024109, -0.06368077546358109, 0.054752446711063385, 0.04944201931357384, -0.2987542748451233, -0.16236701607704163, 0.11134135723114014, -0.2388170212507248, 0.07395632565021515, 0.12151885032653809, -0.041891805827617645, -0.10519231110811234, -0.1890822798013687, 0.13673673570156097, 0.14068511128425598, 0.1253659874200821, -0.22653533518314362, -0.11867418140172958, 0.32348430156707764, 0.18691153824329376, -0.6645033359527588, -0.006045620422810316, -0.29660555720329285, -0.029627876356244087, 0.20086131989955902, -0.187443345785141, 0.631697416305542, -0.12789827585220337, -0.12147552520036697, -0.007419958710670471, 0.11761238425970078, -0.12032697349786758, 0.3506828248500824, -0.09424198418855667, -0.08409081399440765, 0.0910419449210167, -0.11686234921216965, 0.0737936943769455, 0.0014112889766693115, ...]",0.770797,0,the only excitement comes when the credits finally roll and you get to leave the theater .
1,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",6080,"[0, 3866, 994, 4468, 23, 39, 275, 817, 5188, 3121, 3541, 14, 32, 6681, 9869, 53, 67, 29836, 8513, 8, 23134, 12038, 25606, 228, 417, 7469, 2156, 6637, 2156, 7005, 55, 101, 1428, 2649, 2955, 26351, 87, 205, 17195, 281, 479, 2]",neg,42,"[-0.12965631484985352, -0.10750909149646759, -0.392819344997406, 0.23402263224124908, 0.010695393197238445, -0.07399005442857742, 0.15130046010017395, -0.030590228736400604, -0.18184790015220642, -0.12506601214408875, 0.0993395745754242, 0.09256643801927567, 0.33045467734336853, 0.17623376846313477, -0.009037100709974766, -0.14154788851737976, 0.7564421892166138, -0.05878685042262077, -0.03853224962949753, 0.3701777160167694, -0.3265226185321808, -0.011549143120646477, 0.2488141804933548, -0.08391375094652176, 0.009564959444105625, -0.212877094745636, 0.048634715378284454, -0.07062599062919617, -0.3811725378036499, -0.12673339247703552, -0.01026322040706873, -0.24414096772670746, -0.1232338696718216, 0.09485022723674774, -0.19253979623317719, -0.054395392537117004, 0.21332338452339172, 0.40326613187789917, 0.2155093401670456, -0.5088616609573364, -0.3400716781616211, 0.4591353237628937, 0.08890822529792786, 0.1843874752521515, -0.15239113569259644, -0.2833177447319031, -0.11168127506971359, 0.25325915217399597, -0.0586787685751915, -0.21112321317195892, -0.2810358703136444, 0.11271234601736069, 0.17352280020713806, -0.21650603413581848, -0.23015564680099487, -0.012010931968688965, 0.2400810718536377, 0.34901440143585205, 0.024952787905931473, -0.11385580897331238, -0.12498290091753006, 0.01934744231402874, 0.07924199104309082, -0.2151636779308319, 0.15958388149738312, -0.3776818513870239, 0.1565454751253128, 0.016666946932673454, -0.31611794233322144, 0.433468759059906, 0.12981261312961578, 0.09754228591918945, 0.2931279242038727, 0.013555794022977352, 0.056666988879442215, 0.01724403351545334, 0.15035876631736755, -0.11490357667207718, -0.24567006528377533, 0.02843506447970867, 0.027632026001811028, -0.028614016249775887, -0.20672260224819183, 0.059745147824287415, 0.03713641315698624, -0.010408525355160236, 0.1542750895023346, -0.050928905606269836, 0.23415891826152802, 0.058182504028081894, -0.02012523263692856, 0.46659547090530396, 0.15216363966464996, -0.014698346145451069, 0.34962522983551025, 0.1668269783258438, -0.03960828483104706, 0.14027927815914154, -0.1932191550731659, -0.14552442729473114, ...]",0.515333,0,"scorsese at his best makes gangster films that are equally lovely but also relentlessly brutal and brutally intelligent ; perdition , meanwhile , reads more like driving miss daisy than goodfellas ."
2,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",7280,"[0, 102, 10439, 2156, 11923, 2156, 26081, 159, 479, 2]",neg,10,"[0.17786522209644318, 0.3085567057132721, 0.3459327816963196, 0.17433539032936096, 0.034761976450681686, -0.23686936497688293, 0.30796191096305847, 0.5154323577880859, 0.02422725036740303, 0.04894415661692619, 0.01657852903008461, -0.17129473388195038, 0.04543374851346016, 0.46662384271621704, 0.030553702265024185, 0.22557954490184784, -0.03447119519114494, -0.22122028470039368, 0.17454420030117035, 0.2902672290802002, -0.15967757999897003, 0.13103511929512024, -0.03508492186665535, -0.024228790774941444, -0.14071814715862274, -0.22014465928077698, 0.26420536637306213, -0.12379838526248932, -0.19637584686279297, -0.11714394390583038, -0.12160471081733704, 0.08832138776779175, 0.14266009628772736, 0.16409488022327423, -0.2231815904378891, 0.4679659605026245, -0.21566472947597504, 0.27192002534866333, -0.16363315284252167, 0.014075261540710926, 0.06480659544467926, -0.3283059597015381, 0.07540423423051834, 0.006178535055369139, 0.1387406587600708, -0.07707338780164719, 0.04011848568916321, 0.19707314670085907, -0.3595183789730072, 0.16951383650302887, 0.08032499253749847, 0.3893658220767975, -0.09819304198026657, -0.037792954593896866, 0.26548582315444946, 0.3073823153972626, 0.15000665187835693, -0.03882044926285744, -0.2212156057357788, 0.2510394752025604, -0.09592173993587494, -0.4037157893180847, -0.45425763726234436, 0.2856711745262146, 0.2172514945268631, -0.016501085832715034, 0.08840632438659668, -0.8317118883132935, -0.029085353016853333, 0.49915897846221924, 0.7791582942008972, -0.2380441427230835, -0.016743140295147896, 0.21256513893604279, 0.28053489327430725, -0.22307121753692627, -0.062333352863788605, 0.22435373067855835, 0.11218150705099106, -0.08326300978660583, -0.17685595154762268, 0.3336477279663086, -0.17521119117736816, -0.04317280650138855, -0.9503389000892639, 0.01869210973381996, -0.01793409325182438, 0.24861964583396912, -0.13385680317878723, -0.25910472869873047, -0.023599697276949883, -0.26051732897758484, 0.30779194831848145, -0.004107794724404812, -0.017796611413359642, 0.10789701342582703, 0.22939962148666382, -0.21993249654769897, -0.6796860694885254, 0.4220023453235626, ...]",0.699222,0,"a mild , reluctant , thumbs down ."


### Accessing DataLoaders 

We can access dataloaders for the raw text in `ds.dsd_raw` with `ds.dld_raw`, and for the tokenised text in `ds.dsd_tkn` with `ds_dld_tkn`. Both of these are dictionaries of dataloaders with keys `['train', 'valid', 'test', 'train_eval]`. 

In [None]:
print("Dataloader dict has keys:", ds.dld_raw.keys())
batch_raw = next(iter(ds.dld_raw['train']))
batch_tkn = next(iter(ds.dld_tkn['train']))
print(batch_raw.keys())
print(batch_tkn.keys())
print("Tokenised input is of shape:", batch_tkn['input_ids'].shape)

Dataloader dict has keys: dict_keys(['train', 'test', 'valid', 'train_eval'])
dict_keys(['text', 'idx'])
dict_keys(['idx', 'attention_mask', 'input_ids', 'label', 'orig_truelabel_probs', 'orig_sts_embeddings'])
Tokenised input is of shape: torch.Size([32, 64])


#hide
## Export 

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 02_tests.ipynb.
Converted 03_config.ipynb.
Converted 07_models.ipynb.
Converted 10_data.ipynb.
Converted 20_trainer.ipynb.
Converted 30_logging.ipynb.
Converted 35_charts.ipynb.
Converted index.ipynb.
Converted run.ipynb.
