#remove torchscript
# change to distillation bert
#output to onnx
# gpu version
# https://github.com/huggingface/transformers/issues/227
#https://huggingface.co/transformers/serialization.html

In [8]:
from transformers import DistilBertTokenizer,DistilBertTokenizerFast
from transformers import DistilBertForTokenClassification, AdamW, DistilBertConfig
from transformers import BatchEncoding
from tokenizers import Encoding
import torch


# read label
from typing_extensions import TypedDict
from typing import List,Any
IntList = List[int] # A list of token_ids
IntListList = List[IntList] # A List of List of token_ids, e.g. a Batch


import itertools
class LabelSet:
    def __init__(self, labels: List[str]):
        self.labels_to_id = {}
        self.ids_to_label = {}

        self.labels_to_id["o"] = 0
        self.ids_to_label[0] = "o"
        num = 1
        for label in labels:
            if label == "o":
                print("skip:{}".format(label))
                continue
            self.labels_to_id[label] = num
            self.ids_to_label[num] = label
            num = num +1 


    def get_aligned_label_ids_from_aligned_label(self, aligned_labels):
        return list(map(self.labels_to_id.get, aligned_labels))

    def get_untagged_id(self):
        return self.labels_to_id["o"]

    def get_labels(self):
        return self.labels_to_id

slots = ["O", 
    "file_name", 
    "file_type", 
    "data_source", 
    "contact_name", 
    "to_contact_name",
    "file_keyword",
    "date",
    "time",
    "meeting_starttime",
    "file_action",
    "file_action_context",
    "position_ref",
    "order_ref",
    "file_recency",
    "sharetarget_type",
    "sharetarget_name",
    "file_folder",
    "data_source_name",
    "data_source_type",
    "attachment"]

# map all slots to lower case
slots_label_set = LabelSet(labels=map(str.lower,slots))



#enc = BertTokenizer.from_pretrained("bert-base-uncased")
fast_tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') # Load a pre-trained tokenizer

# Tokenizing input text
#text = "[CLS] a visually stunning rumination on love"
#text = "[CLS] a visually stunning rumination on love [SEP]"
text = "a visually stunning rumination on love"
fast_tokenized_batch : BatchEncoding = fast_tokenizer(text)
fast_tokenized_text :Encoding  =fast_tokenized_batch[0]
fast_tokenized_text_tokens_copy = fast_tokenized_text.tokens


# for deubg
fast_tokenized_text_old = fast_tokenizer.tokenize(text)
print("fast token ouput version 1: {}".format(fast_tokenized_text_old))
print("fast token ouput version 2 being used here: {}".format(fast_tokenized_text_tokens_copy))

# Masking one of the input tokens
# ? in yue's tutria; for bert it does not have [mask] for replacement. so this migt be optional
# check in the future
masked_index = 3
fast_tokenized_text_tokens_copy[masked_index] = '[MASK]'
indexed_tokens = fast_tokenizer.convert_tokens_to_ids(fast_tokenized_text_tokens_copy)

# for debug 
print("fast token ouput version 2 being used here after replace: {}".format(fast_tokenized_text_tokens_copy))

print("indexed_tokens: {}".format(indexed_tokens))
segments_ids = [0]

# Creating a dummy input
# but you need to move tensors to GPU
#https://github.com/huggingface/transformers/issues/227
# discuss convertion
#https://discuss.pytorch.org/t/best-way-to-convert-a-list-to-a-tensor/59949/2
#torch.tensor
#tokens_tensor = torch.tensor([indexed_tokens])
#segments_tensors = torch.tensor([segments_ids])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("create input for device: {}".format(device))
tokens_tensor = torch.tensor([indexed_tokens]).to(device)
segments_tensors = torch.tensor([segments_ids]).to(device)
dummy_input = tokens_tensor


# for debug
print("tokens_tensor shape for chunk: {}".format(tokens_tensor[0].shape[0])) 
for token_tensor in tokens_tensor:
    print("token_tensor shape for chunk: {}".format(token_tensor.shape[0]))


# for deubg
# 14 tokens for output
print("tokens_tensor shape: {}".format(tokens_tensor.shape))
print("segments_tensor shape: {}".format(segments_tensors.shape))

print("tokens_tensor: {}".format(tokens_tensor))
print("segments_tensor: {}".format(segments_tensors))


# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
#config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
#    num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)

# Instantiating the model
#model = BertModel(config)

# The model needs to be in evaluation mode
#model.eval()

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
#model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
num_labels = len(set(slots_label_set.get_labels()))
model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased', num_labels=num_labels,
                                                            output_attentions=False, output_hidden_states=False)




torch.onnx.export(model=model,
    args=(dummy_input),
    f='traced_distill_bert.onnx.bin',
    input_names = ["input_ids"],
    verbose=True,
    output_names = ["logits"],
    do_constant_folding = True,
    opset_version=11,
    dynamic_axes = {'input_ids': {1: '?'}, 'logits': {1: '?'}}
    )


# if want to want to download, uncomment it 
#torch.jit.save(traced_model, "traced_distill_bert.pt")

skip:o
fast token ouput version 1: ['a', 'visually', 'stunning', 'rum', '##ination', 'on', 'love']
fast token ouput version 2 being used here: ['[CLS]', 'a', 'visually', 'stunning', 'rum', '##ination', 'on', 'love', '[SEP]']
fast token ouput version 2 being used here after replace: ['[CLS]', 'a', 'visually', '[MASK]', 'rum', '##ination', 'on', 'love', '[SEP]']
indexed_tokens: [101, 1037, 17453, 103, 19379, 12758, 2006, 2293, 102]
create input for device: cpu
tokens_tensor shape for chunk: 9
token_tensor shape for chunk: 9
tokens_tensor shape: torch.Size([1, 9])
segments_tensor shape: torch.Size([1, 1])
tokens_tensor: tensor([[  101,  1037, 17453,   103, 19379, 12758,  2006,  2293,   102]])
segments_tensor: tensor([[0]])
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_proj