In [None]:
from transformers import BertModel, BertTokenizer, BertConfig
import torch
from typing import Tuple, Optional, List
from torch import nn
from transformers import BertPreTrainedModel
import os

os.makedirs(name="nlp_model", exist_ok=True)

# os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

### step1.  load bert model (such as tokenizer a text, and get tokens_tenosr, segements)


In [None]:
enc = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]


In [None]:
tokens_tensor

In [None]:
segments_ids

### step2.  costom my bert model

In [None]:

class Mybert4Sentence(BertPreTrainedModel):
    # copy code from  class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> torch.Tensor:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]

        return pooled_output


### step3. load model and save model to .pt file

In [None]:
model2 = Mybert4Sentence.from_pretrained("bert-base-uncased", torchscript=True)

In [None]:
traced_model = torch.jit.trace(model2, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "nlp_model/traced_bert.pt")

### step4. load .pt file then check it

In [None]:
loaded_model = torch.jit.load("nlp_model/traced_bert.pt")
loaded_model.eval()

In [None]:
loaded_model(*dummy_input).shape

In [None]:
loaded_model(*dummy_input)[:, :10]