In [1]:
from BaseEncoder import BaseEncoder
from DataLoader import read_examples_from_file, ReviewDataset

from transformers import AutoTokenizer, TFAutoModel
import torch
from torch import nn
from torch.utils.data import DataLoader

from SpanMltri import SpanMltri

In [2]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Using cuda device


# Data Loader

In [3]:
def decode_te_label(te_label_dict, tokens):
    # input : 
    #     te_label_dict = Dict; Key : span range (string), Value :  term type label (ASPECT, SENTIMENT, O)
    #     tokens = list of tokens from the sentence
    # output : List of pair (phrase, term_type_label)
    te_label_list = []
    
    for span_range in te_label_dict:
        start_idx, end_idx = span_range.split('-')
        start_idx, end_idx = int(start_idx), int(end_idx)
        sentence = ' '.join(tokens[start_idx:end_idx+1])
        te_label = te_label_dict[span_range]
        te_label_list.append((sentence, te_label))
    
    return te_label_list

def decode_relation_label(relation_dict, tokens):
    # input : 
    #     relation = Dict; Key : span range pair (aspect_term_span_range, opinion_term_span_range), Value :  sentiment polarity label (POSITIVE, NEGATIVE, NEUTRAL)
    #     tokens = list of tokens from the sentence
    # output : list of triples (aspect_term, opinion_term, polarity)
    relation_list = []
    
    for span_range_pair in relation_dict:
        aspect_term_span_range, opinion_term_span_range = span_range_pair
        
        aspect_term_start_idx, aspect_term_end_idx = aspect_term_span_range.split('-')
        aspect_term_start_idx, aspect_term_end_idx = int(aspect_term_start_idx), int(aspect_term_end_idx)
        aspect_term = ' '.join(tokens[aspect_term_start_idx:aspect_term_end_idx+1])
        
        opinion_term_start_idx, opinion_term_end_idx = opinion_term_span_range.split('-')
        opinion_term_start_idx, opinion_term_end_idx = int(opinion_term_start_idx), int(opinion_term_end_idx)
        opinion_term = ' '.join(tokens[opinion_term_start_idx:opinion_term_end_idx+1])
        
        relation_label = relation_dict[span_range_pair]
        
        relation_list.append((aspect_term, opinion_term, relation_label))
    
    return relation_list

In [4]:
TRAIN_FILE_PATH = "dataset/train.tsv"
DEV_FILE_PATH = "dataset/dev.tsv"

In [5]:
train_data = ReviewDataset(TRAIN_FILE_PATH)
dev_data = ReviewDataset(DEV_FILE_PATH)

Some weights of the model checkpoint at indolem/indobert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at indolem/indobert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predic

In [6]:
IDX = 393
print(train_data.texts[IDX])
print(decode_te_label(train_data.te_label_dict[IDX], train_data.texts[IDX]))
print(decode_relation_label(train_data.relation_dict[IDX], train_data.texts[IDX]))

['kain', 'kasur', 'selama', '4', 'malam', 'menginap', 'tidak', 'pernah', 'di', 'ganti', ',', 'dan', 'ac', 'kurang', 'dingin', '.']
[('kasur selama', 'ASPECT'), ('pernah di ganti ,', 'SENTIMENT'), ('kurang', 'ASPECT'), ('dingin .', 'SENTIMENT')]
[('kasur selama', 'pernah di ganti ,', 'NG'), ('kurang', 'dingin .', 'NG')]


In [7]:
IDX = 38
print(dev_data.texts[IDX])
print(decode_te_label(dev_data.te_label_dict[IDX], dev_data.texts[IDX]))
print(decode_relation_label(dev_data.relation_dict[IDX], dev_data.texts[IDX]))

['sebenarnya', 'kamar', 'baik', 'tetapi', 'kurang', 'profesional', ',', 'di', 'kasih', 'tahu', 'sarapan', 'nasi', 'goreng', ',', 'minta', 'di', 'antar', 'jam', '8.30', 'ehh', 'datang', 'jam', '9', 'itu', 'pun', 'di', 'telepon', 'dulu', 'dan', 'yang', 'datang', 'bahkan', 'mie', 'muah', 'telor', '.', 'hm', '.']
[('baik', 'ASPECT'), ('tetapi', 'SENTIMENT'), ('profesional ,', 'SENTIMENT')]
[('baik', 'tetapi', 'PO')]


# Train Method

In [8]:
model = SpanMltri().to(device)

Some weights of the model checkpoint at indolem/indobert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
BATCH_SIZE = 16

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader = DataLoader(dev_data, batch_size=BATCH_SIZE, shuffle=False)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

lambda_t = 0.5
lambda_r = 0.5

In [10]:
def train(train_dataloader, train_data, model, loss_fn, optimizer):
    size = len(train_dataloader.dataset)

    for batch, X in enumerate(train_dataloader):
        X = X.to(device)
        current_te_label_dict = train_data.te_label_dict[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
        current_relation_dict = train_data.relation_dict[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
        sentences = train_data.texts[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
        logits_term_scorer, span_ranges, logits_relation_scorer, span_pair_ranges = model(X)

        y_te_true = []
        CURRENT_BATCH_SIZE = min(len(current_te_label_dict), BATCH_SIZE)
        for i in range(CURRENT_BATCH_SIZE):
            y_ = []
            for span_range in span_ranges:
                if span_range in current_te_label_dict[i]:
                    label = current_te_label_dict[i][span_range]
                    if label == 'ASPECT':
                        y_.append(1)
                    elif label == 'SENTIMENT':
                        y_.append(2)
                else: # label is O
                    y_.append(0)        
            y_te_true.append(torch.Tensor(y_))
        y_te_true = torch.stack(y_te_true)
        y_te_true = y_te_true.to(torch.long)

        logits_term_scorer = logits_term_scorer.reshape(logits_term_scorer.shape[0]*logits_term_scorer.shape[1], logits_term_scorer.shape[-1])
        y_te_true = y_te_true.reshape(-1).to(device)
        te_loss = loss_fn(logits_term_scorer, y_te_true)

        y_paote_true = []
        CURRENT_BATCH_SIZE = min(len(current_relation_dict), BATCH_SIZE)
        for i in range(CURRENT_BATCH_SIZE):
            y_ = []
            for span_pair_range in span_pair_ranges[i]:
                if span_pair_range not in current_relation_dict[i]:
                    y_.append(0)
                else:
                    label = current_relation_dict[i][span_pair_range]
                    if label == 'PO':
                        y_.append(1)
                    elif label == 'NG':
                        y_.append(2)
                    elif label == 'NT':
                        y_.append(3)
            y_paote_true.append(torch.Tensor(y_))
        y_paote_true = torch.stack(y_paote_true)
        y_paote_true = y_paote_true.to(torch.long)

        logits_relation_scorer = logits_relation_scorer.reshape(logits_relation_scorer.shape[0]*logits_relation_scorer.shape[1], logits_relation_scorer.shape[-1])
        y_paote_true = y_paote_true.reshape(-1).to(device)
        paote_loss = loss_fn(logits_relation_scorer, y_paote_true)

        total_loss = lambda_t*te_loss + lambda_r*paote_loss

        # Backpropagation
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if batch % 2 == 0:
            total_loss, current = total_loss.item(), batch * len(X)
            print(f"loss: {total_loss:>7f}  [{current:>5d}/{size:>5d}]")

In [11]:
def test(dev_dataloader, model):
    size = len(dev_dataloader.dataset)
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch, X in enumerate(dev_dataloader):
            X = X.to(device)
            
            current_te_label_dict = dev_data.te_label_dict[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
            current_relation_dict = dev_data.relation_dict[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
            
            logits_term_scorer, span_ranges, logits_relation_scorer, span_pair_ranges = model(X)

            y_te_true = []
            CURRENT_BATCH_SIZE = min(len(current_te_label_dict), BATCH_SIZE)
            for i in range(CURRENT_BATCH_SIZE):
                y_ = []
                for span_range in span_ranges:
                    if span_range in current_te_label_dict[i]:
                        label = current_te_label_dict[i][span_range]
                        if label == 'ASPECT':
                            y_.append(1)
                        elif label == 'SENTIMENT':
                            y_.append(2)
                    else: # label is O
                        y_.append(0)        
                y_te_true.append(torch.Tensor(y_))
            y_te_true = torch.stack(y_te_true)
            y_te_true = y_te_true.to(torch.long)

            logits_term_scorer = logits_term_scorer.reshape(logits_term_scorer.shape[0]*logits_term_scorer.shape[1], logits_term_scorer.shape[-1])
            y_te_true = y_te_true.reshape(-1).to(device)
            te_loss = loss_fn(logits_term_scorer, y_te_true)

            y_paote_true = []
            CURRENT_BATCH_SIZE = min(len(current_relation_dict), BATCH_SIZE)
            for i in range(CURRENT_BATCH_SIZE):
                y_ = []
                for span_pair_range in span_pair_ranges[i]:
                    if span_pair_range not in current_relation_dict[i]:
                        y_.append(0)
                    else:
                        label = current_relation_dict[i][span_pair_range]
                        if label == 'PO':
                            y_.append(1)
                        elif label == 'NG':
                            y_.append(2)
                        elif label == 'NT':
                            y_.append(3)
                y_paote_true.append(torch.Tensor(y_))
            y_paote_true = torch.stack(y_paote_true)
            y_paote_true = y_paote_true.to(torch.long)

            logits_relation_scorer = logits_relation_scorer.reshape(logits_relation_scorer.shape[0]*logits_relation_scorer.shape[1], logits_relation_scorer.shape[-1])
            y_paote_true = y_paote_true.reshape(-1).to(device)
            paote_loss = loss_fn(logits_relation_scorer, y_paote_true)

            total_loss += lambda_t*te_loss.item() + lambda_r*paote_loss.item()

    total_loss /= size
    print(f"Test Error: \n Avg loss: {total_loss:>8f} \n")

In [12]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, train_data, model, loss_fn, optimizer)
    test(dev_dataloader, model)
print("Done!")

Epoch 1
-------------------------------
loss: 16.687113  [    0/ 3000]
loss: 2.232355  [   32/ 3000]
loss: 0.308499  [   64/ 3000]
loss: 0.123945  [   96/ 3000]
loss: 0.058107  [  128/ 3000]
loss: 0.043705  [  160/ 3000]
loss: 0.040378  [  192/ 3000]
loss: 0.041156  [  224/ 3000]
loss: 0.034098  [  256/ 3000]
loss: 0.030528  [  288/ 3000]
loss: 0.038016  [  320/ 3000]
loss: 0.031810  [  352/ 3000]
loss: 0.035299  [  384/ 3000]
loss: 0.033669  [  416/ 3000]
loss: 0.023997  [  448/ 3000]
loss: 0.037645  [  480/ 3000]
loss: 0.027837  [  512/ 3000]
loss: 0.034670  [  544/ 3000]
loss: 0.033726  [  576/ 3000]
loss: 0.033039  [  608/ 3000]
loss: 0.028358  [  640/ 3000]
loss: 0.027258  [  672/ 3000]
loss: 0.031088  [  704/ 3000]
loss: 0.028634  [  736/ 3000]
loss: 0.033256  [  768/ 3000]
loss: 0.035022  [  800/ 3000]
loss: 0.032138  [  832/ 3000]
loss: 0.030711  [  864/ 3000]
loss: 0.032198  [  896/ 3000]
loss: 0.036880  [  928/ 3000]
loss: 0.034458  [  960/ 3000]
loss: 0.037017  [  992/ 3000]

KeyboardInterrupt: 

In [4]:
model = SpanMltri().to(device)

In [5]:
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

SpanMltri(
  (base_encoder): BaseEncoder(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(31923, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
           

In [92]:
with torch.no_grad():
    for batch, X in enumerate(dev_dataloader):
        X = X.to(device)
        
        sentences = train_data.texts[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
        current_te_label_dict = train_data.te_label_dict[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
        current_relation_dict = train_data.relation_dict[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
        
        logits_term_scorer, span_ranges, logits_relation_scorer, span_pair_ranges = model(X)
        
        y_te_true = []
        CURRENT_BATCH_SIZE = min(len(current_te_label_dict), BATCH_SIZE)
        for i in range(CURRENT_BATCH_SIZE):
            y_ = []
            for span_range in span_ranges:
                if span_range in current_te_label_dict[i]:
                    label = current_te_label_dict[i][span_range]
                    if label == 'ASPECT':
                        y_.append(1)
                    elif label == 'SENTIMENT':
                        y_.append(2)
                else: # label is O
                    y_.append(0)        
            y_te_true.append(torch.Tensor(y_))
        y_te_true = torch.stack(y_te_true)
        break

In [80]:
IDX = 10

print(sentences[IDX])
print(current_relation_dict[IDX])
y_pred = logits_relation_scorer.argmax(-1)
print(y_pred[IDX])

['agak', 'perlu', 'dibersihkan', '.', 'kamar', 'banyak', 'semut', '.']
{('5-5', '1-3'): 'NG', ('5-5', '6-7'): 'NG'}
tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')


In [94]:
IDX = 6

print(sentences[IDX])
print(current_te_label_dict[IDX])
y_pred = logits_term_scorer.argmax(-1)
print(y_pred[IDX])
print(y_te_true[IDX])

['pelayanan', 'lumayan', 'baik', '.']
{'1-1': 'ASPECT', '2-3': 'SENTIMENT'}
tensor([0, 1, 2, 0, 2, 0, 2, 0, 2, 0, 0, 1, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [95]:
for idx, y in enumerate(y_pred[IDX]):
    if y != 0:
        span_range = span_ranges[idx]
        print(idx, ' ', y.item(), ' ', span_range)

1   1   1-1
2   2   2-2
4   2   4-4
6   2   6-6
8   2   8-8
11   1   11-11
12   2   12-12
17   2   17-17
56   2   16-17


In [91]:
sentences[10][3:5]

['.', 'kamar']