In [24]:
import warnings
warnings.filterwarnings('ignore')

import io
import random
import numpy as np
import mxnet as mx
import gluonnlp as nlp
from sentence_embedding.bert import data, model

np.random.seed(100)
random.seed(100)
mx.random.seed(10000)

ctx = mx.cpu()

In [None]:
def model4(testfilename,params_saved ):

    from sentence_embedding.bert import model
    bert_base, vocabulary = nlp.model.get_model('bert_12_768_12',
                                                 dataset_name='book_corpus_wiki_en_uncased',
                                                 pretrained=True, ctx=ctx, use_pooler=True,
                                                 use_decoder=False, use_classifier=False)

    model = model.classification.BERTClassifier(bert_base, num_classes=2, dropout=0.1)
    # only need to initialize the classifier layer.
    model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
    model.hybridize(static_alloc=True)

    # softmax cross entropy loss for classification
    loss_function = mx.gluon.loss.SoftmaxCELoss()
    loss_function.hybridize(static_alloc=True)

    metric = mx.metric.Accuracy()

    
    nlp.utils.load_parameters(model, params_saved)
    
    #test
    num_discard_samples = 0
    # by comma
    field_separator = nlp.data.Splitter('\t')
    # Fields to select from the file
    field_indices = [0] #[3,0]
    data_train_raw = nlp.data.TSVDataset(testfilename,
                                     field_separator=field_separator,
                                     num_discard_samples=num_discard_samples,
                                     field_indices=field_indices)
    # Sentence A & target
    sample_id=0
    print(data_train_raw[sample_id][0])
    
    # Use the vocabulary from pre-trained model for tokenization
    bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=True)

    # The maximum length of an input sequence
    max_len = 128

    # The labels for the 4 classes
#     all_labels = ["0", "1"]
    transform = data.transform.BERTDatasetTransform(bert_tokenizer, max_len,
#                                                     class_labels=all_labels,
                                                    has_label=False,
                                                    pad=True,
                                                    pair=False)
    print('vocabulary used for tokenization = \n%s'%vocabulary)
    print('%s token id = %s'%(vocabulary.padding_token, vocabulary[vocabulary.padding_token]))
    print('%s token id = %s'%(vocabulary.cls_token, vocabulary[vocabulary.cls_token]))
    print('%s token id = %s'%(vocabulary.sep_token, vocabulary[vocabulary.sep_token]))

    data_train = data_train_raw.transform(transform)

    batch_size = 32

    train_sampler = nlp.data.FixedBucketSampler(lengths=[int(item[1]) for item in data_train],
                                            batch_size=batch_size,
                                            shuffle=True)
    bert_dataloader = mx.gluon.data.DataLoader(data_train, batch_sampler=train_sampler)

    metric.reset()       
        
    results = []
    for _, seqs in enumerate(bert_dataloader):
        input_ids, valid_length, type_ids = seqs
        out = model(input_ids.as_in_context(ctx),
                    type_ids.as_in_context(ctx),
                    valid_length.astype('float32').as_in_context(ctx))
        
        indices = mx.nd.topk(out, k=1, ret_typ='indices', dtype='int32').asnumpy()
        for index in indices:
            results.append(int(index))

    return results


testfilename ="C:/Users/nwang/Desktop/nlp/bert/data/q2_balance_113_5/4classes/dev.tsv"
params_saved=['C:/Users/nwang/Desktop/nlp/bert/data/q2_balance_113_5/4classes/model_1',
              'C:/Users/nwang/Desktop/nlp/bert/data/q2_balance_113_5/4classes/model_2',
              'C:/Users/nwang/Desktop/nlp/bert/data/q2_balance_113_5/4classes/model_3',
              'C:/Users/nwang/Desktop/nlp/bert/data/q2_balance_113_5/4classes/model_4']

c=0
for pa in params_saved:
    if c==0:
        one=model4(testfilename,pa)
        one=np.array(one).reshape((len(one),1))
        c=2
    else:
        tmp=model4(testfilename,pa)
        tmp=np.array(tmp).reshape((len(one),1))
        one=np.hstack((one,tmp))
        
one
