In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
from transformers.modeling_bert import BertConfig, BertModel

from salt.data.dataset.bert_dataset import BERTDataSet
from salt.data.tokenizer import SentencepieceTokenizer

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

In [3]:
##GPU 사용 시
device = torch.device("cuda:7")

In [4]:
# !wget https://www.dropbox.com/s/374ftkec978br3d/ratings_train.txt?dl=1
# !wget https://www.dropbox.com/s/977gbwh542gdy94/ratings_test.txt?dl=1

In [5]:
# bertmodel = BertModel(BertConfig.from_dict(bert_params))

In [6]:
from model.modeling_electra import ElectraForPretrain
from salt.data.tokenizer import SentencepieceTokenizer

03/30/2020 11:57:47 - INFO - model.file_utils -   PyTorch version 1.4.0 available.


In [7]:
pretrain_model = ElectraForPretrain('./model_config/electra_small_config.json')
# binary_path = '/home/dmig/work/test_container/binary/electra/token_22000/small/models/electra_step_55000_loss10.827.pth'
# binary_path = '/home/dmig/work/test_container/binary/electra/token_22000/small/models/electra_step_122000_loss7.173.pth'
# binary_path = '/home/dmig/work/test_container/binary/electra/token_22000/small/models/electra_step_159000_loss10.434.pth'
binary_path = '/home/dmig/work/test_container/binary/electra/token_22000/small/models/electra_step_201000_loss10.286.pth'
checkpoint = torch.load(binary_path, map_location={'cuda:0': 'cpu'})
state_dict = checkpoint['model_state_dict']
state_dict = {k.replace('module.', ''): v for k,v in state_dict.items()}
pretrain_model.load_state_dict(state_dict)
bertmodel = pretrain_model.dis_model.eval()

03/30/2020 11:58:17 - INFO - model.configuration_utils -   Model config ElectraConfig {
  "_num_labels": 2,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "embedding_size": 256,
  "eos_token_id": null,
  "finetuning_task": null,
  "generator_ratio": 4,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 64,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "min_length": 0,
  "model_type": "",
  "no_repeat_ngram_size": 0,
  "num_attention_heads": 1,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hid

In [8]:
tokenizer_path = '/home/dmig/work/test_container/tokenizer/tokenizer_22000.model'

In [9]:
tokenizer = SentencepieceTokenizer(tokenizer_path)

In [10]:
import pandas as pd

In [11]:
trainset = pd.read_csv('./ratings_train.txt?dl=1', sep='\t')
testset = pd.read_csv('./ratings_test.txt?dl=1', sep='\t')

In [12]:
trainset.shape

(150000, 3)

In [13]:
testset.shape

(50000, 3)

In [14]:
trainset = trainset.dropna(axis=0)
testset = testset.dropna(axis=0)

In [15]:
traindata = dict()
traindata['text'] =list()
traindata['label'] = trainset['label'].tolist()

In [16]:
testdata = dict()
testdata['text'] =list()
testdata['label'] = testset['label'].tolist()

In [17]:
from salt.preprocessor.normalizer import normalize

In [18]:
for i in trainset['document'].tolist():
    traindata['text'].append([normalize(i)])

In [19]:
for i in testset['document'].tolist():
    testdata['text'].append([normalize(i)])

In [20]:
## Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [21]:
train_data = BERTDataSet(data = traindata, text_idx = 'text', label_idx = 'label', tokenizer = tokenizer, max_seq_len=max_len, pad=True, pair=False)
test_data = BERTDataSet(data = testdata, text_idx = 'text', label_idx = 'label', tokenizer = tokenizer, max_seq_len=max_len, pad=True, pair=False)

[generating BERT dataset]: 100%|██████████| 149995/149995 [00:09<00:00, 15908.48it/s]
[generating BERT dataset]: 100%|██████████| 49997/49997 [00:03<00:00, 15992.54it/s]


In [22]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)



In [23]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=5)

In [24]:
model = BERTClassifier(bertmodel, hidden_size=256, dr_rate=0.5).to(device)

In [25]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [26]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [27]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [28]:
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [29]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [30]:
import math
def gen_attention_mask(token_ids, valid_length):
    attention_mask = torch.zeros_like(token_ids)
    for i, v in enumerate(valid_length):
        attention_mask[i][:v] = 1
    return attention_mask.float()


In [32]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 1 batch id 1 loss 0.30781763792037964 train acc 0.890625
epoch 1 batch id 201 loss 0.14335723221302032 train acc 0.9384328358208955
epoch 1 batch id 401 loss 0.13531407713890076 train acc 0.9378117206982544
epoch 1 batch id 601 loss 0.2521513104438782 train acc 0.9382539517470881
epoch 1 batch id 801 loss 0.20836767554283142 train acc 0.9377730961298377
epoch 1 batch id 1001 loss 0.21005457639694214 train acc 0.9370785464535465
epoch 1 batch id 1201 loss 0.12816943228244781 train acc 0.9368495004163198
epoch 1 batch id 1401 loss 0.23530995845794678 train acc 0.9363735724482513
epoch 1 batch id 1601 loss 0.22132864594459534 train acc 0.9360360712054966
epoch 1 batch id 1801 loss 0.14328327775001526 train acc 0.9359296918378679
epoch 1 batch id 2001 loss 0.247671976685524 train acc 0.9359070464767616
epoch 1 batch id 2201 loss 0.09174485504627228 train acc 0.9358388232621536

epoch 1 train acc 0.9359102470930232


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 1 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 2 batch id 1 loss 0.3279961347579956 train acc 0.875
epoch 2 batch id 201 loss 0.19918538630008698 train acc 0.9359452736318408
epoch 2 batch id 401 loss 0.09562499821186066 train acc 0.9363310473815462
epoch 2 batch id 601 loss 0.2714020311832428 train acc 0.9372920133111481
epoch 2 batch id 801 loss 0.19827067852020264 train acc 0.9379876716604245
epoch 2 batch id 1001 loss 0.26120755076408386 train acc 0.9372502497502497
epoch 2 batch id 1201 loss 0.15501363575458527 train acc 0.9366673605328892
epoch 2 batch id 1401 loss 0.23144735395908356 train acc 0.9362397394718058
epoch 2 batch id 1601 loss 0.24361149966716766 train acc 0.935957995003123
epoch 2 batch id 1801 loss 0.15718580782413483 train acc 0.9357127984453082
epoch 2 batch id 2001 loss 0.2596357762813568 train acc 0.9356805972013993
epoch 2 batch id 2201 loss 0.05511324480175972 train acc 0.9353773852794185

epoch 2 train acc 0.9358003363362172


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 2 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 3 batch id 1 loss 0.4099917709827423 train acc 0.84375
epoch 3 batch id 201 loss 0.23668067157268524 train acc 0.9331467661691543
epoch 3 batch id 401 loss 0.14883995056152344 train acc 0.9349283042394015
epoch 3 batch id 601 loss 0.2642366588115692 train acc 0.9357841098169717
epoch 3 batch id 801 loss 0.23070818185806274 train acc 0.9358224094881398
epoch 3 batch id 1001 loss 0.23921459913253784 train acc 0.9358922327672328
epoch 3 batch id 1201 loss 0.13604795932769775 train acc 0.9351842214820982
epoch 3 batch id 1401 loss 0.23888327181339264 train acc 0.9346672019985724
epoch 3 batch id 1601 loss 0.258232057094574 train acc 0.9348063710181137
epoch 3 batch id 1801 loss 0.11772524565458298 train acc 0.9350274153248196
epoch 3 batch id 2001 loss 0.26564034819602966 train acc 0.935297976011994
epoch 3 batch id 2201 loss 0.1357443630695343 train acc 0.9352638005452067

epoch 3 train acc 0.9354603725990158


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 3 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 4 batch id 1 loss 0.3544045388698578 train acc 0.890625
epoch 4 batch id 201 loss 0.19850526750087738 train acc 0.9361784825870647
epoch 4 batch id 401 loss 0.1742056906223297 train acc 0.9366427680798005
epoch 4 batch id 601 loss 0.22665847837924957 train acc 0.9365900582362728
epoch 4 batch id 801 loss 0.283020555973053 train acc 0.9363490948813983
epoch 4 batch id 1001 loss 0.21246609091758728 train acc 0.9360327172827173
epoch 4 batch id 1201 loss 0.13960054516792297 train acc 0.9358087010824313
epoch 4 batch id 1401 loss 0.27803295850753784 train acc 0.9354144361170592
epoch 4 batch id 1601 loss 0.19567273557186127 train acc 0.9354212211118051
epoch 4 batch id 1801 loss 0.17061394453048706 train acc 0.9354091476957246
epoch 4 batch id 2001 loss 0.19127364456653595 train acc 0.9355868940529735
epoch 4 batch id 2201 loss 0.08103649318218231 train acc 0.9357252385279419

epoch 4 train acc 0.9357670065580602


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 4 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 5 batch id 1 loss 0.29936718940734863 train acc 0.859375
epoch 5 batch id 201 loss 0.18927431106567383 train acc 0.9340796019900498
epoch 5 batch id 401 loss 0.18593120574951172 train acc 0.9348114089775561
epoch 5 batch id 601 loss 0.2827187776565552 train acc 0.9356801164725458
epoch 5 batch id 801 loss 0.14695227146148682 train acc 0.9360759987515606
epoch 5 batch id 1001 loss 0.22952759265899658 train acc 0.9350961538461539
epoch 5 batch id 1201 loss 0.11875797808170319 train acc 0.934533721898418
epoch 5 batch id 1401 loss 0.22969718277454376 train acc 0.9345222162740899
epoch 5 batch id 1601 loss 0.23069745302200317 train acc 0.9342207995003123
epoch 5 batch id 1801 loss 0.10851491987705231 train acc 0.9344201138256524
epoch 5 batch id 2001 loss 0.3451370298862457 train acc 0.9347826086956522
epoch 5 batch id 2201 loss 0.11281037330627441 train acc 0.9347881644706951

epoch 5 train acc 0.934947249037622


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 5 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 6 batch id 1 loss 0.3720177412033081 train acc 0.84375
epoch 6 batch id 201 loss 0.19181612133979797 train acc 0.9371113184079602
epoch 6 batch id 401 loss 0.12695062160491943 train acc 0.9366427680798005
epoch 6 batch id 601 loss 0.3218015432357788 train acc 0.9371880199667221
epoch 6 batch id 801 loss 0.1945500522851944 train acc 0.9371683832709113
epoch 6 batch id 1001 loss 0.3145107626914978 train acc 0.9363917332667333
epoch 6 batch id 1201 loss 0.11246959865093231 train acc 0.9358737510407993
epoch 6 batch id 1401 loss 0.2476348727941513 train acc 0.9358493932905068
epoch 6 batch id 1601 loss 0.26063552498817444 train acc 0.9355578544659587
epoch 6 batch id 1801 loss 0.11427875608205795 train acc 0.9355132565241533
epoch 6 batch id 2001 loss 0.2355523556470871 train acc 0.9357821089455273
epoch 6 batch id 2201 loss 0.1739010065793991 train acc 0.9356471490231713

epoch 6 train acc 0.935723600335344


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 6 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 7 batch id 1 loss 0.4277228116989136 train acc 0.828125
epoch 7 batch id 201 loss 0.2132609635591507 train acc 0.9346237562189055
epoch 7 batch id 401 loss 0.18565106391906738 train acc 0.9363310473815462
epoch 7 batch id 601 loss 0.1921117752790451 train acc 0.9372660149750416
epoch 7 batch id 801 loss 0.2286350429058075 train acc 0.9365246566791511
epoch 7 batch id 1001 loss 0.30213844776153564 train acc 0.935939060939061
epoch 7 batch id 1201 loss 0.10168814659118652 train acc 0.935717631140716
epoch 7 batch id 1401 loss 0.274747759103775 train acc 0.9353475196288366
epoch 7 batch id 1601 loss 0.2221226841211319 train acc 0.9352455496564647
epoch 7 batch id 1801 loss 0.1927114725112915 train acc 0.935305038867296
epoch 7 batch id 2001 loss 0.20698638260364532 train acc 0.9355478510744628
epoch 7 batch id 2201 loss 0.11199753731489182 train acc 0.9356542480690595

epoch 7 train acc 0.9360271338499088


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 7 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 8 batch id 1 loss 0.31938332319259644 train acc 0.875
epoch 8 batch id 201 loss 0.22547514736652374 train acc 0.9366449004975125
epoch 8 batch id 401 loss 0.1103774681687355 train acc 0.9370324189526185
epoch 8 batch id 601 loss 0.2328304648399353 train acc 0.937603993344426
epoch 8 batch id 801 loss 0.18136587738990784 train acc 0.9374219725343321
epoch 8 batch id 1001 loss 0.16291822493076324 train acc 0.9369536713286714
epoch 8 batch id 1201 loss 0.13058659434318542 train acc 0.9363681307243963
epoch 8 batch id 1401 loss 0.19932805001735687 train acc 0.9359720735189151
epoch 8 batch id 1601 loss 0.20019502937793732 train acc 0.9356066520924422
epoch 8 batch id 1801 loss 0.17374001443386078 train acc 0.9354959050527485
epoch 8 batch id 2001 loss 0.27287405729293823 train acc 0.935508808095952
epoch 8 batch id 2201 loss 0.11554523557424545 train acc 0.9353986824170831

epoch 8 train acc 0.9357369322466068


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 8 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 9 batch id 1 loss 0.33788183331489563 train acc 0.859375
epoch 9 batch id 201 loss 0.17800867557525635 train acc 0.9361784825870647
epoch 9 batch id 401 loss 0.11398687213659286 train acc 0.9370324189526185
epoch 9 batch id 601 loss 0.2912108898162842 train acc 0.9371100249584027
epoch 9 batch id 801 loss 0.19433587789535522 train acc 0.9374804931335831
epoch 9 batch id 1001 loss 0.22225099802017212 train acc 0.9365478271728271
epoch 9 batch id 1201 loss 0.15657417476177216 train acc 0.9356786011656952
epoch 9 batch id 1401 loss 0.27760007977485657 train acc 0.9352471448965025
epoch 9 batch id 1601 loss 0.26320281624794006 train acc 0.9351284353529045
epoch 9 batch id 1801 loss 0.19801293313503265 train acc 0.935105496946141
epoch 9 batch id 2001 loss 0.26209381222724915 train acc 0.9351808470764618
epoch 9 batch id 2201 loss 0.12161140143871307 train acc 0.9349514425261245

epoch 9 train acc 0.9351438172176362


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 9 test acc 0.8752997122762148


HBox(children=(FloatProgress(value=0.0, max=2344.0), HTML(value='')))

epoch 10 batch id 1 loss 0.3170332908630371 train acc 0.875
epoch 10 batch id 201 loss 0.14119859039783478 train acc 0.9334577114427861
epoch 10 batch id 401 loss 0.17781388759613037 train acc 0.9358245012468828
epoch 10 batch id 601 loss 0.27473291754722595 train acc 0.9360700915141431
epoch 10 batch id 801 loss 0.12813779711723328 train acc 0.936583177278402
epoch 10 batch id 1001 loss 0.264318585395813 train acc 0.9361263736263736
epoch 10 batch id 1201 loss 0.08390761911869049 train acc 0.9355745212323064
epoch 10 batch id 1401 loss 0.20121996104717255 train acc 0.9352582976445396
epoch 10 batch id 1601 loss 0.28321683406829834 train acc 0.9351089163023111
epoch 10 batch id 1801 loss 0.2013881653547287 train acc 0.9352009300388673
epoch 10 batch id 2001 loss 0.2241167277097702 train acc 0.935454147926037
epoch 10 batch id 2201 loss 0.08333182334899902 train acc 0.9355974557019536

epoch 10 train acc 0.9356104341118342


HBox(children=(FloatProgress(value=0.0, max=782.0), HTML(value='')))


epoch 10 test acc 0.8752997122762148
