In [16]:
import tensorflow as tf
from tensorflow import keras
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering

In [17]:
%%time
model_dir = "pretrained/google/electra-small-discriminator"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_dir)
print(repr(model.config))

Some layers from the model checkpoint at pretrained/google/electra-small-discriminator were not used when initializing TFElectraForQuestionAnswering: ['discriminator_predictions']
- This IS expected if you are initializing TFElectraForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFElectraForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFElectraForQuestionAnswering were not initialized from the model checkpoint at pretrained/google/electra-small-discriminator and are newly initialized: ['qa_outputs']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ElectraConfig {
  "_name_or_path": "pretrained/google/electra-small-discriminator",
  "architectures": [
    "ElectraForPreTraining"
  ],
  "attention_probs_dropout_prob": 0.1,
  "embedding_size": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "electra",
  "num_attention_heads": 4,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "summary_activation": "gelu",
  "summary_last_dropout": 0.1,
  "summary_type": "first",
  "summary_use_proj": true,
  "transformers_version": "4.5.1",
  "type_vocab_size": 2,
  "vocab_size": 30522
}

CPU times: user 750 ms, sys: 93.8 ms, total: 844 ms
Wall time: 658 ms


In [18]:
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=5e-5),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=tf.metrics.SparseCategoricalAccuracy(),
)
model.summary()

Model: "tf_electra_for_question_answering_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
electra (TFElectraMainLayer) multiple                  13483008  
_________________________________________________________________
qa_outputs (Dense)           multiple                  514       
Total params: 13,483,522
Trainable params: 13,483,522
Non-trainable params: 0
_________________________________________________________________


In [19]:
model.save_pretrained("tmp")

In [20]:
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors="tf")
print(repr(inputs))

{'input_ids': <tf.Tensor: shape=(1, 14), dtype=int32, numpy=
array([[  101,  2040,  2001,  3958, 27227,  1029,   102,  3958, 27227,
         2001,  1037,  3835, 13997,   102]], dtype=int32)>, 'token_type_ids': <tf.Tensor: shape=(1, 14), dtype=int32, numpy=array([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(1, 14), dtype=int32, numpy=array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)>}


In [21]:
outputs = model(inputs)
print(repr(outputs))

TFQuestionAnsweringModelOutput(loss=None, start_logits=<tf.Tensor: shape=(1, 14), dtype=float32, numpy=
array([[-0.03157917, -0.03710311,  0.16405778, -0.01507592, -0.01097412,
        -0.18119028, -0.03076549,  0.04689135,  0.12549038,  0.17952093,
         0.06050755,  0.1644939 ,  0.14850104, -0.03068607]],
      dtype=float32)>, end_logits=<tf.Tensor: shape=(1, 14), dtype=float32, numpy=
array([[ 0.1771944 ,  0.2459357 ,  0.1530178 ,  0.20741835,  0.06197455,
         0.2094365 ,  0.17667554,  0.00232536, -0.09494966,  0.12351818,
        -0.22235042,  0.11328951,  0.04692179,  0.176589  ]],
      dtype=float32)>, hidden_states=None, attentions=None)
