In [1]:
# 为环境变量做好准备
import os
from io import StringIO
from dotenv import load_dotenv

env_text = """
HOME=./eztransfer_modelzoo
"""

load_dotenv(stream=StringIO(env_text), override=True, verbose=True)
os.environ.get("HOME")

'./eztransfer_modelzoo'

In [2]:
# 让本地的 easytransfer 的优先级更高
import sys
sys.path.insert(0, os.path.abspath('../'))
sys.path

['d:\\code\\github\\EasyTransfer',
 'd:\\code\\github\\EasyTransfer\\examples',
 'c:\\Users\\tzh\\.vscode\\extensions\\ms-toolsai.jupyter-2021.11.1001550889\\pythonFiles',
 'c:\\Users\\tzh\\.vscode\\extensions\\ms-toolsai.jupyter-2021.11.1001550889\\pythonFiles\\lib\\python',
 'C:\\Anaconda3\\envs\\nlp\\python37.zip',
 'C:\\Anaconda3\\envs\\nlp\\DLLs',
 'C:\\Anaconda3\\envs\\nlp\\lib',
 'C:\\Anaconda3\\envs\\nlp',
 '',
 'C:\\Users\\tzh\\AppData\\Roaming\\Python\\Python37\\site-packages',
 'C:\\Anaconda3\\envs\\nlp\\lib\\site-packages',
 'C:\\Anaconda3\\envs\\nlp\\lib\\site-packages\\win32',
 'C:\\Anaconda3\\envs\\nlp\\lib\\site-packages\\win32\\lib',
 'C:\\Anaconda3\\envs\\nlp\\lib\\site-packages\\Pythonwin',
 'C:\\Anaconda3\\envs\\nlp\\lib\\site-packages\\IPython\\extensions',
 'C:\\Users\\tzh\\.ipython']

In [3]:
# 导入 easytransfer, 并验证, 因为原始的没有 __version__ 这个参数
import importlib
import easytransfer
importlib.reload(easytransfer)
easytransfer.__version__

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


INFO:tensorflow:*********** tf.__version__ is 1.13.2 ******


'0.1.4'

In [20]:
# jupyter 特殊操作
old_argv = sys.argv
print(old_argv)
sys.argv = old_argv[:1]

['C:\\Anaconda3\\envs\\nlp\\lib\\site-packages\\ipykernel_launcher.py', '--ip=127.0.0.1', '--stdin=9008', '--control=9006', '--hb=9005', '--Session.signature_scheme="hmac-sha256"', '--Session.key=b"ade30d99-008a-4d84-ae87-1f16c96a81d3"', '--shell=9007', '--transport="tcp"', '--iopub=9009', '--f=C:\\Users\\tzh\\AppData\\Local\\Temp\\tmp-9924mxh47ph2ELli.json']


In [4]:
import json

from easytransfer import Config, base_model, layers, model_zoo, preprocessors
from easytransfer.datasets import CSVReader, CSVWriter
from easytransfer.evaluators import classification_eval_metrics
from easytransfer.losses import softmax_cross_entropy

import tensorflow as tf

In [5]:
class TextClassification(base_model):
    """
    定义文本分类模型
    """

    def __init__(self, **kwargs):
        super(TextClassification, self).__init__(**kwargs)
        self.user_defined_config = kwargs["user_defined_config"]

    def build_logits(self, features, mode=None):
        """构图

        Args:
            features ([type]): [description]
            mode ([type], optional): [description]. Defaults to None.

        Returns:
            [type]: [description]
        """
        # 负责对原始数据进行预处理，生成模型需要的特征，比如：input_ids, input_mask, segment_ids等
        preprocessor = preprocessors.get_preprocessor(
            self.pretrain_model_name_or_path, user_defined_config=self.user_defined_config
        )
        # 负责构建网络的backbone
        model = model_zoo.get_pretrained_model(self.pretrain_model_name_or_path)

        dense = layers.Dense(self.num_labels, kernel_initializer=layers.get_initializer(0.02), name="dense")
        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)
        _, pooled_output = model([input_ids, input_mask, segment_ids], mode=mode)
        logits = dense(pooled_output)

        # 用于 continue finetune
        # self.check_and_init_from_checkpoint(mode)
        return logits, label_ids

    def build_loss(self, logits, labels):
        """定义损失函数

        Args:
            logits ([type]): logits returned from build_logits
            labels ([type]): labels returned from build_logits

        Returns:
            [type]: [description]
        """
        return softmax_cross_entropy(labels, self.num_labels, logits)

    def build_eval_metrics(self, logits, labels):
        """定义评估指标

        Args:
            logits ([type]): logits returned from build_logits
            labels ([type]): labels returned from build_logits

        Returns:
            [type]: [description]
        """
        return classification_eval_metrics(logits, labels, self.num_labels)

    def build_predictions(self, output):
        """定义预测输出

        Args:
            output ([type]): returned from build_logits

        Returns:
            [type]: [description]
        """
        logits, _ = output
        predictions = dict()
        index = tf.argmax(logits, axis=-1, output_type=tf.int32)
        predictions["predict_index"] = index
        predictions["predict_softmax"] = tf.nn.softmax(logits)
        # 核心是理解 shape, 最后一维才是类别数量, 第一个维度是 batch_size
        predictions["predict_prob"] = tf.gather(tf.nn.softmax(logits), index, axis=-1)
        return predictions

In [6]:
def train(config_json):
    config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json)
    app = TextClassification(user_defined_config=config)

    train_reader = CSVReader(
        input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size,
    )
    eval_reader = CSVReader(
        input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size,
    )

    app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)


def predict(config_json):
    config = Config(mode="predict_on_the_fly", config_json=config_json)
    app = TextClassification(user_defined_config=config)

    pred_reader = CSVReader(
        input_glob=app.predict_input_fp,
        is_training=False,
        input_schema=app.input_schema,
        batch_size=app.predict_batch_size,
    )
    pred_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema)

    result = app.run_predict(reader=pred_reader, writer=None, checkpoint_path=app.predict_checkpoint_path)
    for row in result:
        print(row)
        break


def evaluate(config_json):
    config = Config(mode="evaluate_on_the_fly", config_json=config_json)
    app = TextClassification(user_defined_config=config)

    eval_reader = CSVReader(
        input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size,
    )

    result = app.run_evaluate(reader=eval_reader, checkpoint_path=app.eval_ckpt_path)
    print(result)


In [21]:
config_json = json.load(open("user_config.json", "r", encoding="utf-8"))
train(config_json)

INFO:tensorflow:***************** modelZooBasePath ./eztransfer_modelzoo\.eztransfer_modelzoo ***************
INFO:tensorflow:total number of training examples 53360
INFO:tensorflow:***********Running in train_and_evaluate_on_the_fly mode***********
INFO:tensorflow:***********Disable Tao***********
INFO:tensorflow:***********NCCL_IB_DISABLE 0***********
INFO:tensorflow:***********NCCL_P2P_DISABLE 0***********
INFO:tensorflow:***********NCCL_SHM_DISABLE 0***********
INFO:tensorflow:***********NCCL_MAX_NRINGS 4***********
INFO:tensorflow:***********NCCL_MIN_NRINGS 2***********
INFO:tensorflow:***********NCCL_LAUNCH_MODE PARALLEL***********
INFO:tensorflow:***********TF_JIT_PROFILING False***********
INFO:tensorflow:***********PAI_ENABLE_HLO_DUMPER False***********
INFO:tensorflow:***********Single worker, Single gpu, Don't use distribution strategy***********
INFO:tensorflow:model_dir: model_dir
INFO:tensorflow:num workers: 1
INFO:tensorflow:num gpus: 1
INFO:tensorflow:learning rate: 1e-

KeyboardInterrupt: 