In [1]:
import os
import torch

from xlm.utils import AttrDict
from xlm.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
from xlm.model.transformer import TransformerModel

FAISS library was not found.
FAISS not available. Switching to standard nearest neighbors search implementation.


## Reload a pretrained model

In [2]:
model_path = "/projectnb/statnlp/gkuwanto/XLM/dumped/baseline_para_0/q3v4i6kl9t/best-valid_mlm_ppl.pth"
reloaded = torch.load(model_path)
params = AttrDict(reloaded['params'])
print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

Supported languages: en, id


## Build dictionary / update parameters / build model

In [3]:
# build dictionary / update parameters
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
params.n_words = len(dico)
params.bos_index = dico.index(BOS_WORD)
params.eos_index = dico.index(EOS_WORD)
params.pad_index = dico.index(PAD_WORD)
params.unk_index = dico.index(UNK_WORD)
params.mask_index = dico.index(MASK_WORD)

# build model / reload weights
model = TransformerModel(params, dico, True, True)
model.eval()

from collections import OrderedDict
reloaded_model = OrderedDict()
for k, v in reloaded['model'].items():
      reloaded_model[k.replace('module.', '')] = v
model.load_state_dict(reloaded_model)

<All keys matched successfully>


## Get sentence representations

Sentences have to be in the BPE format, i.e. tokenized sentences on which you applied fastBPE.

In [4]:
# Below is one way to bpe-ize sentences
codes = "" # path to the codes of the model
fastbpe = os.path.join(os.getcwd(), 'tools/fastBPE/fast')

def to_bpe(sentences):
    # write sentences to tmp file
    with open('/tmp/sentences.bpe', 'w') as fwrite:
        for sent in sentences:
            fwrite.write(sent + '\n')
    
    # apply bpe to tmp file
    os.system('%s applybpe /tmp/sentences.bpe /tmp/sentences %s' % (fastbpe, codes))
    
    # load bpe-ized sentences
    sentences_bpe = []
    with open('/tmp/sentences.bpe') as f:
        for line in f:
            sentences_bpe.append(line.rstrip())
    
    return sentences_bpe


In [5]:
# Below are already BPE-ized sentences

# list of (sentences, lang)
sentences = [
     'warung ini dimiliki oleh pengusaha pabrik tahu yang sudah puluhan tahun terkenal membuat tahu putih di bandung . tahu berkualitas , dipadu keahlian memasak , dipadu kretivitas , jadilah warung yang menyajikan menu utama berbahan tahu , ditambah menu umum lain seperti ayam . semuanya selera indonesia . harga cukup terjangkau . jangan lewatkan tahu bletoka nya , tidak kalah dengan yang asli dari tegal !',
    'aaa'
]

# bpe-ize sentences
sentences = to_bpe(sentences)
print('\n\n'.join(sentences))

# check how many tokens are OOV
n_w = len([w for w in ' '.join(sentences).split()])
n_oov = len([w for w in ' '.join(sentences).split() if w not in dico.word2id])
print('Number of out-of-vocab words: %s/%s' % (n_oov, n_w))

# add </s> sentence delimiters
sentences = [(('</s> %s </s>' % sent.strip()).split()) for sent in sentences]

warung ini dimiliki oleh pengusaha pabrik tahu yang sudah puluhan tahun terkenal membuat tahu putih di bandung . tahu berkualitas , dipadu keahlian memasak , dipadu kretivitas , jadilah warung yang menyajikan menu utama berbahan tahu , ditambah menu umum lain seperti ayam . semuanya selera indonesia . harga cukup terjangkau . jangan lewatkan tahu bletoka nya , tidak kalah dengan yang asli dari tegal !

aaa
Number of out-of-vocab words: 2/67


### Create batch

In [6]:
bs = len(sentences)
slen = max([len(sent) for sent in sentences])

word_ids = torch.LongTensor(slen, bs).fill_(params.pad_index)
for i in range(len(sentences)):
    sent = torch.LongTensor([dico.index(w) for w in sentences[i]])
    word_ids[:len(sent), i] = sent

lengths = torch.LongTensor([len(sent) for sent in sentences])
                             
# NOTE: No more language id (removed it in a later version)
langs = torch.LongTensor([params.lang2id['id']]).unsqueeze(0).expand(slen, bs) if params.n_langs > 1 else None
# langs = torch.LongTensor([params.lang2id['id']])


In [7]:
word_ids

tensor([[    1,     1],
        [ 3367, 14369],
        [   20,     1],
        [ 1017,     2],
        [   45,     2],
        [ 1891,     2],
        [ 1616,     2],
        [  177,     2],
        [   16,     2],
        [   48,     2],
        [ 2375,     2],
        [   60,     2],
        [  772,     2],
        [   82,     2],
        [  177,     2],
        [  539,     2],
        [   18,     2],
        [  647,     2],
        [   14,     2],
        [  177,     2],
        [ 1089,     2],
        [   15,     2],
        [25377,     2],
        [ 3301,     2],
        [ 2851,     2],
        [   15,     2],
        [25377,     2],
        [    3,     2],
        [   15,     2],
        [ 4504,     2],
        [ 3367,     2],
        [   16,     2],
        [ 3075,     2],
        [  969,     2],
        [  323,     2],
        [ 4369,     2],
        [  177,     2],
        [   15,     2],
        [ 1702,     2],
        [  969,     2],
        [  303,     2],
        [   76, 

### Forward

In [8]:
tensor = model('fwd', x=word_ids, lengths=lengths, langs=langs, causal=False).contiguous()
print(tensor.size())

torch.Size([68, 2, 1024])


In [9]:
from torch import nn

proj = nn.Sequential(*[
    nn.Dropout(params.dropout),
    nn.Linear(1024, 3)
]).cuda()

In [10]:
logits = proj(tensor[0].cuda())

In [11]:
logits.data.max(1)[1]

tensor([1, 1], device='cuda:0')

The variable `tensor` is of shape `(sequence_length, batch_size, model_dimension)`.

`tensor[0]` is a tensor of shape `(batch_size, model_dimension)` that corresponds to the first hidden state of the last layer of each sentence.

This is this vector that we use to finetune on the GLUE and XNLI tasks.

# finetuning smsa

In [12]:
from data_loader_utils import DocumentSentimentDataset, DocumentSentimentDataLoader

import random

import numpy as np
import pandas as pd
import torch
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm


from utils.forward_fn import forward_sequence_classification
from utils.metrics import document_sentiment_metrics_fn



def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
set_seed(33333)

In [13]:
train_dataset_path = './dataset/smsa_doc-sentiment-prosa/train_preprocess.tsv'
valid_dataset_path = './dataset/smsa_doc-sentiment-prosa/valid_preprocess.tsv'
test_dataset_path = './dataset/smsa_doc-sentiment-prosa/test_preprocess_masked_label.tsv'

In [15]:
train_dataset = DocumentSentimentDataset(train_dataset_path, dico, params, lowercase=True)
valid_dataset = DocumentSentimentDataset(valid_dataset_path, dico, params, lowercase=True)
test_dataset = DocumentSentimentDataset(test_dataset_path,dico, params, lowercase=True)

train_loader = DocumentSentimentDataLoader(dataset=train_dataset, params=params, max_seq_len=512, batch_size=16, num_workers=16, shuffle=True)  
valid_loader = DocumentSentimentDataLoader(dataset=valid_dataset, params=params, max_seq_len=512,  batch_size=16, num_workers=16, shuffle=False)  
test_loader = DocumentSentimentDataLoader(dataset=test_dataset, params=params, max_seq_len=512, batch_size=16, num_workers=16, shuffle=False)

In [16]:
train_dataset[0]

(tensor([1669,  252, 1099, 1318,  677,  792,    3,  367,  677,  367,    3,  425,
          367,  543,  367,  958,  367,  868,  367,    3,  749,  958,  655,  877,
            3,  824,  655,  677,  792, 1318,  357,  252,  877,  252,    3,  824,
          252,  458, 1099,  367,  868,    3,  710,  252,  877, 1318,    3,  882,
          252,  677,  792,    3,  357, 1318,  425,  252,  877,    3,  824, 1318,
          958, 1318,  877,  252,  677,    3,  710,  252,  877, 1318,  677,    3,
          710,  655, 1099,  868,  655,  677,  252,  958,    3,  543,  655,  543,
          458, 1318,  252,  710,    3,  710,  252,  877, 1318,    3,  824, 1318,
          710,  367,  877,    3,  425,  367,    3,  458,  252,  677,  425, 1318,
          677,  792,    3,   14,    3,  710,  252,  877, 1318,    3,  458,  655,
         1099,  868, 1318,  252,  958,  367,  710,  252,  357,    3,   15,    3,
          425,  367,  824,  252,  425, 1318,    3,  868,  655,  252,  877,  958,
          367,  252,  677,  

In [17]:
w2i, i2w = DocumentSentimentDataset.LABEL2INDEX, DocumentSentimentDataset.INDEX2LABEL
print(w2i)
print(i2w)

{'positive': 0, 'neutral': 1, 'negative': 2}
{0: 'positive', 1: 'neutral', 2: 'negative'}


In [18]:
import time

## train and test

In [19]:
from model_utils import forward_sequence_classification

In [20]:
optimizer_m = optim.Adam(model.parameters(), lr=3e-5)
model = model.cuda()
optimizer_p = optim.Adam(proj.parameters(), lr=3e-5)
proj = proj.cuda()

In [21]:
n_epochs = 5

for epoch in range(n_epochs):
    model.train()
    proj.train()
    torch.set_grad_enabled(True)
    
    total_train_loss = 0

    list_hyp, list_label = [], []
    
    train_pbar = tqdm(train_loader, leave=True, total=len(train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        loss, logits, batch_hyp, batch_label = forward_sequence_classification(proj, model, batch_data[:-1], i2w=i2w, device='cuda')
#         print(loss)
        
        optimizer_m.zero_grad()
        optimizer_p.zero_grad()
        loss.backward()
        optimizer_m.step()
        optimizer_p.step()
        
        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss
        
        list_hyp += batch_hyp
        list_label += batch_label
        
        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f}".format((epoch+1),
            total_train_loss/(i+1)))
        
        
    # Calculate train metric
    metrics = document_sentiment_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {}".format((epoch+1),
        total_train_loss/(i+1),metrics))
    
    
        
        
    # Evaluate on validation
    model.eval()
    proj.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    
    list_hyp, list_label = [], []

    pbar = tqdm(valid_loader, leave=True, total=len(valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        loss, logits, batch_hyp, batch_label = forward_sequence_classification(proj, model, batch_data[:-1], i2w=i2w, device='cuda')
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        
        pbar.set_description("VALID LOSS:{:.4f}".format(total_loss/(i+1)))
        
    metrics = document_sentiment_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics))
    

(Epoch 1) TRAIN LOSS:0.7074: 100%|██████████| 688/688 [06:53<00:00,  1.66it/s]


(Epoch 1) TRAIN LOSS:0.7074 {'ACC': 0.6940909090909091, 'F1': 0.572908541790616, 'REC': 0.5539061361734273, 'PRE': 0.6215411110939723}


VALID LOSS:0.6415: 100%|██████████| 79/79 [00:15<00:00,  5.15it/s]


(Epoch 1) VALID LOSS:0.6415 {'ACC': 0.7317460317460317, 'F1': 0.5926110504552285, 'REC': 0.5670945243547362, 'PRE': 0.7273314924014516}


(Epoch 2) TRAIN LOSS:0.5875: 100%|██████████| 688/688 [06:52<00:00,  1.67it/s]


(Epoch 2) TRAIN LOSS:0.5875 {'ACC': 0.755, 'F1': 0.6856904185317573, 'REC': 0.6734095055331242, 'PRE': 0.7012846177094132}


VALID LOSS:0.5580: 100%|██████████| 79/79 [00:15<00:00,  5.13it/s]


(Epoch 2) VALID LOSS:0.5580 {'ACC': 0.7674603174603175, 'F1': 0.6959390261185757, 'REC': 0.6837106545386137, 'PRE': 0.7134656201514282}


(Epoch 3) TRAIN LOSS:0.5150: 100%|██████████| 688/688 [06:55<00:00,  1.66it/s]


(Epoch 3) TRAIN LOSS:0.5150 {'ACC': 0.7864545454545454, 'F1': 0.7295356843167031, 'REC': 0.7213072883526804, 'PRE': 0.73891677647396}


VALID LOSS:0.5537: 100%|██████████| 79/79 [00:15<00:00,  5.13it/s]


(Epoch 3) VALID LOSS:0.5537 {'ACC': 0.7690476190476191, 'F1': 0.69283344257811, 'REC': 0.6644854570649897, 'PRE': 0.746791823094898}


(Epoch 4) TRAIN LOSS:0.4705: 100%|██████████| 688/688 [06:55<00:00,  1.65it/s]


(Epoch 4) TRAIN LOSS:0.4705 {'ACC': 0.8092727272727273, 'F1': 0.7641763127276856, 'REC': 0.7558122754590082, 'PRE': 0.7735716302651056}


VALID LOSS:0.5215: 100%|██████████| 79/79 [00:15<00:00,  5.14it/s]


(Epoch 4) VALID LOSS:0.5215 {'ACC': 0.7944444444444444, 'F1': 0.7226798721851195, 'REC': 0.6988590168762769, 'PRE': 0.759524512358524}


(Epoch 5) TRAIN LOSS:0.4188: 100%|██████████| 688/688 [06:52<00:00,  1.67it/s]


(Epoch 5) TRAIN LOSS:0.4188 {'ACC': 0.8320909090909091, 'F1': 0.7975780663667237, 'REC': 0.7900075699931337, 'PRE': 0.8059019127421904}


VALID LOSS:0.5142: 100%|██████████| 79/79 [00:15<00:00,  5.14it/s]


(Epoch 5) VALID LOSS:0.5142 {'ACC': 0.8, 'F1': 0.7405651530924792, 'REC': 0.7175984086301884, 'PRE': 0.7751295336787565}


In [22]:

# Evaluate on test
model.eval()
proj.eval()
torch.set_grad_enabled(False)

total_loss, total_correct, total_labels = 0, 0, 0
list_hyp, list_label = [], []

pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    loss, logits, batch_hyp, batch_label = forward_sequence_classification(proj, model, batch_data[:-1], i2w=i2w, device='cuda')
    list_hyp += batch_hyp

# Save prediction
df = pd.DataFrame({'label':list_hyp}).reset_index()
# df.to_csv('pred.txt', index=False)

print(df)

100%|██████████| 32/32 [00:03<00:00,  8.10it/s]

     index     label
0        0  positive
1        1  negative
2        2  negative
3        3  negative
4        4  negative
..     ...       ...
495    495  negative
496    496  negative
497    497  negative
498    498  negative
499    499  negative

[500 rows x 2 columns]





In [23]:
df['label'].value_counts()

negative    289
positive    189
neutral      22
Name: label, dtype: int64

In [24]:
df.to_csv('/projectnb/statnlp/gik/XLM/output/pred-smsa.csv', index=False)

In [25]:
torch.save(model.state_dict(), '/projectnb/statnlp/gik/XLM/output/smsa_xlm_finetuned_model.pth')
torch.save(proj.state_dict(), '/projectnb/statnlp/gik/XLM/output/smsa_proj.pth')