In [2]:
# | default_exp dataset

# Protein Dataset

> Protein Dataset

In [3]:
# | hide
from nbdev.showdoc import *

In [4]:
#| hide
import nbdev; nbdev.nbdev_export()

In [5]:
#| export
import re
from typing import Tuple, Callable, TypedDict, List, Dict

import torch
from torch.utils.data import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
#| export
def extract_property_tag(name: str) -> Callable:
    def inner(sequence):
        pattern = f"{name}=(.+?) "
        match = re.search(pattern, sequence)
        if match:
            return match.group(1)
        else:
            return None
    return inner

In [None]:
extractor = extract_property_tag("OS")

In [7]:
#| export
class ProteinSequence(TypedDict):
    id: str
    seq: str
    desc: str

In [8]:
#| export
class ProteinDataset(Dataset):
    def __init__(self, data: List[ProteinSequence], tokenizer: Callable, tag_extractor: Callable):
        xs = []
        ys = []
        
        for item in data:
            xs.append(tag_extractor(item["desc"]))
            ys.append(item["seq"])
        
        encoded_ys = tokenizer.encode_batch(ys)
        encoded_ys = [torch.tensor(e.ids) for e in encoded_ys]
        
        self.xs: List[str] = xs
        self.ys: List[str] = encoded_ys

    def __len__(self) -> int:
        return len(self.xs)

    def __getitem__(self, idx: int) -> Tuple[str, str]:
        return self.xs[idx], self.ys[idx]

In [10]:
def get_batch(data, block_size, batch_size):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [14]:
data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x, y = get_batch(data, block_size=4, batch_size=2)

In [15]:
x

tensor([[6, 7, 8, 9],
        [1, 2, 3, 4]])

In [16]:
y

tensor([[ 7,  8,  9, 10],
        [ 2,  3,  4,  5]])