In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
from transformers.modeling_bert import BertConfig, BertModel

from salt.data.dataset.bert_dataset import BERTDataSet
from salt.data.tokenizer import SentencepieceTokenizer

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

In [3]:
##GPU 사용 시
device = torch.device("cuda:7")

In [12]:
# !wget https://www.dropbox.com/s/374ftkec978br3d/ratings_train.txt?dl=1
# !wget https://www.dropbox.com/s/977gbwh542gdy94/ratings_test.txt?dl=1

In [5]:
# bertmodel = BertModel(BertConfig.from_dict(bert_params))

In [7]:
from model.modeling_rezero import ElectraForPretrain
from salt.data.tokenizer import SentencepieceTokenizer

In [8]:
pretrain_model = ElectraForPretrain('./model_config/electra_small_config.json')
# binary_path = '/home/dmig/work/test_container/binary/electra/token_22000/small/models/electra_step_55000_loss10.827.pth'
# binary_path = '/home/dmig/work/test_container/binary/electra/token_22000/small/models/electra_step_122000_loss7.173.pth'
# binary_path = '/home/dmig/work/test_container/binary/electra/token_22000/small/models/electra_step_159000_loss10.434.pth'
binary_path = '/home/dmig/work/test_container/binary/electra/token_22000/small/models/electra_step_201000_loss10.286.pth'
checkpoint = torch.load(binary_path, map_location={'cuda:0': 'cpu'})
state_dict = checkpoint['model_state_dict']
state_dict = {k.replace('module.', ''): v for k,v in state_dict.items()}
# pretrain_model.load_state_dict(state_dict)
bertmodel = pretrain_model.dis_model.eval()

03/30/2020 17:00:33 - INFO - model.configuration_utils -   Model config ElectraConfig {
  "_num_labels": 2,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "embedding_size": 256,
  "eos_token_id": null,
  "finetuning_task": null,
  "generator_ratio": 4,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 64,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "min_length": 0,
  "model_type": "",
  "no_repeat_ngram_size": 0,
  "num_attention_heads": 1,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hid

In [13]:
tokenizer_path = '/home/dmig/work/test_container/tokenizer/tokenizer_22000.model'

In [14]:
tokenizer = SentencepieceTokenizer(tokenizer_path)

In [15]:
import pandas as pd

In [16]:
trainset = pd.read_csv('./ratings_train.txt?dl=1', sep='\t')
testset = pd.read_csv('./ratings_test.txt?dl=1', sep='\t')

In [17]:
trainset.shape

(150000, 3)

In [18]:
testset.shape

(50000, 3)

In [19]:
trainset = trainset.dropna(axis=0)
testset = testset.dropna(axis=0)

In [20]:
traindata = dict()
traindata['text'] =list()
traindata['label'] = trainset['label'].tolist()

In [21]:
testdata = dict()
testdata['text'] =list()
testdata['label'] = testset['label'].tolist()

In [22]:
from salt.preprocessor.normalizer import normalize

In [23]:
for i in trainset['document'].tolist():
    traindata['text'].append([normalize(i)])

In [24]:
for i in testset['document'].tolist():
    testdata['text'].append([normalize(i)])

In [25]:
## Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [26]:
train_data = BERTDataSet(data = traindata, text_idx = 'text', label_idx = 'label', tokenizer = tokenizer, max_seq_len=max_len, pad=True, pair=False)
test_data = BERTDataSet(data = testdata, text_idx = 'text', label_idx = 'label', tokenizer = tokenizer, max_seq_len=max_len, pad=True, pair=False)

[generating BERT dataset]: 100%|██████████| 149995/149995 [00:09<00:00, 15729.02it/s]
[generating BERT dataset]: 100%|██████████| 49997/49997 [00:03<00:00, 16093.12it/s]


In [27]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)



In [28]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=5)

In [29]:
model = BERTClassifier(bertmodel, hidden_size=256, dr_rate=0.5).to(device)

In [30]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [31]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [32]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [33]:
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [34]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [35]:
import math
def gen_attention_mask(token_ids, valid_length):
    attention_mask = torch.zeros_like(token_ids)
    for i, v in enumerate(valid_length):
        attention_mask[i][:v] = 1
    return attention_mask.float()


In [36]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 1 batch id 1 loss 0.7257020473480225 train acc 0.46875
epoch 1 batch id 201 loss 0.7310513257980347 train acc 0.5002332089552238
epoch 1 batch id 401 loss 0.6948056221008301 train acc 0.49551901496259354
epoch 1 batch id 601 loss 0.7033348679542542 train acc 0.4964642262895175
epoch 1 batch id 801 loss 0.7151772379875183 train acc 0.4968984082397004
epoch 1 batch id 1001 loss 0.6797340512275696 train acc 0.49740884115884115
epoch 1 batch id 1201 loss 0.658115804195404 train acc 0.4977362614487927
epoch 1 batch id 1401 loss 0.690140426158905 train acc 0.49852783725910066
epoch 1 batch id 1601 loss 0.7027692198753357 train acc 0.4981066520924422
epoch 1 batch id 1801 loss 0.6577820777893066 train acc 0.5007721404775125
epoch 1 batch id 2001 loss 0.5384700298309326 train acc 0.5226215017491255
epoch 1 batch id 2201 loss 0.3780318796634674 train acc 0.546690424806906

epoch 1 train acc 0.5630171541392174


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 1 test acc 0.8125630164273067


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 2 batch id 1 loss 0.7100959420204163 train acc 0.65625
epoch 2 batch id 201 loss 0.3885534703731537 train acc 0.8194962686567164
epoch 2 batch id 401 loss 0.3700013756752014 train acc 0.8182278678304239
epoch 2 batch id 601 loss 0.4059150815010071 train acc 0.8209494592346089
epoch 2 batch id 801 loss 0.438575804233551 train acc 0.8234433520599251
epoch 2 batch id 1001 loss 0.28438618779182434 train acc 0.8247065434565435
epoch 2 batch id 1201 loss 0.3687033951282501 train acc 0.8257311615320566
epoch 2 batch id 1401 loss 0.43567243218421936 train acc 0.8264297822983583
epoch 2 batch id 1601 loss 0.4300857186317444 train acc 0.8272076046221112
epoch 2 batch id 1801 loss 0.46287357807159424 train acc 0.8285761382565242
epoch 2 batch id 2001 loss 0.390890508890152 train acc 0.8304285357321339
epoch 2 batch id 2201 loss 0.2778278589248657 train acc 0.8326257950931395

epoch 2 train acc 0.8345213719838876


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 2 test acc 0.8388762172929373


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 3 batch id 1 loss 0.6023924350738525 train acc 0.734375
epoch 3 batch id 201 loss 0.33359652757644653 train acc 0.8558768656716418
epoch 3 batch id 401 loss 0.2919536828994751 train acc 0.8559850374064838
epoch 3 batch id 601 loss 0.3527991771697998 train acc 0.8555272462562395
epoch 3 batch id 801 loss 0.3955630362033844 train acc 0.8561953807740325
epoch 3 batch id 1001 loss 0.2924390435218811 train acc 0.8550668081918081
epoch 3 batch id 1201 loss 0.34994643926620483 train acc 0.8540929433805162
epoch 3 batch id 1401 loss 0.4000101387500763 train acc 0.8533302105638829
epoch 3 batch id 1601 loss 0.3715963065624237 train acc 0.8526116489693941
epoch 3 batch id 1801 loss 0.4654548168182373 train acc 0.8528334952803998
epoch 3 batch id 2001 loss 0.36421361565589905 train acc 0.8533858070964517
epoch 3 batch id 2201 loss 0.27078044414520264 train acc 0.8546470354384371

epoch 3 train acc 0.8556591172910548


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 3 test acc 0.8412954333071021


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 4 batch id 1 loss 0.5827285647392273 train acc 0.765625
epoch 4 batch id 201 loss 0.3150353729724884 train acc 0.8644278606965174
epoch 4 batch id 401 loss 0.27514544129371643 train acc 0.8651807980049875
epoch 4 batch id 601 loss 0.3201715350151062 train acc 0.8650686356073212
epoch 4 batch id 801 loss 0.373264878988266 train acc 0.8654611423220974
epoch 4 batch id 1001 loss 0.26721230149269104 train acc 0.8644948801198801
epoch 4 batch id 1201 loss 0.3310320973396301 train acc 0.8639935470441299
epoch 4 batch id 1401 loss 0.3817404806613922 train acc 0.8631334760885082
epoch 4 batch id 1601 loss 0.31175100803375244 train acc 0.8627908338538414
epoch 4 batch id 1801 loss 0.4340623617172241 train acc 0.8628973486951693
epoch 4 batch id 2001 loss 0.40219196677207947 train acc 0.8632949150424788
epoch 4 batch id 2201 loss 0.2521302402019501 train acc 0.8642236483416629

epoch 4 train acc 0.8653148315342488


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 4 test acc 0.8406160854810151


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 5 batch id 1 loss 0.5252546072006226 train acc 0.78125
epoch 5 batch id 201 loss 0.2842329740524292 train acc 0.8732120646766169
epoch 5 batch id 401 loss 0.2991602122783661 train acc 0.8747272443890274
epoch 5 batch id 601 loss 0.34600937366485596 train acc 0.8742200499168054
epoch 5 batch id 801 loss 0.31162580847740173 train acc 0.8743367665418227
epoch 5 batch id 1001 loss 0.2544440031051636 train acc 0.8736263736263736
epoch 5 batch id 1201 loss 0.33059778809547424 train acc 0.8731916111573689
epoch 5 batch id 1401 loss 0.3210889995098114 train acc 0.8720556745182013
epoch 5 batch id 1601 loss 0.2686362564563751 train acc 0.8713499375390381
epoch 5 batch id 1801 loss 0.416944682598114 train acc 0.8715470571904498
epoch 5 batch id 2001 loss 0.3444525897502899 train acc 0.8717750499750125
epoch 5 batch id 2201 loss 0.21955034136772156 train acc 0.8727070081781009

epoch 5 train acc 0.8736472760734979


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 5 test acc 0.8405161813889435


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 6 batch id 1 loss 0.4989190697669983 train acc 0.8125
epoch 6 batch id 201 loss 0.2507878541946411 train acc 0.8795864427860697
epoch 6 batch id 401 loss 0.25159454345703125 train acc 0.8815461346633416
epoch 6 batch id 601 loss 0.29091745615005493 train acc 0.8807456322795341
epoch 6 batch id 801 loss 0.29379281401634216 train acc 0.8816518414481898
epoch 6 batch id 1001 loss 0.20558911561965942 train acc 0.8808847402597403
epoch 6 batch id 1201 loss 0.3273458480834961 train acc 0.8800869067443797
epoch 6 batch id 1401 loss 0.29215070605278015 train acc 0.8793941827266238
epoch 6 batch id 1601 loss 0.2512838542461395 train acc 0.878864772017489
epoch 6 batch id 1801 loss 0.3709067106246948 train acc 0.8791036229872293
epoch 6 batch id 2001 loss 0.33287957310676575 train acc 0.8795680284857571
epoch 6 batch id 2201 loss 0.1947292685508728 train acc 0.8802887891867333

epoch 6 train acc 0.8810430763354233


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 6 test acc 0.8384181954554396


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 7 batch id 1 loss 0.5049951672554016 train acc 0.8125
epoch 7 batch id 201 loss 0.26472243666648865 train acc 0.8889148009950248
epoch 7 batch id 401 loss 0.2126031517982483 train acc 0.8903132793017456
epoch 7 batch id 601 loss 0.27130070328712463 train acc 0.8891170965058236
epoch 7 batch id 801 loss 0.2773917317390442 train acc 0.8901178214731585
epoch 7 batch id 1001 loss 0.22347691655158997 train acc 0.8892357642357642
epoch 7 batch id 1201 loss 0.3346298933029175 train acc 0.8877758118234804
epoch 7 batch id 1401 loss 0.2848488688468933 train acc 0.8871564953604568
epoch 7 batch id 1601 loss 0.22445005178451538 train acc 0.8867992660836976
epoch 7 batch id 1801 loss 0.3938150405883789 train acc 0.8870939755691283
epoch 7 batch id 2001 loss 0.3146064877510071 train acc 0.8872829210394803
epoch 7 batch id 2201 loss 0.20497435331344604 train acc 0.8880125511131304

epoch 7 train acc 0.8887456655786173


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 7 test acc 0.8376589243556954


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 8 batch id 1 loss 0.5074476599693298 train acc 0.8125
epoch 8 batch id 201 loss 0.2426416277885437 train acc 0.8929570895522388
epoch 8 batch id 401 loss 0.16904719173908234 train acc 0.8954566708229427
epoch 8 batch id 601 loss 0.2585350275039673 train acc 0.8938487936772047
epoch 8 batch id 801 loss 0.21219070255756378 train acc 0.8941362359550562
epoch 8 batch id 1001 loss 0.16813042759895325 train acc 0.893419080919081
epoch 8 batch id 1201 loss 0.32563599944114685 train acc 0.8925895087427144
epoch 8 batch id 1401 loss 0.2853933572769165 train acc 0.8923202177016417
epoch 8 batch id 1601 loss 0.16870994865894318 train acc 0.8915911930043723
epoch 8 batch id 1801 loss 0.3742367923259735 train acc 0.8917268184342032
epoch 8 batch id 2001 loss 0.3720835745334625 train acc 0.8917182033983009
epoch 8 batch id 2201 loss 0.21287745237350464 train acc 0.8924068605179464

epoch 8 train acc 0.8931351198507818


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 8 test acc 0.8366798642533936


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 9 batch id 1 loss 0.516160249710083 train acc 0.8125
epoch 9 batch id 201 loss 0.22874940931797028 train acc 0.9005752487562189
epoch 9 batch id 401 loss 0.16607902944087982 train acc 0.90169108478803
epoch 9 batch id 601 loss 0.21741262078285217 train acc 0.9004783693843594
epoch 9 batch id 801 loss 0.1916283220052719 train acc 0.9011001872659176
epoch 9 batch id 1001 loss 0.14797671139240265 train acc 0.8995535714285714
epoch 9 batch id 1201 loss 0.34800055623054504 train acc 0.8981317651956703
epoch 9 batch id 1401 loss 0.2961972653865814 train acc 0.8974616345467523
epoch 9 batch id 1601 loss 0.18215464055538177 train acc 0.897007729544035
epoch 9 batch id 1801 loss 0.33541637659072876 train acc 0.897296640755136
epoch 9 batch id 2001 loss 0.3324943482875824 train acc 0.8974184782608695
epoch 9 batch id 2201 loss 0.19261199235916138 train acc 0.8980577010449795

epoch 9 train acc 0.8987878502262083


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 9 test acc 0.8358806315168207


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 10 batch id 1 loss 0.4459599256515503 train acc 0.8125
epoch 10 batch id 201 loss 0.18447867035865784 train acc 0.9044620646766169
epoch 10 batch id 401 loss 0.1373143196105957 train acc 0.9058213840399002
epoch 10 batch id 601 loss 0.22898145020008087 train acc 0.9043781198003328
epoch 10 batch id 801 loss 0.18592312932014465 train acc 0.9045138888888888
epoch 10 batch id 1001 loss 0.17885860800743103 train acc 0.9032842157842158
epoch 10 batch id 1201 loss 0.32378682494163513 train acc 0.9021518526228143
epoch 10 batch id 1401 loss 0.3038676381111145 train acc 0.9014431655960029
epoch 10 batch id 1601 loss 0.17385226488113403 train acc 0.9010188944409744
epoch 10 batch id 1801 loss 0.3726145625114441 train acc 0.901079261521377
epoch 10 batch id 2001 loss 0.2546845078468323 train acc 0.9011041354322838
epoch 10 batch id 2201 loss 0.22223028540611267 train acc 0.9015291344843253

epoch 10 train acc 0.9021374154198746


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 10 test acc 0.8359205931536494
