# Categorical Transformation for DL
> a list of things to categoryical transformation

In [2]:
# default_exp category

## Imports

In [74]:
# export
import pandas as pd
import numpy as np
from pathlib import Path
import json
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from typing import Iterable

class C2I:
    """
    Category to indices
    >>> c2i = C2I(
            ["class 1", "class 2", ..., "class n"],
            pad_mst=True,
            )
    >>> c2i[["class 2", "class 5"]]
    [0] array([2,3])
    
    If the indices you put in the slicing is a np.ndarray
        a verctorized function will be used
    """
    def __init__(
        self,
        arr:Iterable,
        pad_mst:bool=False,
    ):
        self.pad_mst = pad_mst
        self.pad = ["[MST]",] if self.pad_mst else []
        self.dict = dict(
            (v,k) for k,v in enumerate(self.pad + list(arr)))
        self.get_int = self.get_get_int()
        self.get_int_ = np.vectorize(self.get_int)
        
    def get_get_int(self,):
        if self.pad_mst:
            def get_int(idx:str) -> int:
                if idx in self.dict:
                    return self.dict[idx]
                else:
                    return 0
        else:
            def get_int(idx:str) -> int:
                return self.dict[idx]
        return get_int

    def __repr__(self) -> str:
        return f"C2I:{self.__len__()} categories"

    def __len__(self):
        return len(self.dict)

    def __getitem__(self, k:int):
        if type(k) == np.ndarray:
            # use vectorized function
            return self.get_int_(k)
        else:
            # use the original python function
            return self.get_int(k)
        
class Category:
    """
    Manage categorical translations
    c = Category(
            ["class 1", "class 2", ..., "class n"],
            pad_mst=True,)
            
    c.c2i[["class 3","class 6"]]
    c.i2c[[3, 2, 1]]
    """
    def __init__(
        self,
        arr:Iterable,
        pad_mst:bool=False
    ):
        self.pad_mst=pad_mst
        self.c2i = C2I(arr, pad_mst=pad_mst)
        self.i2c = np.array(self.c2i.pad+list(arr))

    def save(self,path: Path) -> None:
        """
        save category information to json file
        """
        with open(path,"w") as f:
            json.dump(self.i2c.tolist(),f)

    @classmethod
    def load(cls, path:Path):
        """
        load category information from a json file
        """
        with open(path,"r") as f:
            l = np.array(json.load(f))
        if l[0]=="[MST]":
            return cls(l[1:], pad_mst=True)
        else:
            return cls(l, pad_mst=False)

    def __len__(self):
        return len(self.i2c)

    def __repr__(self):
        return f"Category Manager with {self.__len__()}"

## Indexing forward and backward

In [75]:
cates = Category(list(map(lambda x:f"Cate_{x+1}",range(50))))

In [76]:
cates

Category Manager with 50

In [77]:
cates.i2c[:5]

array(['Cate_1', 'Cate_2', 'Cate_3', 'Cate_4', 'Cate_5'], dtype='<U7')

In [78]:
test_c = np.random.randint(1,50,1000)

### Indices to categories

In [79]:
labels = cates.i2c[test_c]
labels[:20]

array(['Cate_49', 'Cate_27', 'Cate_42', 'Cate_15', 'Cate_19', 'Cate_30',
       'Cate_33', 'Cate_18', 'Cate_45', 'Cate_43', 'Cate_3', 'Cate_14',
       'Cate_50', 'Cate_26', 'Cate_42', 'Cate_33', 'Cate_38', 'Cate_38',
       'Cate_50', 'Cate_2'], dtype='<U7')

### Category to indices

In [80]:
cates.c2i[labels[:20]]

array([48, 26, 41, 14, 18, 29, 32, 17, 44, 42,  2, 13, 49, 25, 41, 32, 37,
       37, 49,  1])

Using vectorized function

In [83]:
%%time
indices_generated = cates.c2i[labels]

CPU times: user 445 µs, sys: 10 µs, total: 455 µs
Wall time: 452 µs


Using the original python function

In [84]:
%%time
indices_generated2 = list(cates.c2i.get_int(l) for l in labels)

CPU times: user 730 µs, sys: 0 ns, total: 730 µs
Wall time: 735 µs


Transform forward and backward and check fidelity

In [85]:
(cates.c2i[labels]==test_c).mean()

1.0

## With missing tokens

We can set pad_mst to True to manage missing token

In [86]:
nt = Category("ATCG", pad_mst=True)

### Categories to indices

In [87]:
nt.c2i[np.array(list("AAACCTTATTGCAGCOAAT"))]

array([1, 1, 1, 3, 3, 2, 2, 1, 2, 2, 4, 3, 1, 4, 3, 0, 1, 1, 2])

### Indices to categories

In [88]:
nt.i2c[[1, 1, 1, 3, 3, 2, 2, 2, 2, 4, 3, 1, 4, 3, 0, 1, 1, 2]]

array(['A', 'A', 'A', 'C', 'C', 'T', 'T', 'T', 'T', 'G', 'C', 'A', 'G',
       'C', '[MST]', 'A', 'A', 'T'], dtype='<U5')

## Data save and load

### Save categories

In [89]:
nt.save("atcg.json")

### Load categories

In [90]:
cm = Category.load("atcg.json")
cm

Category Manager with 5