In [None]:
#| default_exp generation.trie

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from tqdm.auto import tqdm
from fastcore.dispatch import *
from xcai.data import XCDataBlock
from dataclasses import dataclass
from typing import Optional, List, Any, Union

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

## Setup

In [None]:
from xcai.test_utils import *

In [None]:
block = Test.from_cfg('train')

## Trie

In [None]:
#| export
class TrieNode:
    def __init__(self, tok:int):
        self.tok, self.nxt_toks = tok, {}
        self.is_end, self.cnt, self.info = False, 0, None

    @property
    def data(self):
        return self.tok, self.nxt_toks, self.is_end, self.cnt, self.info

    @data.setter
    def data(self, x):
        self.tok, self.nxt_toks, self.is_end, self.cnt, self.info = x


In [None]:
#| export
@dataclass
class TrieOutput:
    s:Optional[List]=None
    cnt:Optional[int]=None
    info:Optional[Any]=None
    

In [None]:
#| export
class Trie(object):
    def __init__(self):
        self.root, self.depth = None, 0

    @staticmethod
    def _add_info(node:TrieNode, info:Any):
        if node.info is None: 
            node.info = info if isinstance(info, list) else [info]
        else: 
            if isinstance(info, list): node.info.extend(info)
            else: node.info.append(info)
        
    def insert(self, toks:Optional[List], info:Optional[Any]=None):
        if len(toks) > self.depth: self.depth = len(toks)
        if self.root is None: self.root=TrieNode(toks[0])
        if self.root.tok != toks[0]: raise ValueError(f'Expected `bos_tok` to be `{self.root.tok}` but got `{toks[0]}`.')
        node = self.root
        for tok in toks[1:]:
            node.cnt += 1
            if tok in node.nxt_toks: node = node.nxt_toks[tok]
            else: node.nxt_toks[tok]=node=TrieNode(tok)
        node.is_end = True
        if info is not None: Trie._add_info(node, info)
        node.cnt += 1
        
    @staticmethod
    def _search(node:TrieNode, p:List, o:List):
        if node.is_end: o.append(TrieOutput(p, node.cnt, node.info)); return
        for tok, n in node.nxt_toks.items(): Trie._search(n, p+[tok], o)

    def suffixes(self, x:Union[int,List]):
        x = [x] if isinstance(x, int) else x
        node, o = self.root, []
        if node.tok != x[0]: return []
        for tok in x[1:]:
            if tok in node.nxt_toks: node = node.nxt_toks[tok]
            else: return
        Trie._search(node, x, o)
        return sorted(o, key=lambda x: x.cnt, reverse=True)

    @staticmethod
    def _prune(node):
        for t,n in node.nxt_toks.items():
            Trie._prune(n)
            if len(node.nxt_toks) == 1 and len(n.nxt_toks) == 1 and next(iter(n.nxt_toks.values())).is_end:
                node.nxt_toks = n.nxt_toks
        
    def prune(self):
        self._prune(self.root)

    def prefix(self, x:List):
        node, o = self.root, [x[0]]
        if node.tok != x[0]: raise ValueError(f'`bos_tok`({x[0]}) cannot be "{node.tok}".')
        for tok in x[1:-1]:
            if tok in node.nxt_toks: node=node.nxt_toks[tok]; o.append(tok)
        if x[-1] in node.nxt_toks and node.nxt_toks[x[-1]].is_end: return o+x[-1:]

    def __contains__(self, x:List):
        node = self.root
        if node.tok != x[0]: raise ValueError(f'`bos_tok`({x[0]}) cannot be "{node.tok}".')
        for tok in x[1:]: 
            if tok in node.nxt_toks: node = node.nxt_toks[tok]
            else: return False
        return node.is_end

    @property
    def bos_tok(self):
        return self.root.tok

    @typedispatch
    def update(self, x:List):
        for o in tqdm(x): self.insert(o)

    @typedispatch
    def update(self, x:List, y:List):
        for p,q in tqdm(zip(x,y), total=len(x)): self.insert(p,q)

    @classmethod
    @typedispatch
    def from_list(cls, x:List):
        self = cls()
        for o in tqdm(x): self.insert(o)
        return self

    @classmethod
    @typedispatch
    def from_list(cls, x:List, y:List):
        self = cls()
        for p,q in tqdm(zip(x,y), total=len(x)): self.insert(p,q)
        return self


### Example 1

In [None]:
arr = [[101, 100, 200, 300, 102], 
       [101, 200, 100, 100, 109, 102],
       [101, 200, 100, 100, 301, 102],
       [101, 300, 301, 102],
       [101, 300, 301, 102],
       [101, 200, 100, 222, 301, 401, 501, 444, 102]]

info = ['aa', 'bb', 'dd', 'ee', 'hh', 'ii']

t = Trie.from_list(arr, info)

  0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
info = [['a'], ['b','c'], ['d'], ['e','f','g'], ['h'], ['i','j']]
t.update(arr, info)

  0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
t.suffixes(101)

[TrieOutput(s=[101, 300, 301, 102], cnt=4, info=['ee', 'hh', 'e', 'f', 'g', 'h']),
 TrieOutput(s=[101, 100, 200, 300, 102], cnt=2, info=['aa', 'a']),
 TrieOutput(s=[101, 200, 100, 100, 109, 102], cnt=2, info=['bb', 'b', 'c']),
 TrieOutput(s=[101, 200, 100, 100, 301, 102], cnt=2, info=['dd', 'd']),
 TrieOutput(s=[101, 200, 100, 222, 301, 401, 501, 444, 102], cnt=2, info=['ii', 'i', 'j'])]

In [None]:
t.prune()

In [None]:
t.suffixes(101)

[TrieOutput(s=[101, 300, 102], cnt=2, info=['h']),
 TrieOutput(s=[101, 100, 102], cnt=1, info=['a']),
 TrieOutput(s=[101, 200, 100, 100, 109, 102], cnt=1, info=['b', 'c']),
 TrieOutput(s=[101, 200, 100, 100, 301, 102], cnt=1, info=['d']),
 TrieOutput(s=[101, 200, 100, 222, 102], cnt=1, info=['i', 'j'])]

In [None]:
t.prefix([101, 200, 100, 222, 301, 401, 501, 444, 102])

[101, 200, 100, 222, 102]

### Example 2

In [None]:
block.lbl_info.keys()

dict_keys(['identifier', 'input_text', 'input_ids', 'token_type_ids', 'attention_mask'])

In [None]:
toks = block.lbl_info['input_ids']
info = [[i] for i in range(len(toks))]
t = Trie.from_list(toks, info)

  0%|          | 0/312330 [00:00<?, ?it/s]

## XCTrie

In [None]:
#| export
class XCTrie:
    
    @classmethod
    def from_block(cls, block:XCDataBlock, meta:Optional[List]=None):
        lbl_toks = block.lbl_info['input_ids']
        lbl_info = [[i] for i in range(len(lbl_toks))]
        
        trie = Trie.from_list(lbl_toks, lbl_info)

        if meta is not None:
            meta_dset = block.train.dset.meta
            for o in meta:
                if f'{o}_meta' not in meta_dset: raise ValueError(f'`{o}_meta` does not exist.')
                meta_toks = meta_dset[f'{o}_meta'].meta_info['input_ids']
                lbl_meta = meta_dset[f'{o}_meta'].lbl_meta.T.tocsr()
                meta_info = [o.indices.tolist() for o in lbl_meta]
                if len(meta_toks) != len(meta_info): raise ValueError(f'`meta_toks` and `meta_info` should have equal length.')
                trie.update(meta_toks, meta_info)
                
        return trie
        