In [1]:
import os
import re
import json
import string
import numpy as np
import tensorflow as tf
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, TFBertModel, BertConfig
from tqdm import tqdm
os.environ["KERAS_BACKEND"] = "tensorflow"

for gpu in  tf.config.experimental.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)


D:\anaconda\envs\tf-gpu-2.10.0-py-3.10\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
D:\anaconda\envs\tf-gpu-2.10.0-py-3.10\lib\site-packages\numpy\.libs\libopenblas64__v0.3.23-246-g3d31191b-gcc_10_3_0.dll
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
max_len = 384
configuration = BertConfig()
tokenizer = BertWordPieceTokenizer("D:/bert-base-uncased/vocab.txt", lowercase=True)

train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
train_path = tf.keras.utils.get_file("train.json", train_data_url)
eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
eval_path = tf.keras.utils.get_file("eval.json", eval_data_url)

In [3]:
class SquadExample:
    def __init__(self, question, context, start_char_idx, answer_text, all_answers):
        self.question = question
        self.context = context
        self.start_char_idx = start_char_idx
        self.answer_text = answer_text
        self.all_answers = all_answers
        self.skip = False

    def preprocess(self):
        context = self.context
        question = self.question
        answer_text = self.answer_text
        start_char_idx = self.start_char_idx
        context = " ".join(str(context).split())
        question = " ".join(str(question).split())
        answer = " ".join(str(answer_text).split())

        # 如果结束位置超出上下文长度就直接跳过返回
        end_char_idx = start_char_idx + len(answer)
        if end_char_idx >= len(context):
            self.skip = True
            return

        # 上下文长度的全 0 列表中将答案位置都标记为 1
        is_char_in_ans = [0] * len(context)
        for idx in range(start_char_idx, end_char_idx):
            is_char_in_ans[idx] = 1

        # 找到答案对应的子token的所有位置
        tokenized_context = tokenizer.encode(context)
        ans_token_idx = []
        for idx, (start, end) in enumerate(tokenized_context.offsets):
            if sum(is_char_in_ans[start:end]) > 0:
                ans_token_idx.append(idx)
        if len(ans_token_idx) == 0:
            self.skip = True
            return

        start_token_idx = ans_token_idx[0]
        end_token_idx = ans_token_idx[-1]

        tokenized_question = tokenizer.encode(question)
        input_ids = tokenized_context.ids + tokenized_question.ids[1:]
        token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(tokenized_question.ids[1:])
        attention_mask = [1] * len(input_ids)

        padding_length = max_len - len(input_ids)
        if padding_length > 0:
            input_ids = input_ids + [0] * padding_length
            attention_mask = attention_mask + [0] * padding_length
            token_type_ids = token_type_ids + [0] * padding_length
        elif padding_length < 0:
            self.skip = True
            return

        self.input_ids = input_ids
        self.token_type_ids = token_type_ids
        self.attention_mask = attention_mask
        self.start_token_idx = start_token_idx
        self.end_token_idx = end_token_idx
        self.context_token_to_char = tokenized_context.offsets

with open(train_path) as f:
    raw_train_data = json.load(f)

with open(eval_path) as f:
    raw_eval_data = json.load(f)


def create_squad_examples(raw_data):
    squad_examples = []
    for item in tqdm(raw_data["data"]):
        for para in item["paragraphs"]:
            context = para["context"]
            for qa in para["qas"]:
                question = qa["question"]
                answer_text = qa["answers"][0]["text"]
                start_char_idx = qa["answers"][0]["answer_start"]
                all_answers = [_["text"] for _ in qa["answers"]]
                squad_eg = SquadExample(question, context, start_char_idx, answer_text, all_answers)
                squad_eg.preprocess()
                squad_examples.append(squad_eg)
    return squad_examples

def create_inputs_targets(squad_examples):
    dataset_dict = {
        "input_ids": [],
        "token_type_ids": [],
        "attention_mask": [],
        "start_token_idx": [],
        "end_token_idx": [],
    }
    for item in squad_examples:
        if item.skip == False:
            for key in dataset_dict:
                dataset_dict[key].append(getattr(item, key))
    for key in dataset_dict:
        dataset_dict[key] = np.array(dataset_dict[key])

    x = [ dataset_dict["input_ids"], dataset_dict["token_type_ids"],  dataset_dict["attention_mask"],]
    y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
    return x, y


In [4]:
train_squad_examples = create_squad_examples(raw_train_data)
x_train, y_train = create_inputs_targets(train_squad_examples)
print(f"{len(train_squad_examples)} 条训练样本")

eval_squad_examples = create_squad_examples(raw_eval_data)
x_eval, y_eval = create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} 条训练样本")

100%|████████████████████████████████████████████████████████████████████████████████| 442/442 [00:24<00:00, 18.00it/s]


87599 条训练样本


100%|██████████████████████████████████████████████████████████████████████████████████| 48/48 [00:03<00:00, 15.48it/s]


10570 条训练样本


In [7]:
print(tokenizer.encode("Hello, world!").tokens)
print(tokenizer.encode("Hello, world!").ids)
print(tokenizer.encode("Hello, world!").type_ids)
print(tokenizer.encode("Hello, world!").offsets)
print(tokenizer.encode("Hello, world!").attention_mask)
print(tokenizer.encode("Hello, world!").special_tokens_mask)
print(tokenizer.encode("Hello, world!").overflowing)

print(x_train[0][:1],"\n",x_train[1][:1],"\n",x_train[2][:1],)
print(y_train[0][:1], "\n",y_train[1][:1])

['[CLS]', 'hello', ',', 'world', '!', '[SEP]']
[101, 7592, 1010, 2088, 999, 102]
[0, 0, 0, 0, 0, 0]
[(0, 0), (0, 5), (5, 6), (7, 12), (12, 13), (0, 0)]
[1, 1, 1, 1, 1, 1]
[1, 0, 0, 0, 0, 1]
[]
[[  101  6549  2135  1010  1996  2082  2038  1037  3234  2839  1012 10234
   1996  2364  2311  1005  1055  2751  8514  2003  1037  3585  6231  1997
   1996  6261  2984  1012  3202  1999  2392  1997  1996  2364  2311  1998
   5307  2009  1010  2003  1037  6967  6231  1997  4828  2007  2608  2039
  14995  6924  2007  1996  5722  1000  2310  3490  2618  4748  2033 18168
   5267  1000  1012  2279  2000  1996  2364  2311  2003  1996 13546  1997
   1996  6730  2540  1012  3202  2369  1996 13546  2003  1996 24665 23052
   1010  1037 14042  2173  1997  7083  1998  9185  1012  2009  2003  1037
  15059  1997  1996 24665 23052  2012 10223 26371  1010  2605  2073  1996
   6261  2984 22353  2135  2596  2000  3002 16595  9648  4674  2061 12083
   9711  2271  1999  8517  1012  2012  1996  2203  1997  1996  2364

使用 `TFBertModel.from_pretrained()` 加载的 BERT 模型输出通常包含以下内容，具体取决于模型配置和输入参数。假设你使用的是 `BertModel`（即基础的 BERT 模型，不带任何任务头部，比如分类或问答），并且你提供了 `input_ids`、`token_type_ids` 和 `attention_mask`，则它的输出主要包括：

### 1. **`last_hidden_state` (最后隐藏层状态)**:
   - **类型**: 张量 (Tensor)，形状为 `(batch_size, sequence_length, hidden_size)`
   - **描述**: 这是 BERT 模型的最后一层的输出。对于每一个输入 token，它提供了一个 `hidden_size` 维的向量表示。`sequence_length` 是输入序列的长度，`hidden_size` 是 BERT 模型的隐藏层大小，通常为 768（对于 `bert-base-uncased`）。
   
   **用途**: 你可以将这个输出用于进一步的下游任务，如文本分类、命名实体识别、序列标注等。每个 token 的输出向量捕捉了其上下文语义。

### 2. **`pooler_output` (池化输出)**:
   - **类型**: 张量 (Tensor)，形状为 `(batch_size, hidden_size)`
   - **描述**: 这个是对句子级别表示的总结。它取的是 `[CLS]` token 的输出，并经过一个全连接层和激活函数（通常是 `tanh`）之后作为整个句子的表示向量。
   
   **用途**: 这个输出经常用于句子级别的任务，例如句子分类。它可以认为是对整个句子内容的一个全局表示。

### 具体代码示例和输出结果说明：

```python
from transformers import TFBertModel
import tensorflow as tf

# 假设模型已经下载到了本地目录
model = TFBertModel.from_pretrained("D:/bert-base-uncased")

# 输入数据示例
input_ids = tf.constant([[101, 7592, 1010, 2129, 2024, 2017, 102]])  # 输入序列的 token ids
token_type_ids = tf.constant([[0, 0, 0, 0, 0, 0, 0]])  # token 类型 ids，通常用于句子对任务
attention_mask = tf.constant([[1, 1, 1, 1, 1, 1, 1]])  # 注意力掩码，表示哪些 token 需要关注

# 模型前向传播，得到输出
outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

# 查看输出
last_hidden_state = outputs.last_hidden_state
pooler_output = outputs.pooler_output
```

### 输出内容:

1. **`last_hidden_state`**:
   - 形状: `(batch_size, sequence_length, hidden_size)`，即 `(1, 7, 768)`，因为输入序列长度是 7。
   - 内容: 这是每个 token 的隐藏状态。输出的每个元素是该 token 的上下文表示。

2. **`pooler_output`**:
   - 形状: `(batch_size, hidden_size)`，即 `(1, 768)`。
   - 内容: 这是 `[CLS]` token 的隐藏状态经过池化后的表示，是整个输入序列的全局表示。

### 输出总结:
- **`last_hidden_state`** 包含每个输入 token 在经过 BERT 模型后的上下文表示，是用于序列标注等任务的主要输出。
- **`pooler_output`** 则是对句子进行全局表示的输出，常用于句子级任务，如分类。

In [11]:
def create_model():
    ## BERT encoder
    encoder = TFBertModel.from_pretrained("D:/bert-base-uncased")

    ## QA Model
    input_ids = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32)
    token_type_ids = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32)
    attention_mask = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32)
    embedding = encoder(
        input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
    )[0]

    start_logits = tf.keras.layers.Dense(1, name="start_logit", use_bias=False)(embedding)
    start_logits = tf.keras.layers.Flatten()(start_logits)

    end_logits = tf.keras.layers.Dense(1, name="end_logit", use_bias=False)(embedding)
    end_logits = tf.keras.layers.Flatten()(end_logits)

    start_probs = tf.keras.layers.Activation(tf.keras.activations.softmax)(start_logits)
    end_probs = tf.keras.layers.Activation(tf.keras.activations.softmax)(end_logits)

    model = tf.keras.Model(
        inputs=[input_ids, token_type_ids, attention_mask],
        outputs=[start_probs, end_probs],
    )
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
    model.compile(optimizer=optimizer, loss=[loss, loss])
    return model

model = create_model()

Some layers from the model checkpoint at D:/bert-base-uncased were not used when initializing TFBertModel: ['mlm___cls', 'nsp___cls']
- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertModel were initialized from the model checkpoint at D:/bert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


In [13]:
def normalize_text(text):
    text = text.lower()
    exclude = set(string.punctuation)
    text = "".join(ch for ch in text if ch not in exclude)
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    text = re.sub(regex, " ", text)
    text = " ".join(text.split())
    return text


class ExactMatch(tf.keras.callbacks.Callback):
    def __init__(self, x_eval, y_eval):
        self.x_eval = x_eval
        self.y_eval = y_eval

    def on_epoch_end(self, epoch, logs=None):
        pred_start, pred_end = self.model.predict(self.x_eval)
        count = 0
        eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
        for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
            squad_eg = eval_examples_no_skip[idx]
            offsets = squad_eg.context_token_to_char
            start = np.argmax(start)
            end = np.argmax(end)
            if start >= len(offsets):
                continue
            pred_char_start = offsets[start][0]
            if end < len(offsets):
                pred_char_end = offsets[end][1]
                pred_ans = squad_eg.context[pred_char_start:pred_char_end]
            else:
                pred_ans = squad_eg.context[pred_char_start:]

            normalized_pred_ans = normalize_text(pred_ans)
            normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]
            if normalized_pred_ans in normalized_true_ans:
                count += 1
        acc = count / len(self.y_eval[0])
        print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")

In [15]:
exact_match_callback = ExactMatch(x_eval, y_eval)
model.fit(x_train, y_train, epochs=3, batch_size=8, callbacks=[exact_match_callback])

Epoch 1/3

epoch=1, exact match score=0.75
Epoch 2/3

epoch=2, exact match score=0.75
Epoch 3/3

epoch=3, exact match score=0.74


<keras.callbacks.History at 0x18e6411e530>

In [69]:
def test(n):
    test = [x_eval[0][:n], x_eval[1][:n], x_eval[2][:n]]
    pred_start, pred_end = model.predict(test)
    eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
    pred_result = []
    true_result = []
    for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
        squad_eg = eval_examples_no_skip[idx]
        offsets = squad_eg.context_token_to_char
        start = np.argmax(start)
        end = np.argmax(end)
        if start >= len(offsets):
            continue
        pred_char_start = offsets[start][0]
        if end < len(offsets):
            pred_char_end = offsets[end][1]
            pred_ans = squad_eg.context[pred_char_start:pred_char_end]
        else:
            pred_ans = squad_eg.context[pred_char_start:]
    
        normalized_pred_ans = normalize_text(pred_ans)
        pred_result.append(normalized_pred_ans)
        true_start = y_eval[0][idx]
        true_end = y_eval[1][idx]
        true_result.append(normalize_text(squad_eg.context[ offsets[true_start][0] :  offsets[true_end][1] ]))
    return pred_result[:n], true_result[:n]
for a,b in zip(*test(10)):
    print(f"预测：{a}，标签：{b}")
        


预测：denver broncos，标签：denver broncos
预测：carolina panthers，标签：carolina panthers
预测：levis stadium，标签：santa clara california
预测：denver broncos，标签：denver broncos
预测：golden，标签：golden
预测：arabic numerals，标签：golden anniversary
预测：february 7 2016，标签：february 7 2016
预测：american football conference，标签：american football conference
预测：arabic numerals，标签：golden anniversary
预测：american football conference，标签：american football conference
