In [1]:
from BaseEncoder import BaseEncoder
from DataLoader import read_examples_from_file, ReviewDataset

from transformers import AutoTokenizer, TFAutoModel
import torch
from torch import nn
from torch.utils.data import DataLoader

from SpanMltri import SpanMltri

Using cuda device


In [2]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Using cuda device


# Data Loader

In [3]:
def decode_te_label(te_label_dict, tokens):
    # input : 
    #     te_label_dict = Dict; Key : span range (string), Value :  term type label (ASPECT, SENTIMENT, O)
    #     tokens = list of tokens from the sentence
    # output : List of pair (phrase, term_type_label)
    te_label_list = []
    
    for span_range in te_label_dict:
        start_idx, end_idx = span_range.split('-')
        start_idx, end_idx = int(start_idx), int(end_idx)
        sentence = ' '.join(tokens[start_idx:end_idx+1])
        te_label = te_label_dict[span_range]
        te_label_list.append((sentence, te_label))
    
    return te_label_list

def decode_relation_label(relation_dict, tokens):
    # input : 
    #     relation = Dict; Key : span range pair (aspect_term_span_range, opinion_term_span_range), Value :  sentiment polarity label (POSITIVE, NEGATIVE, NEUTRAL)
    #     tokens = list of tokens from the sentence
    # output : list of triples (aspect_term, opinion_term, polarity)
    relation_list = []
    
    for span_range_pair in relation_dict:
        aspect_term_span_range, opinion_term_span_range = span_range_pair
        
        aspect_term_start_idx, aspect_term_end_idx = aspect_term_span_range.split('-')
        aspect_term_start_idx, aspect_term_end_idx = int(aspect_term_start_idx), int(aspect_term_end_idx)
        aspect_term = ' '.join(tokens[aspect_term_start_idx:aspect_term_end_idx+1])
        
        opinion_term_start_idx, opinion_term_end_idx = opinion_term_span_range.split('-')
        opinion_term_start_idx, opinion_term_end_idx = int(opinion_term_start_idx), int(opinion_term_end_idx)
        opinion_term = ' '.join(tokens[opinion_term_start_idx:opinion_term_end_idx+1])
        
        relation_label = relation_dict[span_range_pair]
        
        relation_list.append((aspect_term, opinion_term, relation_label))
    
    return relation_list

In [4]:
TRAIN_FILE_PATH = "dataset/train.tsv"
DEV_FILE_PATH = "dataset/dev.tsv"

In [5]:
train_data = ReviewDataset(TRAIN_FILE_PATH)
dev_data = ReviewDataset(DEV_FILE_PATH)

In [6]:
IDX = 2900
print(train_data.texts[IDX])
print(decode_te_label(train_data.te_label_dict[IDX], train_data.texts[IDX]))
print(decode_relation_label(train_data.relation_dict[IDX], train_data.texts[IDX]))

['kamar', 'dibersihkan', 'dan', 'dirapikan', 'setiap', 'hari', ',', 'fasilitas', 'oke', ',', 'hanya', 'resepsionis', 'sering', 'tidak', 'ada', 'di', 'tempat', '.']
[('fasilitas', 'ASPECT'), ('oke', 'SENTIMENT'), ('resepsionis', 'ASPECT'), ('sering tidak ada', 'SENTIMENT')]
[('fasilitas', 'oke', 'PO'), ('resepsionis', 'sering tidak ada', 'NG')]


In [7]:
IDX = 900
print(dev_data.texts[IDX])
print(decode_te_label(dev_data.te_label_dict[IDX], dev_data.texts[IDX]))
print(decode_relation_label(dev_data.relation_dict[IDX], dev_data.texts[IDX]))

['sudah', '2', 'kali', 'menginap', 'disana', 'dan', 'pelayanannya', 'memuaskan', '.']
[('pelayanannya', 'ASPECT'), ('memuaskan', 'SENTIMENT')]
[('pelayanannya', 'memuaskan', 'PO')]


# Train Method

In [15]:
model = SpanMltri().to(device)

In [16]:
BATCH_SIZE = 8

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
dev_dataloader = DataLoader(dev_data, batch_size=BATCH_SIZE, shuffle=True)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

lambda_t = 0.5
lambda_r = 0.5

In [20]:
def train(train_dataloader, train_data, model, loss_fn, optimizer):
    size = len(train_dataloader.dataset)

    for batch, X in enumerate(train_dataloader):
        X = X.to(device)
        current_te_label_dict = train_data.te_label_dict[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]
        current_relation_dict = train_data.relation_dict[(batch)*BATCH_SIZE:(batch+1)*BATCH_SIZE]

        logits_term_scorer, span_ranges, logits_relation_scorer, span_pair_ranges = model(X)

        y_te_true = []
        CURRENT_BATCH_SIZE = min(len(current_te_label_dict), BATCH_SIZE)
        for i in range(CURRENT_BATCH_SIZE):
            y_ = []
            for span_range in span_ranges:
                if span_range in current_te_label_dict[i]:
                    label = current_te_label_dict[i][span_range]
                    if label == 'ASPECT':
                        y_.append(1)
                    elif label == 'SENTIMENT':
                        y_.append(2)
                else: # label is O
                    y_.append(0)        
            y_te_true.append(torch.Tensor(y_))
        y_te_true = torch.stack(y_te_true)
        y_te_true = y_te_true.to(torch.long)

        logits_term_scorer = logits_term_scorer.reshape(logits_term_scorer.shape[0]*logits_term_scorer.shape[1], logits_term_scorer.shape[-1])
        y_te_true = y_te_true.reshape(-1).to(device)
        te_loss = loss_fn(logits_term_scorer, y_te_true)

        y_paote_true = []
        CURRENT_BATCH_SIZE = min(len(current_relation_dict), BATCH_SIZE)
        for i in range(CURRENT_BATCH_SIZE):
            y_ = []
            for span_pair_range in span_pair_ranges[i]:
                if span_pair_range not in current_relation_dict[i]:
                    y_.append(0)
                else:
                    label = current_relation_dict[i][span_pair_range]
                    if label == 'PO':
                        y_.append(1)
                    elif label == 'NG':
                        y_.append(2)
                    elif label == 'NT':
                        y_.append(3)
            y_paote_true.append(torch.Tensor(y_))
        y_paote_true = torch.stack(y_paote_true)
        y_paote_true = y_paote_true.to(torch.long)

        logits_relation_scorer = logits_relation_scorer.reshape(logits_relation_scorer.shape[0]*logits_relation_scorer.shape[1], logits_relation_scorer.shape[-1])
        y_paote_true = y_paote_true.reshape(-1).to(device)
        paote_loss = loss_fn(logits_relation_scorer, y_paote_true)

        total_loss = lambda_t*te_loss + lambda_r*paote_loss

        # Backpropagation
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if batch % 2 == 0:
            total_loss, current = total_loss.item(), batch * len(X)
            print(f"loss: {total_loss:>7f}  [{current:>5d}/{size:>5d}]")

In [21]:
def test(dev_dataloader, model):
    size = len(dev_dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dev_dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, train_data, model, loss_fn, optimizer)
    test(dev_dataloader, model)
print("Done!")

Epoch 1
-------------------------------
loss: 1.333529  [    0/ 3000]
loss: 1.280262  [   16/ 3000]
loss: 1.256137  [   32/ 3000]
loss: 1.234200  [   48/ 3000]
loss: 1.231892  [   64/ 3000]
loss: 1.233439  [   80/ 3000]
loss: 1.228324  [   96/ 3000]
loss: 1.226486  [  112/ 3000]
loss: 1.234813  [  128/ 3000]
loss: 1.232931  [  144/ 3000]
loss: 1.229365  [  160/ 3000]
loss: 1.226102  [  176/ 3000]
loss: 1.225214  [  192/ 3000]
loss: 1.230003  [  208/ 3000]
loss: 1.223525  [  224/ 3000]
loss: 1.222108  [  240/ 3000]
loss: 1.225379  [  256/ 3000]
loss: 1.223016  [  272/ 3000]
loss: 1.225428  [  288/ 3000]
loss: 1.214058  [  304/ 3000]
loss: 1.226027  [  320/ 3000]
loss: 1.226262  [  336/ 3000]
loss: 1.218978  [  352/ 3000]
loss: 1.224049  [  368/ 3000]
loss: 1.220452  [  384/ 3000]
loss: 1.225622  [  400/ 3000]
loss: 1.221791  [  416/ 3000]
loss: 1.220468  [  432/ 3000]
loss: 1.214044  [  448/ 3000]
loss: 1.226375  [  464/ 3000]
loss: 1.216253  [  480/ 3000]
loss: 1.211061  [  496/ 3000]
