In [14]:
import numpy as np
import mxnet as mx
import pandas as pd
import logging, time
from mxnet import gluon
from mxnet.metric import Accuracy, F1, CompositeEvalMetric
import gluonnlp as nlp
from gluonnlp.model import get_bert_model
from gluonnlp.data import BERTTokenizer
from bert import BERTClassifier
from bert_util import BERTDatasetTransform
from d2l import try_gpu
np.random.seed(9102)
mx.random.seed(9102)

### Set basic model params

In [15]:
# for training data
batch_size = 64
valid_batch_size = 64
lr, weight_decay = .01, .01
epsilon = 1e-6
nclass = 72
epochs = 5
warmup_ratio = .1
grad_clip = 1.
# for bert model loading
model_name = 'bert_12_768_12'
bert_dataset = 'wiki_cn_cased'
pretrained = True
log_interval = 1
ctx = try_gpu()
logging.getLogger().setLevel(logging.DEBUG)
logging.captureWarnings(True)

### Load bert model and build tokenizer

In [16]:
bert, vocab = get_bert_model(model_name=model_name, dataset_name=bert_dataset, \
                                 pretrained=pretrained, ctx=ctx, use_pooler=True, \
                                 use_decoder=False, use_classifier=False)

In [17]:
model = BERTClassifier(bert, dropout=.1, num_classes=nclass)
model.classifier.initialize(init=mx.init.Normal(.02), ctx=ctx)
loss_function = gluon.loss.SoftmaxCELoss()

In [18]:
model.cast('float16')
print(model)
model.hybridize(static_alloc=True)
loss_function.hybridize(static_alloc=True)
bert_tokenizer=BERTTokenizer(vocab, lower=False)

BERTClassifier(
  (bert): BERTModel(
    (encoder): BERTEncoder(
      (dropout_layer): Dropout(p = 0.1, axes=())
      (layer_norm): BERTLayerNorm(eps=1e-12, axis=-1, center=True, scale=True, in_channels=768)
      (transformer_cells): HybridSequential(
        (0): BERTEncoderCell(
          (dropout_layer): Dropout(p = 0.1, axes=())
          (attention_cell): MultiHeadAttentionCell(
            (_base_cell): DotProductAttentionCell(
              (_dropout_layer): Dropout(p = 0.1, axes=())
            )
            (proj_query): Dense(768 -> 768, linear)
            (proj_key): Dense(768 -> 768, linear)
            (proj_value): Dense(768 -> 768, linear)
          )
          (proj): Dense(768 -> 768, linear)
          (ffn): BERTPositionwiseFFN(
            (ffn_1): Dense(768 -> 3072, linear)
            (activation): GELU()
            (ffn_2): Dense(3072 -> 768, linear)
            (dropout_layer): Dropout(p = 0.1, axes=())
            (layer_norm): BERTLayerNorm(eps=1e-12, axis

### Load training and valid data to make dataloader

In [19]:
DATA_FOLDER = 'data/'
TRAIN_DATA = 'train.csv'
WORD_EMBED = 'sgns.weibo.bigram-char'
LABEL_FILE = 'train.label'
N_ROWS=100
train_df = pd.read_csv(DATA_FOLDER+TRAIN_DATA, sep='|', nrows=N_ROWS)
dataset =[ [row[0], row[1]] for _, row in train_df.iterrows()]
train_dataset, valid_dataset = nlp.data.train_valid_split(dataset, .1)
len(train_dataset), len(valid_dataset)

(90, 10)

In [20]:
# data transformation
max_len = 140 # actually each weibo post should be within 140 characters
labels = [i for i in range(72)]
trans = BERTDatasetTransform(bert_tokenizer, max_len, labels, pad=True, \
                             pair=False, label_dtype='int32')
data_train = train_dataset.transform(trans, lazy=False)
data_train_length = data_train.transform(
    lambda input_id, length, segment_id, label_id: length)
num_samples_train = len(data_train)
num_samples_train, data_train_length[-1]

(90, array(14, dtype=int32))

In [21]:
# bucket sampler
batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(),
    nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack('int32'))

batch_sampler = nlp.data.sampler.FixedBucketSampler(
    data_train_length, batch_size=batch_size, num_buckets=10, ratio=0, shuffle=True)
# dataloader
dataloader_train = gluon.data.DataLoader(
    dataset=data_train, num_workers=12, batch_sampler=batch_sampler, batchify_fn=batchify_fn)
data_valid = valid_dataset.transform(trans, lazy=False)
dataloader_valid = gluon.data.DataLoader(
    data_valid, batch_size=batch_size, num_workers=12, shuffle=False, batchify_fn=batchify_fn)

  str(unused_bucket_keys))



### Train and evaluate helper functions

In [22]:
def evaluate(dataloader_eval, metric):
    """Evaluate the model on validation dataset.
    """
    metric.reset()
    for _, seqs in enumerate(dataloader_eval):
        input_ids, valid_len, type_ids, label = seqs
        out = model(
            input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
            valid_len.astype('float16', copy=False).as_in_context(ctx))
        metric.update([label], [out])
    metric_nm, metric_val = metric.get()
    if not isinstance(metric_nm, list):
        metric_nm = [metric_nm]
        metric_val = [metric_val]
    metric_str = 'validation metrics:' + ','.join(
        [i + ':%.4f' for i in metric_nm])
    logging.info(metric_str, *metric_val)

In [23]:
def train(metric, train_data, dev_data, grad_clip, epochs):
    logging.info('Training bert model on %s', ctx)
    optimizer_params = {'learning_rate': lr, 'epsilon': epsilon, \
                        'wd': weight_decay, 'multi_precision': True}
    trainer = gluon.Trainer(model.collect_params(), 'adam', \
                            optimizer_params, update_on_kvstore=False)
    step_size = batch_size
    num_train_steps = round(num_samples_train / step_size * epochs)
    num_warmup_steps = round(num_train_steps * warmup_ratio)
    step_num = 0
    
    # do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = .0
    # collect differentiable params
    params = [
        p for p in model.collect_params().values() if p.grad_req != 'null'
    ]
    
    for epoch_id in range(epochs):
        metric.reset()
        step_loss = 0
        start = time.time()
        for batch_id, seqs in enumerate(train_data):
            step_num += 1
            # learning rate schedule
            if step_num < num_warmup_steps:
                new_lr = lr * step_num / num_warmup_steps
            else:
                offset = (step_num - num_warmup_steps) * lr / \
                         (num_train_steps - num_warmup_steps)
                try:
                    assert(offset<lr)
                except:
                    print(step_num)
                    break
                new_lr = lr - offset
            trainer.set_learning_rate(new_lr)
            # forward and backward
            with mx.autograd.record():
                input_ids, valid_length, type_ids, label = seqs
                out = model(
                            input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
                            valid_length.astype('float16', copy=False).as_in_context(ctx))
                ls = loss_function(out, label.as_in_context(ctx)).mean()
            ls.backward()
            # update
            trainer.allreduce_grads()
            nlp.utils.clip_grad_global_norm(params, grad_clip)
            trainer.update(1)
            step_loss += ls.asscalar()
            metric.update([label], [out])
            if (batch_id + 1) % (log_interval) == 0:
                metric_nm, metric_val = metric.get()
                if not isinstance(metric_nm, list):
                    metric_nm = [metric_nm]
                    metric_val = [metric_val]
                eval_str = '[Epoch %d Batch %d/%d] loss=%.4f, lr=%.7f, metrics=' + \
                    ','.join([i + ':%.4f' for i in metric_nm])
                logging.info(eval_str, epoch_id + 1, batch_id + 1, len(train_data),
                             step_loss / log_interval,
                             trainer.learning_rate, *metric_val)
                step_loss = 0
        mx.nd.waitall()
        evaluate(dev_data, metric)
        # save params
        token = str(round(time.time()))
        model.save_parameters('model/bert-'+token+'.params')
        logging.info('params saved as token %s', token)
        end = time.time()
        logging.info('Time cost=%.1fs', end - start)
        start = end

In [24]:
metric = CompositeEvalMetric()
metric.add(Accuracy())
metric

<mxnet.metric.CompositeEvalMetric at 0x7f25082cfa20>

### Training

In [25]:
train(metric, dataloader_train, dataloader_valid, grad_clip, epochs)

INFO:root:Training bert model on gpu(0)


MXNetError: [22:59:56] /home/travis/build/dmlc/mxnet-distro/mxnet-build/3rdparty/mshadow/mshadow/./././dot_engine-inl.h:571: Not implmented!

Stack trace returned 10 entries:
[bt] (0) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x40ba6a) [0x7f2541220a6a]
[bt] (1) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x40c081) [0x7f2541221081]
[bt] (2) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x3c10d22) [0x7f2544a25d22]
[bt] (3) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x415b393) [0x7f2544f70393]
[bt] (4) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(mxnet::imperative::PushFCompute(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)> const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&)::{lambda(mxnet::RunContext)#1}::operator()(mxnet::RunContext) const+0x2e8) [0x7f2543b84e78]
[bt] (5) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2cc1689) [0x7f2543ad6689]
[bt] (6) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2ccafc4) [0x7f2543adffc4]
[bt] (7) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2ccf2b3) [0x7f2543ae42b3]
[bt] (8) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2ccf506) [0x7f2543ae4506]
[bt] (9) /home/steven/miniconda3/envs/dl/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x2ccb6f4) [0x7f2543ae06f4]

