# Running a Model on a Task Dataset

In this example, we're going to walk through how to run a model on a specific task dataset. We'll use the `transformer` base model, and run on the `Secondary Structure` dataset. To do this, we'll need a few things:

- Dataset
- DataLoader + CollateFn
- TAPEConfig
- TaskModel

## Dataset

Let's start by creating the `SecondaryStructureDataset`. This is located in the `tape_pytorch.datasets` module.

There are three arguments that every TAPE task dataset has. The first, `data_path` is the path to the TAPE data directory. By default, we assume this is a folder called `data` in the top level of the `tape-pytorch` directory, but you can put it wherever you like. Note that this is *not* the path to the task-specific data, which is located inside the TAPE data directory and must be named appropriately.

The second argument, `mode`, refers to the particular dataset split to train on. For some tasks this is simply, `train`, `val`, and `test`. For other tasks this can also refer to specific test splits (e.g. for the Secondary Structure task, the test splits are `cb513`, `ts115`, and `casp12`).

The third argument, `tokenizer`, refers to how input sequences should be tokenized. There are two options here, `dummy`, which splits sequences into individual amino acids, and `bpe`, which uses a byte-pair encoding of the input sequences.

In [1]:
from tape_pytorch.datasets import SecondaryStructureDataset

dataset = SecondaryStructureDataset(data_path='../data', mode='train', tokenizer='dummy')

## DataLoader + CollateFn

The DataLoader used in TAPE is the standard [torch DataLoader](https://pytorch.org/docs/stable/data.html). However, we also need a custom [collate_fn](https://pytorch.org/docs/stable/data.html#working-with-collate-fn) because we have variable length sequences. Task-specific collate functions can also be loaded from the `tape_pytorch.datasets` module.

In [2]:
from torch.utils.data import DataLoader
from tape_pytorch.datasets import SecondaryStructureBatch

batch_size = 4
num_workers = 1
collate_fn = SecondaryStructureBatch()

dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)

## TAPEConfig

A TAPE Config object is used to configure the base model for each TAPE task. It can be imported from `tape_pytorch.models`. It provides a number of configuration options. The best way to set these are with config files. See `tape-pytorch/config/transformer_config.json`, which is the file we'll be using to load a config.

In [3]:
from tape_pytorch.models import TAPEConfig

config = TAPEConfig.from_json_file('../config/transformer_config.json')

## TaskModel

The final part is to create the task model! These are located in `tape_pytorch.models.task_models`. The model we need for Secondary Structure prediction is called the `SequenceToSequenceClassificationModel`. Let's make it now. We'll also need to add a config option to `SequenceToSequenceClassificationModel`, which tells it the number of classes we want to predict. Since we're doing 3-class secondary structure prediction, we set `config.num_classes = 3`.

In [4]:
from tape_pytorch.models.task_models import SequenceToSequenceClassificationModel

config.num_classes = 3
task_model = SequenceToSequenceClassificationModel(config)

## Run the Model

And that's it - now we can run the model!

In [7]:
from itertools import islice

loss = 0
batch = next(iter(dataloader))
outputs = task_model(**batch)
print(outputs.keys())
print('Loss:', outputs[task_model.LOSS_KEY])

torch.Size([4, 265]) torch.Size([4, 263])


# An Easier Way

There's a lot of things to remember here, and it would be nice if it was possible to do this in a simpler way. Fortunately, there is! TAPE has a `registry`, which stores all object classes of interest, and lets you access them by providing just the task name. So the code above turns into this:

In [9]:
from tape_pytorch.registry import registry
from torch.utils.data import DataLoader
from tape_pytorch.models import TAPEConfig

task_name = 'secondary_structure'

dataset = registry.get_dataset_class(task_name)(data_path='../data', mode='train', tokenizer='dummy')
collate_fn = registry.get_collate_fn_class(task_name)()
dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)
config = TAPEConfig.from_json_file('../config/transformer_config.json')
config.num_classes = 3
task_model = registry.get_task_model_class(task_name)(config)