In [1]:
# 为环境变量做好准备
import os
from io import StringIO
from dotenv import load_dotenv

env_text = """
HOME=./eztransfer_modelzoo
"""

load_dotenv(stream=StringIO(env_text), override=True, verbose=True)
os.environ.get("HOME")

'./eztransfer_modelzoo'

In [2]:
# 让本地的 easytransfer 的优先级更高
import sys
sys.path.insert(0, os.path.abspath('../'))
sys.path

['d:\\code\\github\\EasyTransfer',
 'd:\\code\\github\\EasyTransfer\\examples',
 'c:\\Users\\tzh\\.vscode\\extensions\\ms-toolsai.jupyter-2021.11.1001550889\\pythonFiles\\vscode_datascience_helpers',
 'c:\\Users\\tzh\\.vscode\\extensions\\ms-toolsai.jupyter-2021.11.1001550889\\pythonFiles',
 'c:\\Users\\tzh\\.vscode\\extensions\\ms-toolsai.jupyter-2021.11.1001550889\\pythonFiles\\lib\\python',
 'C:\\Anaconda3\\envs\\tf2\\python39.zip',
 'C:\\Anaconda3\\envs\\tf2\\DLLs',
 'C:\\Anaconda3\\envs\\tf2\\lib',
 'C:\\Anaconda3\\envs\\tf2',
 '',
 'C:\\Anaconda3\\envs\\tf2\\lib\\site-packages',
 'C:\\Anaconda3\\envs\\tf2\\lib\\site-packages\\win32',
 'C:\\Anaconda3\\envs\\tf2\\lib\\site-packages\\win32\\lib',
 'C:\\Anaconda3\\envs\\tf2\\lib\\site-packages\\Pythonwin']

In [3]:
# 导入 easytransfer, 并验证, 因为原始的没有 __version__ 这个参数
import importlib
import easytransfer as easytransfer
importlib.reload(easytransfer)
easytransfer.__version__

INFO:tensorflow:*********** tf.__version__ is 2.7.0 ******


'1.0.0'

In [4]:
# jupyter 特殊操作
old_argv = sys.argv
print(old_argv)
sys.argv = old_argv[:1]

['ipykernel_launcher', '--ip=127.0.0.1', '--stdin=9008', '--control=9006', '--hb=9005', '--Session.signature_scheme="hmac-sha256"', '--Session.key=b"56db95a5-7993-4fcb-85d8-2ebb311f9f5a"', '--shell=9007', '--transport="tcp"', '--iopub=9009', '--f=C:\\Users\\tzh\\AppData\\Local\\Temp\\tmp-45620OAqltQMMcAn2.json']


In [5]:
import json

from easytransfer import Config, base_model, layers, model_zoo, preprocessors
from easytransfer.datasets import CSVReader, CSVWriter
from easytransfer.evaluators import classification_eval_metrics
from easytransfer.losses import softmax_cross_entropy

import tensorflow as tf

tf.random.set_seed(32)

In [6]:
class TextClassification(base_model):
    """
    定义文本分类模型
    """

    def __init__(self, **kwargs):
        super(TextClassification, self).__init__(**kwargs)
        self.user_defined_config = kwargs["user_defined_config"]

    def build_logits(self, features, mode=None):
        """构图

        Args:
            features ([type]): [description]
            mode ([type], optional): [description]. Defaults to None.

        Returns:
            [type]: [description]
        """
        # 负责对原始数据进行预处理，生成模型需要的特征，比如：input_ids, input_mask, segment_ids等
        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path, user_defined_config=self.user_defined_config
        )
        # 负责构建网络的backbone
        model = model_zoo.get_pretrained_model(self.pretrain_model_name_or_path)

        dense = layers.Dense(self.num_labels, kernel_initializer=layers.get_initializer(0.02), name="dense")
        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)
        _, pooled_output = model([input_ids, input_mask, segment_ids], mode=mode)
        logits = dense(pooled_output)

        # 用于 continue finetune
        # self.check_and_init_from_checkpoint(mode)
        return logits, label_ids

    def build_loss(self, logits, labels):
        """定义损失函数

        Args:
            logits ([type]): logits returned from build_logits
            labels ([type]): labels returned from build_logits

        Returns:
            [type]: [description]
        """
        return softmax_cross_entropy(labels, self.num_labels, logits)

    def build_eval_metrics(self, logits, labels):
        """定义评估指标

        Args:
            logits ([type]): logits returned from build_logits
            labels ([type]): labels returned from build_logits

        Returns:
            [type]: [description]
        """
        return classification_eval_metrics(logits, labels, self.num_labels)

    def build_predictions(self, output):
        """定义预测输出

        Args:
            output ([type]): returned from build_logits

        Returns:
            [type]: [description]
        """
        logits, _ = output
        predictions = dict()
        index = tf.argmax(logits, axis=-1, output_type=tf.int32)
        predictions["predict_index"] = index
        predictions["predict_softmax"] = tf.nn.softmax(logits)
        # 核心是理解 shape, 最后一维才是类别数量, 第一个维度是 batch_size
        predictions["predict_prob"] = tf.gather(tf.nn.softmax(logits), index, axis=-1)
        return predictions

In [7]:
def train(config_json):
    config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json)
    app = TextClassification(user_defined_config=config)

    train_reader = CSVReader(
        input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size,
    )
    eval_reader = CSVReader(
        input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size,
    )

    app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)


def predict(config_json):
    config = Config(mode="predict_on_the_fly", config_json=config_json)
    app = TextClassification(user_defined_config=config)

    pred_reader = CSVReader(
        input_glob=app.predict_input_fp,
        is_training=False,
        input_schema=app.input_schema,
        batch_size=app.predict_batch_size,
    )
    pred_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema)

    result = app.run_predict(reader=pred_reader, writer=None, checkpoint_path=app.predict_checkpoint_path)
    for row in result:
        print(row)
        break


def evaluate(config_json):
    config = Config(mode="evaluate_on_the_fly", config_json=config_json)
    app = TextClassification(user_defined_config=config)

    eval_reader = CSVReader(
        input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size,
    )

    result = app.run_evaluate(reader=eval_reader, checkpoint_path=app.eval_ckpt_path)
    print(result)


In [8]:
config_json = json.load(open("user_config_tf2.json", "r", encoding="utf-8"))
train(config_json)

INFO:tensorflow:***************** modelZooBasePath ./eztransfer_modelzoo\.eztransfer_modelzoo ***************
INFO:tensorflow:total number of training examples 1000
INFO:tensorflow:***********Running in train_and_evaluate_on_the_fly mode***********
INFO:tensorflow:***********Disable Tao***********
INFO:tensorflow:***********NCCL_IB_DISABLE 0***********
INFO:tensorflow:***********NCCL_P2P_DISABLE 0***********
INFO:tensorflow:***********NCCL_SHM_DISABLE 0***********
INFO:tensorflow:***********NCCL_MAX_NRINGS 4***********
INFO:tensorflow:***********NCCL_MIN_NRINGS 2***********
INFO:tensorflow:***********NCCL_LAUNCH_MODE PARALLEL***********
INFO:tensorflow:***********TF_JIT_PROFILING False***********
INFO:tensorflow:***********PAI_ENABLE_HLO_DUMPER False***********
INFO:tensorflow:***********Single worker, Single gpu, Don't use distribution strategy***********
INFO:tensorflow:model_dir: model_dir_tf2
INFO:tensorflow:num workers: 1
INFO:tensorflow:num gpus: 1
INFO:tensorflow:learning rate: 

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 31: model_dir_tf2\model.ckpt-31
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 62...
INFO:tensorflow:Saving checkpoints for 62 into model_dir_tf2\model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 62...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Load weights from D:/code/py_nlp_classify/model/bert/google-bert-base-zh/bert_model.ckpt
INFO:tensorflow:empty data to evaluate
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-30T17:40:46
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from model_dir_tf2\model.ckpt-62
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 15.56522s
INFO:tensorflow:Finished evaluation at 2022-01-30-17:41:01
INFO:tensorflow:Saving dict for global step 62: global_step = 62, loss = 2.3540769, py_accuracy = 0.34, py_macr

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 93...
INFO:tensorflow:Saving checkpoints for 93 into model_dir_tf2\model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 93...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Load weights from D:/code/py_nlp_classify/model/bert/google-bert-base-zh/bert_model.ckpt
INFO:tensorflow:empty data to evaluate
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-30T17:47:30
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from model_dir_tf2\model.ckpt-93
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 14.61137s
INFO:tensorflow:Finished evaluation at 2022-01-30-17:47:45
INFO:tensorflow:Saving dict for global step 93: global_step = 93, loss = 2.188513, py_accuracy = 0.34, py_macro_f1 = 0.26095536, py_micro_f1 = 0.34, py_weighted_f1 = 0.29991698
INFO:tensorflow:Saving 'checkpo

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


INFO:tensorflow:global_step/sec: 0.083921
INFO:tensorflow:loss = 2.087132, step = 100 (1191.605 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 124...
INFO:tensorflow:Saving checkpoints for 124 into model_dir_tf2\model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 124...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Load weights from D:/code/py_nlp_classify/model/bert/google-bert-base-zh/bert_model.ckpt
INFO:tensorflow:empty data to evaluate
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-30T17:54:42
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from model_dir_tf2\model.ckpt-124
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 14.28321s
INFO:tensorflow:Finished evaluation at 2022-01-30-17:54:56
INFO:tensorflow:Saving dict for global step 124: global_step = 124, loss = 2.09781, py_accuracy = 0.38,

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 155...
INFO:tensorflow:Saving checkpoints for 155 into model_dir_tf2\model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 155...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Load weights from D:/code/py_nlp_classify/model/bert/google-bert-base-zh/bert_model.ckpt
INFO:tensorflow:empty data to evaluate
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-30T18:00:37
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from model_dir_tf2\model.ckpt-155
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 14.89519s
INFO:tensorflow:Finished evaluation at 2022-01-30-18:00:51
INFO:tensorflow:Saving dict for global step 155: global_step = 155, loss = 2.05618, py_accuracy = 0.43, py_macro_f1 = 0.3517705, py_micro_f1 = 0.43, py_weighted_f1 = 0.3988156
INFO:tensorflow:Saving 'chec

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 157...
INFO:tensorflow:Saving checkpoints for 157 into model_dir_tf2\model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 157...
INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (100 secs).
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Load weights from D:/code/py_nlp_classify/model/bert/google-bert-base-zh/bert_model.ckpt
INFO:tensorflow:empty data to evaluate
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-01-30T18:01:24
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from model_dir_tf2\model.ckpt-157
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 16.43908s
INFO:tensorflow:Finished evaluation at 2022-01-30-18:01:40
INFO:tensorflow:Saving dict for global step 157: global_step = 157, loss = 2.0560822, py_accuracy = 0.42, py_macro_f1 = 0.

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


INFO:tensorflow:Loss for final step: 1.9419066.


In [10]:
config_json["evaluate_config"]["eval_checkpoint_path"] = config_json["evaluate_config"]["eval_checkpoint_path"].format(157)
evaluate(config_json)
# {'loss': 2.0560822, 'py_accuracy': 0.42, 'py_macro_f1': 0.33929557, 'py_micro_f1': 0.42, 'py_weighted_f1': 0.38653803, 'global_step': 157}

INFO:tensorflow:***************** modelZooBasePath ./eztransfer_modelzoo\.eztransfer_modelzoo ***************
INFO:tensorflow:num eval steps: None
INFO:tensorflow:***********Running in evaluate_on_the_fly mode***********
INFO:tensorflow:Using config: {'_model_dir': 'C:\\Users\\tzh\\AppData\\Local\\Temp\\tmpitjpdibk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': intra_op_parallelism_threads: 1024
inter_op_parallelism_threads: 1024
gpu_options {
  per_process_gpu_memory_fraction: 1.0
  allow_growth: true
  force_gpu_compatible: true
}
allow_soft_placement: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': Tr

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 157: model_dir_tf2/model.ckpt-157
{'loss': 2.0560822, 'py_accuracy': 0.42, 'py_macro_f1': 0.33929557, 'py_micro_f1': 0.42, 'py_weighted_f1': 0.38653803, 'global_step': 157}
