In [None]:
#default_exp data

In [None]:
#hide
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
#export
import torch, random, pandas as pd
from torch.utils.data import DataLoader, RandomSampler
from datasets import load_dataset, load_from_disk, DatasetDict, ClassLabel
from IPython.display import display, HTML
from travis_attack.trainer import get_vm_probs
from travis_attack.config import Config

In [None]:
#hide
import inspect
from IPython.core.debugger import set_trace

# Preparing data

## Classes 

### Base class

In [None]:
#export
class ProcessedDataset: 
    """Class that wraps a raw dataset (e.g. from huggingface datasets) and performs preprocessing on it."""
    def __init__(self, cfg, vm_tokenizer, vm_model, pp_tokenizer, sts_model): 
        self._cfg,self._vm_tokenizer,self._vm_model,self._pp_tokenizer,self._sts_model = cfg,vm_tokenizer,vm_model,pp_tokenizer,sts_model
        if   self._cfg.dataset_name == "simple":          self._prep_dsd_raw_simple()
        elif self._cfg.dataset_name == "rotten_tomatoes": self._prep_dsd_raw_rotten_tomatoes()
        else: raise Exception("cfg.dataset_name must be either 'simple' or 'rotten_tomatoes'")
        self._preprocess_dataset() 
            
    def _prep_dsd_raw_simple(self): 
        """Load the simple dataset and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        self.dsd_raw = DatasetDict()
        for s in self._cfg.splits:  
            self.dsd_raw[s] = load_dataset('csv', data_files=f"{self._cfg.path_data}simple_dataset_{s}.csv")['train']    
        
    def _prep_dsd_raw_rotten_tomatoes(self):
        """Load the rotten tomatoes dataet and package it up in a DatasetDict (dsd) 
        with splits for train, valid, test."""
        self.dsd_raw = load_dataset("rotten_tomatoes")
        self.dsd_raw['valid'] = self.dsd_raw.pop('validation')  # "valid" is easier than "validation" 
        # make sure that all datasets have the same number of labels as what the victim model predicts
        for _,ds in self.dsd_raw.items(): assert ds.features[self._cfg.label_cname].num_classes == self._cfg.vm_num_labels 

    def _prep_dsd_raw_snli(self): 
        ## For snli
        # remove_minus1_labels = lambda x: x[label_cname] != -1
        # ds_train = ds_train.filter(remove_minus1_labels)
        # valid = valid.filter(remove_minus1_labels)
        # test = test.filter(remove_minus1_labels)
        raise NotImplementedError("SNLI not implemented yet.")
    
    def _preprocess_dataset(self): 
        """Add columns, tokenize, transform, prepare dataloaders, and do other preprocessing tasks."""
        ##### TODO: why do you need train_eval? how is it different from train?
        dsd = self.dsd_raw.map(self._add_idx, batched=True, with_indices=True)  # add idx column
        if self._cfg.use_small_ds: dsd = self._prep_small_ds(dsd)  # do after adding idx so it's consistent across runs
        dsd = dsd.map(self._add_vm_orig_score, batched=True)  # add VM score
        if self._cfg.remove_misclassified_examples: dsd = dsd.filter(lambda x: x['orig_vm_predclass'] == x['label']) 
        dsd = dsd.map(self._add_sts_orig_embeddings, batched=True)  # add STS score 
        dsd = dsd.map(self._tokenize_fn,             batched=True)  # tokenize
        dsd = dsd.map(self._add_n_tokens,            batched=True)  # add n_tokens
        if self._cfg.bucket_by_length: dsd = dsd.sort("n_tokens", reverse=True)  # sort by n_tokens (high to low), useful for cuda memory caching
        self.dld_raw = self._get_dataloaders_dict(dsd, collate_fn=self._collate_fn_raw)  # dict of data loaders that serve raw text
        self.dld_tkn = self._get_dataloaders_dict(dsd, collate_fn=self._collate_fn_tkn)  # dict of data loaders that serve tokenized text
        self.dsd_tkn = dsd 
        
    def _add_idx(self, batch, idx):
        """Add row numbers"""
        batch['idx'] = idx 
        return batch   
    
    def _add_n_tokens(self, batch): 
        """Add the number of tokens of the orig text """
        batch['n_tokens'] = [len(o) for o in batch['input_ids']]
        return batch 
    
    def _add_sts_orig_embeddings(self, batch): 
        """Add the sts embeddings of the orig text"""
        batch['orig_sts_embeddings'] = self._sts_model.encode(batch[self._cfg.orig_cname], batch_size=64, convert_to_tensor=False)
        return batch
    
    def _add_vm_orig_score(self, batch): 
        """Add the vm score of the orig text"""
        labels = torch.tensor(batch[self._cfg.label_cname], device=self._cfg.device)
        orig_probs,orig_predclass = get_vm_probs(batch[self._cfg.orig_cname], self._cfg, self._vm_tokenizer,
                                                 self._vm_model, return_predclass=True)
        batch['orig_truelabel_probs'] = torch.gather(orig_probs,1, labels[:,None]).squeeze().cpu().tolist()
        batch['orig_vm_predclass'] = orig_predclass.cpu().tolist()
        return batch
    
    def _tokenize_fn(self, batch):  
        """Tokenize a batch of orig text using the paraphrase tokenizer."""
        return self._pp_tokenizer(batch[self._cfg.orig_cname], truncation=True, max_length=self._cfg.orig_max_length)  
    
    def _collate_fn_tkn(self, x): 
        """Collate function used by the DataLoader that serves tokenized data."""
        d = dict()
        for k in ['idx', 'attention_mask', 'input_ids', 'label', 'orig_truelabel_probs', 'orig_sts_embeddings']: 
            d[k] = [o[k] for o in x]
        return self._pp_tokenizer.pad(d, pad_to_multiple_of=self._cfg.orig_padding_multiple, return_tensors="pt")

    def _collate_fn_raw(self, x): 
        """Collate function used by the DataLoader that serves raw data."""
        d = dict()
        for k in [cfg.orig_cname, 'idx']: d[k] = [o[k] for o in x]
        return d 

    def _get_sampler(self, ds): 
        """Returns a RandomSampler. Used so we can keep the same shuffle order across multiple data loaders.
        Used when self._cfg.shuffle_train = True"""
        g = torch.Generator()
        g.manual_seed(seed)
        return RandomSampler(ds, generator=g)
    
    def _prep_small_ds(self, dsd):
        """Replaces a datasetdict with a smaller shard of itself."""
        for k,v in dsd.items():  
            dsd[k] = v.shard(self._cfg.n_shards, 0, contiguous=self._cfg.shard_contiguous)
        return dsd
        
    def _get_dataloaders_dict(self, dsd, collate_fn): 
        """Prepare a dict of dataloaders for train, valid and test"""
        if self._cfg.bucket_by_length and self._cfg.shuffle_train:  raise Exception("Can only do one of bucket by length or shuffle")
        d = dict()
        for split, ds in dsd.items(): 
            if self._cfg.shuffle_train:
                if split == "train": 
                    sampler = self.get_sampler(ds)
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_train, 
                                           sampler=sampler, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
                else: 
                    d[split] =  DataLoader(ds, batch_size=self._cfg.batch_size_eval, 
                                           shuffle=False, collate_fn=collate_fn, 
                                           num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
            if self._cfg.bucket_by_length: 
                batch_size = self._cfg.batch_size_train if split == "train" else self._cfg.batch_size_eval
                d[split] =  DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, 
                                       num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 

        # Add eval dataloader for train 
        d['train_eval'] = DataLoader(dsd['train'], batch_size=self._cfg.batch_size_eval, shuffle=False,
                                    collate_fn=collate_fn, 
                                     num_workers=self._cfg.n_wkrs, pin_memory=self._cfg.pin_memory) 
        return d 
    
    def show_random_elements(self, ds, num_examples=10):
        """Print some elements in a nice format so you can take a look at them. 
        Split is one of 'train', 'test', 'valid'. 
        Use for a dataset `ds` from the `dataset` package.  """
        assert num_examples <= len(ds), "Can't pick more elements than there are in the dataset."
        picks = []
        for _ in range(num_examples):
            pick = random.randint(0, len(ds)-1)
            while pick in picks:
                pick = random.randint(0, len(ds)-1)
            picks.append(pick)
        df = pd.DataFrame(ds[picks])
        for column, typ in ds.features.items():
            if isinstance(typ, ClassLabel):
                df[column] = df[column].transform(lambda i: typ.names[i])
        display(HTML(df.to_html()))

## Usage 

### Basics 

Here we have defined a class `ProcessedDataset` that will load and preprocess a dataset. But before processing the dataset you must load both the config object and all models/tokenizers, so we do this first. 

In [None]:
from travis_attack.models import prepare_models
cfg = Config()
vm_tokenizer, vm_model, pp_tokenizer, pp_model, sts_model, cfg = prepare_models(cfg)

Currently there are two choices for dataset: 

* `simple`, a dataset of simple sentences with four elements each in the train, test and valid splits
* `rotten_tomatoes`, a dataset of movie reviews scraped from the Rotten Tomatoes site. 

The dataset is specified by the config class. There are two ways to do this. 

1. Edit the `self.dataset_name` variable in the Config class to either `simple` or `rotten_tomatoes`. An error will be thrown if the name is not one of these two. This is the best way to use when doing runs.   
2. Use the `adjust_dataset_...` methods of the config class: e.g. `cfg = cfg.adjust_dataset_for_rotten_tomatoes_dataset() = Config()`. This is easiest for automated testing so we will do this here. 

Once the config is specified and loaded, create an object of class `ProcessedDataset` by passing the config, models and tokenizers as variables. This will do all preprocessing automatically in creating the object (the preprocessing code is in the `__init__()` function of the class.  

In [None]:
cfg_simple = cfg.adjust_config_for_simple_dataset()
ds = ProcessedDataset(cfg_simple,  vm_tokenizer, vm_model, pp_tokenizer, sts_model)

Using custom data configuration default-b253756c445fb811
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-b253756c445fb811/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-c802946231f72062
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-c802946231f72062/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-43a49c5188c42e69
Reusing dataset csv (/data/tproth/.cache/huggingface/datasets/csv/default-43a49c5188c42e69/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

If you want to use the small dataset adjust the config before creating the `ProcessedDataset` object. 

In [None]:
cfg_rt_small_ds = cfg.adjust_config_for_rotten_tomatoes_dataset().small_ds()
ds = ProcessedDataset(cfg_rt_small_ds,  vm_tokenizer, vm_model, pp_tokenizer, sts_model)

Using custom data configuration default
Reusing dataset rotten_tomatoes_movie_review (/data/tproth/.cache/huggingface/datasets/rotten_tomatoes_movie_review/default/1.0.0/e06abb624abab47e1a64608fdfe65a913f5a68c66118408032644a3285208fb5)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

### Accessing datasets

You can access raw data with `ds.dsd_raw` and processed data with the `ds.dsd_tkn`. (The dsd here stands for "DatasetDict")

In [None]:
ds.dsd_raw

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 8530
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
    valid: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
})

In [None]:
ds.dsd_tkn

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 130
    })
    test: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 15
    })
    valid: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'n_tokens', 'orig_sts_embeddings', 'orig_truelabel_probs', 'orig_vm_predclass', 'text'],
        num_rows: 16
    })
})

You can access elements by indexing: 

In [None]:
ds.dsd_raw['valid'][0:2]

{'text': ['compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
  'the soundtrack alone is worth the price of admission .'],
 'label': [1, 1]}

In [None]:
ds.dsd_tkn['valid'][0:2]

{'attention_mask': [[1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1],
  [1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1]],
 'idx': [60, 960],
 'input_ids': [[0,
   12979,
   4104,
   18701,
   1116,
   438,
   5777,
   18,
   9869,
   8,
   2770,
   2156,
   31,
   69,
   308,
   26735,
   2156,
   13855,
   7,
   5,
   471,
   9,
   5,
   1380,
   9,
   390,
   18,
   3541,
   14,
   3616,
   7,
   1877,
   5,
   42639,
   9,
   32693,
   30802,
   12,
   506,
   36562,
   30,
   8959,
   2182,
   40629,
   918,
   19,
   10,
   33937,
 

Alternately you can look at some random elements of a dataset with the `ds.show_random_elements()` method. 

In [None]:
ds.show_random_elements(ds.dsd_raw['train'], num_examples=3)

Unnamed: 0,text,label
0,"some movies were made for the big screen , some for the small screen , and some , like ballistic : ecks vs . sever , were made for the palm screen .",neg
1,would that greengrass had gone a tad less for grit and a lot more for intelligibility .,neg
2,"almost everything about the film is unsettling , from the preposterous hairpiece worn by lai's villainous father to the endless action sequences .",neg


In [None]:
ds.show_random_elements(ds.dsd_tkn['train'], num_examples=3)

Unnamed: 0,attention_mask,idx,input_ids,label,n_tokens,orig_sts_embeddings,orig_truelabel_probs,orig_vm_predclass,text
0,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",7140,"[0, 33394, 5, 1270, 18, 3741, 6435, 12, 261, 12, 627, 12, 3628, 14, 3649, 5, 6120, 951, 342, 11, 7, 283, 62, 19, 41, 38279, 352, 46046, 3693, 5494, 12745, 4286, 479, 2]",neg,34,"[-0.04196130111813545, -0.009238973259925842, -0.3204658031463623, -0.03222867101430893, 0.06881716102361679, -0.03910493105649948, 0.27109187841415405, 0.20691175758838654, 0.0867067202925682, -0.15915031731128693, 0.020389346405863762, 0.26467210054397583, 0.205707386136055, -0.19415383040905, 0.030920574441552162, 0.1559268981218338, 0.11858794838190079, 0.2234838753938675, -0.11196062713861465, 0.21912512183189392, 0.12638849020004272, 0.16174715757369995, 0.2730186879634857, 0.09718873351812363, 0.2680603563785553, -0.11813471466302872, 0.05766817927360535, -0.04775218293070793, -0.37661904096603394, -0.1680448353290558, -0.1476079672574997, -0.13525497913360596, 0.0857386440038681, -0.04185601323843002, -0.09892675280570984, 0.0362948477268219, -0.007448195014148951, 0.5905905961990356, -0.09983377158641815, 0.11830045282840729, -0.16939152777194977, -0.07516764104366302, -0.34158164262771606, 0.06370783597230911, -0.08166901767253876, -0.042753469198942184, -0.1133575439453125, -0.1314639151096344, -0.2440081089735031, 0.09848330169916153, 0.1827378273010254, 0.015506694093346596, -0.2789808213710785, -0.0493127703666687, 0.015380859375, 0.02625752054154873, 0.028081530705094337, 0.06733235716819763, -0.18548305332660675, 0.26405346393585205, -0.1820208579301834, 0.01674460992217064, 0.012826412916183472, -0.009701496921479702, 0.12419740110635757, -0.23365411162376404, -0.03039911389350891, -0.025195926427841187, -0.4195925295352936, 0.5111600756645203, -0.0054117352701723576, -0.19782604277133942, -0.030812839046120644, 0.17143206298351288, 0.21343672275543213, 0.23953931033611298, -0.22269316017627716, 0.026624226942658424, 0.3169982433319092, -0.015234175138175488, -0.05387616157531738, 0.017674501985311508, -0.03724076598882675, 0.04830005764961243, -0.2180975377559662, 0.08899800479412079, 0.018246149644255638, 0.2746174931526184, -0.01744389906525612, -0.01190308015793562, -0.16709621250629425, -0.39918261766433716, 0.6197574138641357, 0.0384308397769928, 0.1539965569972992, -0.16340729594230652, -0.1247488260269165, 0.044038742780685425, -0.13146823644638062, 0.3809778094291687, ...]",0.905443,0,consider the title's clunk-on-the-head that suggests the overtime someone put in to come up with an irritatingly unimaginative retread concept .
1,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",3120,"[0, 14746, 5, 1569, 29652, 5, 7636, 642, 1264, 8, 5, 740, 5166, 32876, 763, 2156, 24, 5684, 10, 2579, 12989, 479, 2]",pos,23,"[0.020312774926424026, -0.4342547655105591, -0.11203484237194061, 0.06356408447027206, -0.2472907453775406, 0.23195934295654297, 0.026600506156682968, -0.13035474717617035, -0.17619594931602478, -0.1673208773136139, 0.24095813930034637, 0.2227906435728073, -0.28435206413269043, -0.06695634871721268, 0.04540446400642395, -0.15197722613811493, 0.4307054877281189, 0.17793914675712585, 0.21090035140514374, -0.29525861144065857, -0.174674853682518, 0.1559278517961502, 0.22573716938495636, 0.0655001625418663, -0.13132180273532867, -0.13206128776073456, 0.25834912061691284, -0.06781020015478134, -0.014282864518463612, -0.2692868709564209, -0.26822662353515625, 0.5002068877220154, 0.01568022184073925, 0.011855524964630604, -0.4588114023208618, 0.05000201240181923, -0.23240965604782104, 0.10988220572471619, 0.13011987507343292, 0.02135503850877285, -0.31529998779296875, 0.22454825043678284, 0.3885307013988495, -0.033428989350795746, -0.20892013609409332, 0.11880960315465927, -0.3772817850112915, 0.008186757564544678, 0.17276744544506073, -0.04272422939538956, -0.03144402429461479, -0.1372782289981842, 0.04585927352309227, -0.09268304705619812, 0.07099775969982147, 0.11562521010637283, -0.07953808456659317, 0.12489095330238342, -0.04774516075849533, 0.04359965771436691, 0.03469322249293327, -0.01991434022784233, -0.389232873916626, 0.023120155557990074, 0.2393990308046341, -0.40428078174591064, 0.3099798560142517, 0.09177684783935547, -0.2791033983230591, -0.061158906668424606, -0.24221259355545044, 0.3895144760608673, 0.007583014201372862, -0.3246764838695526, -0.29141902923583984, -0.03138533607125282, 0.3038386404514313, -0.3452931344509125, -0.334575355052948, 0.026498019695281982, 0.13110381364822388, 0.19747871160507202, -0.12946228682994843, -0.20182059705257416, -0.037941671907901764, 0.003365838434547186, 0.13819922506809235, 0.0592874251306057, -0.2996733784675598, -0.06252466142177582, -0.0723121389746666, 0.5948566198348999, -0.34012967348098755, 0.3228849172592163, 0.22052548825740814, 0.1321076899766922, 0.16654090583324432, 0.02773849479854107, 0.0009489202639088035, -0.09560590982437134, ...]",0.863916,1,"when the movie mixes the cornpone and the cosa nostra , it finds a nice rhythm ."
2,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",6900,"[0, 627, 1219, 939, 303, 2185, 1747, 30780, 12677, 30, 42, 822, 2156, 61, 16, 15192, 1043, 922, 7223, 2622, 8, 34, 1473, 383, 7, 224, 2156, 16, 14, 24, 606, 420, 1195, 350, 32907, 25, 38147, 4405, 479, 2]",neg,40,"[0.18801912665367126, -0.10782486200332642, -0.28841546177864075, 0.07453859597444534, 0.3511906862258911, 0.04877807945013046, -0.07187788188457489, -0.04469688981771469, -0.022374307736754417, -0.1601536124944687, 0.2535260021686554, 0.13563890755176544, -0.012134546414017677, 0.20821364223957062, 0.053624581545591354, 0.03267998620867729, 0.32723015546798706, -0.3744050860404968, -0.157009556889534, 0.11462514102458954, 0.135099396109581, 0.2841282784938812, 0.475480854511261, 0.22668801248073578, 0.25136157870292664, -0.0413767546415329, -0.0355643555521965, -0.13597463071346283, -0.19619262218475342, -0.48532748222351074, -0.21454568207263947, 0.11448192596435547, 0.18778006732463837, 0.042475391179323196, -0.02271394431591034, 0.1362188309431076, -0.012045664712786674, 0.1488579511642456, -0.1274537444114685, -0.37674543261528015, -0.3765082061290741, 0.18845713138580322, 0.1247776672244072, 0.03637618198990822, -0.10104376077651978, 0.05746263638138771, -0.15823054313659668, 0.0757000669836998, -0.13739892840385437, -0.1401820182800293, -0.1711084395647049, -0.0691545307636261, -0.02869676984846592, -0.3340966999530792, -0.08003940433263779, -0.25785234570503235, 0.046544190496206284, 0.016493570059537888, -0.024365032091736794, -0.17687088251113892, -0.09296205639839172, -0.22306883335113525, -0.23137523233890533, -0.010635781101882458, 0.4442282021045685, -0.21966774761676788, -0.03467832878232002, -0.2432144433259964, -0.11449000984430313, 0.23409685492515564, 0.09828513860702515, 0.22060738503932953, 0.20325124263763428, 0.1981322467327118, -0.06249312311410904, 0.12440680712461472, 0.23496490716934204, -0.0374331958591938, -0.13474079966545105, 0.024354320019483566, 0.0883830115199089, 0.04915573447942734, 0.02160516194999218, 0.05301307514309883, -0.16829432547092438, 0.22986352443695068, 0.16089969873428345, 0.3052110970020294, -0.14377351105213165, 0.04528259113430977, -0.14061902463436127, -0.025395579636096954, 0.10156853497028351, 0.22802916169166565, 0.3172958493232727, -0.2131151705980301, 0.02391723357141018, 0.014955279417335987, -0.08001270145177841, -0.06634616106748581, ...]",0.938107,0,"the reason i found myself finally unmoved by this film , which is immaculately produced and has serious things to say , is that it comes across rather too plainly as allegory ."


### Accessing DataLoaders 

We can access dataloaders for the raw text in `ds.dsd_raw` with `ds.dld_raw`, and for the tokenised text in `ds.dsd_tkn` with `ds_dld_tkn`. Both of these are dictionaries of dataloaders with keys `['train', 'valid', 'test', 'train_eval]`. 

In [None]:
print("Dataloader dict has keys:", ds.dld_raw.keys())
batch_raw = next(iter(ds.dld_raw['train']))
batch_tkn = next(iter(ds.dld_tkn['train']))
print(batch_raw.keys())
print(batch_tkn.keys())
print("Tokenised input is of shape:", batch_tkn['input_ids'].shape)

Dataloader dict has keys: dict_keys(['train', 'test', 'valid', 'train_eval'])
dict_keys(['text', 'idx'])
dict_keys(['idx', 'attention_mask', 'input_ids', 'label', 'orig_truelabel_probs', 'orig_sts_embeddings'])
Tokenised input is of shape: torch.Size([32, 64])


#hide
## Export 

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 03_config.ipynb.
Converted 07_models.ipynb.
Converted 10_data.ipynb.
Converted 20_trainer.ipynb.
Converted 30_logging.ipynb.
Converted index.ipynb.
Converted run.ipynb.
