In [None]:
#default_exp data

In [None]:
#hide
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [None]:
#export
import torch, random, pandas as pd
from torch.utils.data import DataLoader, RandomSampler
from datasets import load_dataset, load_from_disk, DatasetDict, ClassLabel
from IPython.display import display, HTML
from travis_attack.trainer import get_vm_probs
from travis_attack.config import Config

In [None]:
#hide
import inspect
from IPython.core.debugger import set_trace

# Preparing data

## Classes 

### Base class

In [None]:
#export
class ProcessedDataset: 
    """Class that wraps a raw dataset (e.g. from huggingface datasets) and performs preprocessing on it."""
    def __init__(self, cfg,  vm_tokenizer, pp_tokenizer, vm_model, sts_model): 
        self._cfg,self._vm_tokenizer,self._pp_tokenizer,self._vm_model,self._sts_model = cfg,vm_tokenizer,pp_tokenizer,vm_model,sts_model
        if   self._cfg.dataset_name == "simple":          self._prep_dsd_raw_simple()
        elif self._cfg.dataset_name == "rotten_tomatoes": self._prep_dsd_raw_rotten_tomatoes()
        else: raise Exception("cfg.dataset_name must be either 'simple' or 'rotten_tomatoes'")
        self._preprocess_dataset() 
            
    def _prep_dsd_raw_simple(self): 
        """Load the simple dataset and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        self.dsd_raw = DatasetDict()
        for s in self._cfg.splits:  
            self.dsd_raw[s] = load_dataset('csv', data_files=f"{self._cfg.path_data}simple_dataset_{s}.csv")['train']    
        
    def _prep_dsd_raw_rotten_tomatoes(self):
        """Load the rotten tomatoes dataet and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        self.dsd_raw = load_dataset("rotten_tomatoes")
        self.dsd_raw['valid'] = self.dsd_raw.pop('validation')  # "valid" is easier than "validation" 
        # make sure that all datasets have the same number of labels as what the victim model predicts
        for _,ds in self.dsd_raw.items(): assert ds.features[self._cfg.label_cname].num_classes == self._cfg.vm_num_labels 

    def _prep_dsd_raw_snli(self): 
        ## For snli
        # remove_minus1_labels = lambda x: x[label_cname] != -1
        # ds_train = ds_train.filter(remove_minus1_labels)
        # valid = valid.filter(remove_minus1_labels)
        # test = test.filter(remove_minus1_labels)
        raise NotImplementedError("SNLI not implemented yet.")
    
    def _preprocess_dataset(self): 
        """Add columns, tokenize, transform, prepare dataloaders, and do other preprocessing tasks."""
        ##### TODO: why do you need train_eval? how is it different from train?
        dsd = self.dsd_raw.map(self._add_idx, batched=True, with_indices=True)  # add idx column
        if self._cfg.use_small_ds: dsd = self._prep_small_ds(dsd)  # do after adding idx so it's consistent across runs
        dsd = dsd.map(self._add_vm_orig_score, batched=True)  # add VM score
        if self._cfg.remove_misclassified_examples: dsd = dsd.filter(lambda x: x['orig_vm_predclass'] == x['label']) 
        dsd = dsd.map(self._add_sts_orig_embeddings, batched=True)  # add STS score 
        dsd = dsd.map(self._tokenize_fn,             batched=True)  # tokenize
        dsd = dsd.map(self._add_n_tokens,            batched=True)  # add n_tokens
        if self._cfg.bucket_by_length: dsd = dsd.sort("n_tokens", reverse=True)  # sort by n_tokens (high to low), useful for cuda memory caching
        self.dld_raw = self._get_dataloaders_dict(dsd, collate_fn=self._collate_fn_raw)  # dict of data loaders that serve raw text
        self.dld_tkn = self._get_dataloaders_dict(dsd, collate_fn=self._collate_fn_tkn)  # dict of data loaders that serve tokenized text
        self.dsd_tkn = dsd 
        
    def _add_idx(self, batch, idx):
        """Add row numbers"""
        batch['idx'] = idx 
        return batch   
    
    def _add_n_tokens(self, batch): 
        """Add the number of tokens of the orig text """
        batch['n_tokens'] = [len(o) for o in batch['input_ids']]
        return batch 
    
    def _add_sts_orig_embeddings(self, batch): 
        """Add the sts embeddings of the orig text"""
        batch['orig_sts_embeddings'] = self._sts_model.encode(batch[self._cfg.orig_cname], batch_size=64, convert_to_tensor=False)
        return batch
    
    def _add_vm_orig_score(self, batch): 
        """Add the vm score of the orig text"""
        labels = torch.tensor(batch[cfg.label_cname], device=self._cfg.device)
        orig_probs,orig_predclass = get_vm_probs(batch[self._cfg.orig_cname], self._cfg, self._vm_tokenizer,
                                                 self._vm_model, return_predclass=True)
        batch['orig_truelabel_probs'] = torch.gather(orig_probs,1, labels[:,None]).squeeze().cpu().tolist()
        batch['orig_vm_predclass'] = orig_predclass.cpu().tolist()
        return batch
    
    def _tokenize_fn(self, batch):  
        """Tokenize a batch of orig text using the paraphrase tokenizer."""
        return self._pp_tokenizer(batch[self._cfg.orig_cname], truncation=True, max_length=self._cfg.orig_max_length)  
    
    def _collate_fn_tkn(self, x): 
        """Collate function used by the DataLoader that serves tokenized data."""
        d = dict()
        for k in ['idx', 'attention_mask', 'input_ids', 'label', 'orig_truelabel_probs', 'orig_sts_embeddings']: 
            d[k] = [o[k] for o in x]
        return self._pp_tokenizer.pad(d, pad_to_multiple_of=self._cfg.orig_padding_multiple, return_tensors="pt")

    def _collate_fn_raw(self, x): 
        """Collate function used by the DataLoader that serves raw data."""
        d = dict()
        for k in [cfg.orig_cname, 'idx']: d[k] = [o[k] for o in x]
        return d 

    def _get_sampler(self, ds): 
        """Returns a RandomSampler. Used so we can keep the same shuffle order across multiple data loaders.
        Used when self._cfg.shuffle_train = True"""
        g = torch.Generator()
        g.manual_seed(seed)
        return RandomSampler(ds, generator=g)
    
    def _prep_small_ds(self, dsd):
        """Replaces a datasetdict with a smaller shard of itself."""
        for k,v in dsd.items():  
            dsd[k] = v.shard(self._cfg.n_shards, 0, contiguous=self._cfg.shard_contiguous)
        return dsd
        
    def _get_dataloaders_dict(self, dsd, collate_fn): 
        """Prepare a dict of dataloaders for train, valid and test"""
        if self._cfg.bucket_by_length and self._cfg.shuffle_train:  raise Exception("Can only do one of bucket by length or shuffle")
        d = dict()
        for split, ds in dsd.items(): 
            if self._cfg.shuffle_train:
                if split == "train": 
                    sampler = self.get_sampler(ds)
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_train, 
                                           sampler=sampler, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
                else: 
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_eval, 
                                           shuffle=False, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
            if self._cfg.bucket_by_length: 
                batch_size = self._cfg.batch_size_train if split == "train" else self._cfg.batch_size_eval
                d[split] =  DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, 
                                       num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 

        # Add eval dataloader for train 
        d['train_eval'] = DataLoader(dsd['train'], batch_size=self._cfg.batch_size_eval, shuffle=False,
                                    collate_fn=collate_fn, 
                                     num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
        return d 
    
    def show_random_elements(self, ds, num_examples=10):
        """Print some elements in a nice format so you can take a look at them. 
        Split is one of 'train', 'test', 'valid'. 
        Use for a dataset `ds` from the `dataset` package.  """
        assert num_examples <= len(ds), "Can't pick more elements than there are in the dataset."
        picks = []
        for _ in range(num_examples):
            pick = random.randint(0, len(ds)-1)
            while pick in picks:
                pick = random.randint(0, len(ds)-1)
            picks.append(pick)
        df = pd.DataFrame(ds[picks])
        for column, typ in ds.features.items():
            if isinstance(typ, ClassLabel):
                df[column] = df[column].transform(lambda i: typ.names[i])
        display(HTML(df.to_html()))

## Usage 

### Basics 

Here we have defined a class `ProcessedDataset` that will load and preprocess a dataset. But before processing the dataset you must load both the config object and all models/tokenizers, so we do this first. 

In [None]:
from travis_attack.models import prepare_models
cfg = Config()
vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, cfg = prepare_models(cfg)

Currently there are two choices for dataset: 

* `simple`, a dataset of simple sentences with four elements each in the train, test and valid splits
* `rotten_tomatoes`, a dataset of movie reviews scraped from the Rotten Tomatoes site. 

The dataset is specified by the config class. There are two ways to do this. 

1. Edit the `self.dataset_name` variable in the Config class to either `simple` or `rotten_tomatoes`. An error will be thrown if the name is not one of these two. This is the best way to use when doing runs.   
2. Use the `adjust_dataset_...` methods of the config class: e.g. `cfg = cfg.adjust_dataset_for_rotten_tomatoes_dataset() = Config()`. This is easiest for automated testing so we will do this here. 

Once the config is specified and loaded, create an object of class `ProcessedDataset` by passing the config, models and tokenizers as variables. This will do all preprocessing automatically in creating the object (the preprocessing code is in the `__init__()` function of the class.  

In [None]:
cfg_simple = cfg.adjust_config_for_simple_dataset()
ds = ProcessedDataset(cfg_simple,  vm_tokenizer, pp_tokenizer, vm_model, sts_model)

Using custom data configuration default-b253756c445fb811
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-b253756c445fb811/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-c802946231f72062
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-c802946231f72062/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-43a49c5188c42e69
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-43a49c5188c42e69/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

If you want to use the small dataset adjust the config before creating the `ProcessedDataset` object. 

In [None]:
cfg_rt_small_ds = cfg.adjust_config_for_rotten_tomatoes_dataset()
ds = ProcessedDataset(cfg_rt_small_ds,  vm_tokenizer, pp_tokenizer, vm_model, sts_model)

Using custom data configuration default
Reusing dataset rotten_tomatoes_movie_review (/data/tproth/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/e06abb624abab47e1a64608fdfe65a913f5a68c66118408032644a3285208fb5)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

### Accessing datasets

You can access raw data with `ds.dsd_raw` and processed data with the `ds.dsd_tkn`. (The dsd here stands for "DatasetDict")

In [None]:
ds.dsd_raw

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 8530
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
    valid: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
})

In [None]:
ds.dsd_tkn

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 7443
    })
    test: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 889
    })
    valid: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 895
    })
})

You can access elements by indexing: 

In [None]:
ds.dsd_raw['valid'][0:2]

{'text': ['compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
  'the soundtrack alone is worth the price of admission .'],
 'label': [1, 1]}

In [None]:
ds.dsd_tkn['valid'][0:2]

{'attention_mask': [[1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1],
  [1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1]],
 'idx': [1028, 407],
 'input_ids': [[0,
   627,
   1569,
   18,
   934,
   23485,
   283,
   31,
   1782,
   320,
   295,
   219,
   6195,
   4626,
   859,
   1236,
   922,
   5810,
   859,
   2084,
   605,
   354,
   816,
   10,
   6740,
   12,
 

Alternately you can look at some random elements of a dataset with the `ds.show_random_elements()` method. 

In [None]:
ds.show_random_elements(ds.dsd_raw['train'], num_examples=3)

Unnamed: 0,text,label
0,"a great ending doesn't make up for a weak movie , and crazy as hell doesn't even have a great ending .",neg
1,"while the isle is both preposterous and thoroughly misogynistic , its vistas are incredibly beautiful to look at .",pos
2,not exactly the bees knees,neg


In [None]:
ds.show_random_elements(ds.dsd_tkn['train'], num_examples=3)

Unnamed: 0,attention_mask,idx,input_ids,label,n_tokens,orig_sts_embeddings,orig_truelabel_probs,orig_vm_predclass,text
0,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",2952,"[0, 3341, 63, 11676, 34510, 2156, 24, 39197, 1626, 84, 23810, 479, 2]",pos,13,"[0.07683645188808441, -0.43254953622817993, -0.2884201407432556, 0.26368406414985657, -0.15067246556282043, -0.29725706577301025, 0.44652673602104187, -0.09599904716014862, 0.4136199951171875, 0.22868911921977997, 0.16475704312324524, 0.2580578923225403, -0.009651213884353638, -0.15675018727779388, 0.028356363996863365, 0.026996750384569168, 0.3257802426815033, -0.09476791322231293, -0.1322847455739975, 0.40260812640190125, -0.06383713334798813, 0.23907266557216644, 0.16382071375846863, 0.22284190356731415, -0.07003384083509445, -0.2778371274471283, 0.06157108023762703, -0.14136412739753723, -0.14780288934707642, -0.2185499370098114, -0.19067807495594025, -0.05854059010744095, -0.1803152859210968, 0.24854472279548645, -0.1485261619091034, 0.38437846302986145, -0.09991002082824707, 0.07585038244724274, 0.11837996542453766, -0.4445829391479492, -0.026835225522518158, 0.19295825064182281, 0.009423277340829372, 0.24481302499771118, -0.06806370615959167, -0.03962473198771477, -0.05362413823604584, 0.2703951299190521, -0.019495781511068344, -0.2742648720741272, -0.1754714846611023, -0.047957465052604675, -0.09943359345197678, -0.18178799748420715, 0.000764952739700675, -0.1270955502986908, 0.30256763100624084, 0.017745088785886765, 0.1505133956670761, -0.09769076108932495, 0.1602364182472229, -0.016127118840813637, 0.10854503512382507, 0.11427223682403564, 0.08034198731184006, -0.4638966917991638, -0.051793172955513, -0.07844981551170349, -0.012365324422717094, -0.21001778542995453, 0.2400374412536621, 0.03593837469816208, 0.4045008718967438, 0.09282571822404861, 0.11244865506887436, 0.16519945859909058, 0.16836334764957428, -0.2736123502254486, 0.03752522170543671, 0.10063756257295609, -0.001285043079406023, 0.16157469153404236, 0.5426371097564697, 0.2488955408334732, -0.034798894077539444, 0.11979421228170395, 0.05712924897670746, -0.1519034504890442, 0.09678438305854797, 0.09584493190050125, -0.3669874370098114, 0.125444233417511, 0.40931230783462524, 0.3035696744918823, 0.18691092729568481, -0.2104356288909912, -0.15142661333084106, 0.016831260174512863, -0.10474022477865219, 0.1896384060382843, ...]",0.931301,1,"like its bizarre heroine , it irrigates our souls ."
1,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",3573,"[0, 13138, 4376, 10, 182, 1375, 8, 20853, 5257, 41887, 7, 5, 18701, 43328, 479, 2]",pos,16,"[-0.20158840715885162, 0.3575786054134369, -0.38341793417930603, 0.18010327219963074, 0.15630246698856354, 0.38274314999580383, -0.09994902461767197, -0.2318796068429947, 0.31283318996429443, 0.11660215258598328, -0.11918359249830246, 0.02431304007768631, 0.149966299533844, 0.10298259556293488, -0.14447380602359772, -0.18221139907836914, 0.00669496413320303, 0.29211005568504333, 0.07370553910732269, 0.11665214598178864, -0.08646053075790405, 0.22458311915397644, 0.4172590970993042, 0.2418132871389389, 0.20363584160804749, -0.09745723754167557, -0.06742429733276367, -0.05254404991865158, -0.16569049656391144, -0.1287563592195511, 0.27005329728126526, 0.06663543730974197, -0.2167079597711563, 0.07535022497177124, -0.027371976524591446, 0.2469516545534134, 0.33272191882133484, 0.37361210584640503, 0.049456264823675156, -0.06388180702924728, -0.11500751972198486, 0.41049230098724365, 0.19662395119667053, 0.14603233337402344, 0.24917030334472656, 0.09485176205635071, -0.07280126214027405, 0.2429356575012207, -0.12741756439208984, 0.2231244444847107, -0.25517117977142334, -0.16563566029071808, -0.19966357946395874, 0.1950351595878601, 0.4458578824996948, 0.07378891110420227, -0.14559735357761383, -0.13437621295452118, 0.13913419842720032, -0.1785784363746643, 0.07854782044887543, -0.1836431622505188, 0.24719442427158356, 0.07032902538776398, 0.10869353264570236, -0.3437102735042572, -0.06741522997617722, 0.17490465939044952, -0.2120242714881897, 0.5022298097610474, -0.406875342130661, -0.1948540061712265, 0.49807024002075195, -0.20378531515598297, 0.1250046342611313, -0.4025477468967438, -0.0820605680346489, 0.01799776591360569, -0.3588275611400604, 0.2247607707977295, 0.15300235152244568, 0.22950297594070435, 0.30480924248695374, 0.14958475530147552, 0.06558987498283386, 0.0847502276301384, 0.11094487458467484, 0.038049597293138504, 0.2623656094074249, 0.17512740194797516, -0.23059511184692383, 0.1199222281575203, 0.3239966928958893, -0.082222118973732, 0.1591596156358719, -0.09350860118865967, -0.11698610335588455, -0.19592158496379852, 0.4761217534542084, 0.15343964099884033, ...]",0.958571,1,provides a very moving and revelatory footnote to the holocaust .
2,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",5589,"[0, 1264, 95, 23120, 17081, 352, 13, 5, 220, 4817, 396, 2623, 203, 26350, 7, 5, 3768, 479, 2]",neg,19,"[-0.07974279671907425, -0.0712166577577591, -0.039970070123672485, 0.0007532645249739289, 0.19971798360347748, -0.06183060258626938, -0.05134312063455582, 0.32594773173332214, 0.3937385380268097, -0.022144082933664322, 0.21838627755641937, 0.27371323108673096, 0.26927459239959717, 0.06295505166053772, -0.19895893335342407, -0.20393683016300201, -0.11135062575340271, 0.11019532382488251, -0.2996702790260315, 0.0112022515386343, -0.40957605838775635, -0.12017344683408737, 0.18621665239334106, 0.016812538728117943, 0.23999541997909546, -0.1859787106513977, 0.11688186973333359, 0.05328073725104332, -0.23316197097301483, -0.3899984657764435, -0.1998777538537979, -0.43840503692626953, 0.1213247999548912, 0.10185296088457108, 0.08310198783874512, 0.3317270576953888, -0.08138290047645569, 0.1843833029270172, -0.10513550788164139, -0.2397989183664322, -0.24391785264015198, 0.2868635952472687, -0.14646278321743011, 0.15775363147258759, -0.021649273112416267, -0.12723657488822937, 0.09280068427324295, 0.0035478512290865183, -0.025228489190340042, 0.13578663766384125, -0.2502560615539551, -0.04018213227391243, -0.2480120062828064, -0.21175764501094818, 0.11948485672473907, 0.1399693340063095, -0.3697361946105957, 0.042635075747966766, 0.03373183682560921, -0.11777245253324509, 0.01075794454663992, -0.3459096848964691, 0.10292027145624161, 0.09431789070367813, 0.26532837748527527, -0.3107123076915741, 0.03377913311123848, -0.37663134932518005, 0.11032536625862122, 0.5276355147361755, 0.27742302417755127, -0.12793944776058197, 0.0142912482842803, 0.0338447131216526, 0.2502032220363617, 0.19924692809581757, -0.07416054606437683, -0.12118217349052429, 0.2643100917339325, 0.12474198639392853, -0.010701295919716358, 0.16725444793701172, 0.1870589703321457, -0.16427773237228394, -0.39120978116989136, -0.03582170605659485, 0.4065021574497223, -0.014093264006078243, -0.1936904788017273, 0.15390048921108246, 0.17351402342319489, -0.01267438754439354, 0.48825645446777344, 0.31152695417404175, -0.4173009693622589, 0.2794550359249115, -0.044307924807071686, 0.370815247297287, -0.21668657660484314, 0.040798477828502655, ...]",0.795055,0,one just waits grimly for the next shock without developing much attachment to the characters .


### Accessing DataLoaders 

We can access dataloaders for the raw text in `ds.dsd_raw` with `ds.dld_raw`, and for the tokenised text in `ds.dsd_tkn` with `ds_dld_tkn`. Both of these are dictionaries of dataloaders with keys `['train', 'valid', 'test', 'train_eval]`. 

In [None]:
print("Dataloader dict has keys:", ds.dld_raw.keys())
batch_raw = next(iter(ds.dld_raw['train']))
batch_tkn = next(iter(ds.dld_tkn['train']))
print(batch_raw.keys())
print(batch_tkn.keys())
print("Tokenised input is of shape:", batch_tkn['input_ids'].shape)

Dataloader dict has keys: dict_keys(['train', 'test', 'valid', 'train_eval'])
dict_keys(['text', 'idx'])
dict_keys(['idx', 'attention_mask', 'input_ids', 'label', 'orig_truelabel_probs', 'orig_sts_embeddings'])
Tokenised input is of shape: torch.Size([32, 64])


#hide
## Export 

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 03_config.ipynb.
Converted 07_models.ipynb.
Converted 10_data.ipynb.
Converted 20_trainer.ipynb.
Converted 30_logging.ipynb.
Converted index.ipynb.
Converted run.ipynb.
