In [None]:
# default_exp utils

In [None]:
#export
import os
import random
import torch
import torch.nn as nn
import transformers
import numpy as np

from fastcore.test import *
from fastcore.transform import Transform
from fastcore.foundation import L
from fastai.text.data import TensorText
from fastai.text.core import Tokenizer

## Some Utility functions

In [None]:
#export
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
#hide
MODEL_NAME = 'bert-base-uncased'
bert_tokenizer = transformers.BertTokenizer.from_pretrained(
            pretrained_model_name_or_path=MODEL_NAME,
            do_lower_case=True,
            )

In [None]:
#export
def bert_cls_splitter(m):
    "Split the classifier head from the backbone"
    groups = [nn.Sequential(m.model.embeddings,
                m.model.encoder.layer[0],
                m.model.encoder.layer[1],
                m.model.encoder.layer[2],
                m.model.encoder.layer[3],
                m.model.encoder.layer[4],
                m.model.encoder.layer[5],
                m.model.encoder.layer[6],
                m.model.encoder.layer[7],
                m.model.encoder.layer[8],
                m.model.encoder.layer[9],
                m.model.encoder.layer[10],
                m.model.encoder.layer[11],
                m.model.pooler)]
#     groups = L(groups + [m.model.classifier])
    groups = L(groups) #using BertModel which ends at pool
    # fastai stores the parametes in each layer in a `params` variable
    return groups.map(params)

In [None]:
#export
class HFTokenizer():
    def __init__(self, tokenizer):
        self.tok = tokenizer

    def tokenize(self, text):
        tokens = self.tok.tokenize(text)
        return tokens

    def __call__(self, items):
        # ALways yeild the tokenized text before passing it to the Tokenizer Transform
        for text in items:
            yield self.tokenize(text)

In [None]:
#hide
list(HFTokenizer(bert_tokenizer)('I am a fish'))

[['i'], [], ['a'], ['m'], [], ['a'], [], ['f'], ['i'], ['s'], ['h']]

In [None]:
#export
class Add_Special_Cls(Transform):
    order = 7
    def __init__(self, tokenizer):
        self.tok = tokenizer

    def encodes(self, o):
        return TensorText(self.tok.build_inputs_with_special_tokens(list(o)))

In [None]:
#hide
test_eq(Add_Special_Cls(bert_tokenizer)([0, 1, 2]), TensorText([101,   0,   1,   2, 102]))