In [1]:
import torch
from transformers import BertForMaskedLM, AutoTokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
unique_labels = list(range(10))

In [3]:
class CustomBERTClass(torch.nn.Module):
    def __init__(self):
        super(CustomBERTClass, self).__init__()
        self.l1 = BertForMaskedLM.from_pretrained('distilbert-base-uncased', return_dict=True, output_hidden_states=True)
        self.l1.resize_token_embeddings(len(tokenizer))
        self.l2 = torch.nn.Dropout(0.3)
        self.l3 = torch.nn.Linear(768, 1024)
        self.l4 = torch.nn.Dropout(0.3)
        self.l5 = torch.nn.Linear(1024, len(unique_labels))

        # explainerに読み込ませるためにアトリビュートを生やす
        self.config = self.l1.config
        self.base_model_prefix = self.l1.base_model_prefix
        self.device = self.l1.device
        self.get_input_embeddings = self.l1.get_input_embeddings
        self.bert = self.l1.bert


    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        output = self.l1(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        cls_embeddings = output.hidden_states[-1][:,0,:]
        output = self.l2(cls_embeddings)
        output = self.l3(output)
        output = self.l4(output)
        output = self.l5(output)
        return output

In [4]:
model = CustomBERTClass()

You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing BertForMaskedLM: ['distilbert.transformer.layer.3.ffn.lin2.bias', 'vocab_layer_norm.weight', 'distilbert.transformer.layer.5.attention.k_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.3.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.1.ffn.lin1.bias', 'distilbert.transformer.layer.2.attention.q_lin.bias', 'distilbert.transformer.layer.4.ffn.lin2.bias', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.1.attention.v_lin.bias', 'distilbert.transformer.layer.4.ffn.lin1.weight', 'distilbert.transfor

In [5]:
from transformers_interpret import SequenceClassificationExplainer

In [6]:
cls_explainer = SequenceClassificationExplainer(
    model=model,
    tokenizer=tokenizer
)

In [7]:
# 出力の次元があっていないのでエラー
cls_explainer("I like you.")

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [8]:
class BERTExplainer(SequenceClassificationExplainer):
    @property
    def predicted_class_index(self) -> int:
        "Returns predicted class index (int) for model with last calculated `input_ids`"
        if len(self.input_ids) > 0:
            # we call this before _forward() so it has to be calculated twice
            preds = self.model(self.input_ids)
            self.pred_class = torch.argmax(torch.softmax(preds, dim=0)[0])
            return torch.argmax(torch.softmax(preds, dim=1)[0]).cpu().detach().numpy()

        else:
            raise InputIdsNotCalculatedError("input_ids have not been created yet.`")

    def _forward(  # type: ignore
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
    ):

        if self.accepts_position_ids:
            preds = self.model(
                input_ids,
                position_ids=position_ids,
                attention_mask=attention_mask,
            )
            #preds = preds[0]

        else:
            preds = self.model(input_ids, attention_mask)

        # if it is a single output node
        if len(preds[0]) == 1:
            self._single_node_output = True
            self.pred_probs = torch.sigmoid(preds)[0][0]
            return torch.sigmoid(preds)[:, :]

        self.pred_probs = torch.softmax(preds, dim=1)[0][self.selected_index]
        return torch.softmax(preds, dim=1)[:, self.selected_index]

In [9]:
cls_explainer = BERTExplainer(
    model=model,
    tokenizer=tokenizer
)

In [10]:
cls_explainer("I like you.")

[('[CLS]', 0.0),
 ('i', -0.020300999706232583),
 ('like', -0.3233927557783712),
 ('you', -0.3678551104818028),
 ('.', -0.8716006038395218),
 ('[SEP]', 0.0)]