In [1]:
import numpy as np
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from my_sentence_piecer import MySentencePiecer
from albert_pre import AlbertPre
from tf_to_csv import TfToCsv

In [2]:
torch.cuda.current_device(), torch.cuda.get_device_name(device=None)

(0, 'GeForce RTX 2080 Ti')

# Params

In [3]:
MAX_SENT_N = 30

MAX_WORD_N = 150

MAX_WORD_SENT_N = 300

BATCHSIZE = 20

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data Processing

## Sentence Piecer

In [4]:
sentence_piecer = MySentencePiecer(vocab_size=10000, force_update=False)

In [5]:
print(sentence_piecer.vocab_size)
print(sentence_piecer.vocab_list[:20])
test = "hallo, i'm leaving. this is another sentences."
tokens = sentence_piecer.get_ids_from_vocab(test)
print(tokens)
print(sentence_piecer.get_real_text_from_ids(tokens))

10000
['<unk>', '<s>', '</s>', '▁the', 's', ',', '.', '▁to', '▁a', '▁in', '▁of', '▁and', '▁.', "'", '-', '▁was', '▁for', '▁on', '▁is', '▁he']
[1459, 118, 5, 46, 13, 74, 1111, 6, 57, 18, 220, 1100, 4, 6, 2]
 hallo, i'm leaving. this is another sentences.</s>


In [6]:
eos_token = sentence_piecer.eos_token

In [7]:
albert_pre = AlbertPre()

## Dataset

In [8]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, article, n_articles, highlights, n_highlights):
        self.x = self.to_tensor_list(article, dtype=torch.float)
        self.x_n = torch.tensor(n_articles, dtype=torch.long)

        self.y_n = torch.tensor(n_highlights, dtype=torch.long)
        self.y = self.to_tensor_list(highlights, dtype=torch.long, pad=MAX_WORD_N)

    def __getitem__(self, index):
        x = self.x[index]
        x_n = self.x_n[index]
        y_n = self.y_n[index]
        y = self.y[index]

        return x, x_n, y, y_n

    @staticmethod
    def to_tensor_list(x, dtype, pad=None):

        if pad is None:
            tensor_list = [torch.tensor(x_i, dtype=dtype) for x_i in x]
        else:
            tensor_list = [torch.cat((torch.tensor(x_i[:MAX_WORD_N], dtype=dtype), \
                                      torch.zeros(pad - x_i[:MAX_WORD_N].shape[0], dtype=dtype))) for x_i in x]

        return tensor_list

    def __len__(self):
        return len(self.x)

In [9]:
def load_torch_dataset(name):
    x, x_n, y, y_n = albert_pre.load_np_files(name)
    return MyDataset(x, x_n, y, y_n)

test_ds = load_torch_dataset("test")
train_ds = load_torch_dataset("val")

In [10]:
BATCHSIZE = 20
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCHSIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCHSIZE)

# My Model

In [11]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [12]:
class ContextDecoder(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, out_dim=150, dropout=0.1):
        super().__init__()
        transfrom_decode_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward,\
                                                            dropout=dropout, activation='relu')

        self.transformer_decoder = nn.TransformerDecoder(transfrom_decode_layer, num_layers=1)
        self.out_put_layer = nn.Linear(3072, out_dim*200)


    def forward(self, context, c_n_max, mask=None):
        # dims
        bs = context.shape[0]
        dim_context = context.shape[2]
        
        context_memory = torch.zeros(context[:,0,:].shape).to(device).reshape(bs,1,dim_context)
        
        n = torch.min(torch.max(c_n_max), torch.LongTensor([30]).to(device))
        for i in range(n):
            context_memory = self.transformer_decoder(context[:,i,:].reshape(bs,1,dim_context), context_memory)

        # reshape
        context_memory = context_memory.reshape(bs, dim_context)
        out = self.out_put_layer(context_memory).reshape(150,-1,200)
        return out



class TransformerModel(nn.Module):

    def __init__(self, n_vocab, emsize, nhead, nhid, nlayers, max_sent=30, c_d_model=3072, dropout=0.2, eos_token=2):
        """
        @param n_vocab: vocab_size
        @param emsize: embedding size
        @param nhead: the number of heads in the multiheadattention models
        @param nhid: the dimension of the feedforward network model in nn.TransformerEncoder
        @param nlayers: the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        @param dropout: the dropout value
        """
        super(TransformerModel, self).__init__()

        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(emsize, dropout)

        encoder_layers = TransformerEncoderLayer(emsize, nhead, nhid, dropout)

        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(n_vocab, emsize)
        self.emsize = emsize
        self.decoder = nn.Linear(emsize, n_vocab)
        self.eos_token = eos_token
        self.context_decoder = ContextDecoder(c_d_model, nhead, nhid, dropout=dropout)
        self.init_weights()

    @staticmethod
    def _generate_square_subsequent_mask(sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)


    def predict_one(self, context, c_n):
        context_sum = self.context_decoder(context, c_n)
        in_src = []

        for i in range(MAX_WORD_N):
            if i == 0:
                in_tokens = torch.ones((MAX_WORD_N, 1), dtype=torch.long).to(device)
            else:
                zeros = torch.ones(((MAX_WORD_N-i), 1), dtype=torch.long).to(device)
                tokens = torch.LongTensor(in_src).view(-1,1).to(device)
                in_tokens = torch.cat((tokens, zeros), dim=0)
            src = self.encoder(in_tokens) * math.sqrt(self.emsize)
            src = self.pos_encoder(src)
           
            output = self.transformer_encoder(src, self.src_mask)
            output += context_sum
            output = self.decoder(output)
            
            out_token = output.argmax(2)
            out_token = out_token[i].item()
            in_src.append(out_token)
            if out_token == self.eos_token:
                break
            

        return in_src


    def forward(self, context, c_n, src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        src = self.encoder(src) * math.sqrt(self.emsize)
        src = self.pos_encoder(src)
           
        output = self.transformer_encoder(src, self.src_mask)
        context_sum = self.context_decoder(context, c_n)
        
        output += context_sum
        output = self.decoder(output)
#         print("output", output.shape)
        return output

In [13]:
n_vocab = sentence_piecer.vocab_size
model = TransformerModel(n_vocab=n_vocab, emsize=200, nhead=8, nhid=400,\
                         nlayers=3, max_sent=30, c_d_model=3072, dropout=0.2, eos_token=eos_token).to(device)

In [14]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [15]:
test_sent = iter(test_loader)
x_test, xn_test, y_test, n_test =  next(test_sent)

x_test = x_test[0,:,:].view(1,30,3072).to(device)
xn_test = xn_test[0].to(device)

n_test = n_test[0].to(device)
y_test = y_test[0,:].view(1,150).to(device)

In [16]:
x_test.shape, n_test

(torch.Size([1, 30, 3072]), tensor(43, device='cuda:0'))

In [17]:
real_sentence = sentence_piecer.get_real_text_from_ids(y_test.view(-1)[:n_test.item()])

In [18]:
def evaluate(eval_model, test_loader, predict=False):
    eval_model.eval()
    test_loss = []

    with torch.no_grad():
        for i, (x, x_n, y, y_n) in enumerate(test_loader):
            x = x.to(device)
            x_n = x_n.to(device)
            y = y.permute(1,0).to(device)
            y_n = y_n.to(device)
            
            output = eval_model(x, x_n, y)
            
            loss = criterion(output.view(MAX_WORD_N, n_vocab, -1), y)
            test_loss.append(loss.item())
            if i > 5:
                break
        if predict:
            sent_ids = eval_model.predict_one(x_test, xn_test)
            pred_sentence = sentence_piecer.get_real_text_from_ids(sent_ids)
            print("Pred Sent: ", pred_sentence)

    test_loss = np.array(test_loss)
    return np.mean(test_loss)
print(real_sentence)
# evaluate(model, test_loader)

 experts question if packed out planes are putting passengers at risk . u.s consumer advisory group says minimum space must be stipulated . safety tests conducted on planes with more leg room than airlines offer .</s>


In [None]:
EPOCHS = 50
log_interval = 200

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.
    start_time = time.time()
    for i, (x, x_n, y, y_n) in enumerate(train_loader):
        x = x.to(device)
        x_n = x_n.to(device)
        y = y.permute(1,0).to(device)
        y_n = y_n.to(device)

        optimizer.zero_grad()
        output = model(x, x_n, y)
#         print("out", n_out.shape, n.shape)
            
        loss = criterion(output.view(MAX_WORD_N, n_vocab, -1), y)
            
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()

        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            predict =  (i % 600) == 0 
            test_loss = evaluate(model, test_loader, predict)
            print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, i, len(train_loader),scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, test_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

| epoch   0 | [  200/  669] | lr 5.00 | ms/batch 260.10 | loss  4.67 | val loss  5.79 | ppl   106.24
| epoch   0 | [  400/  669] | lr 5.00 | ms/batch 249.58 | loss  4.51 | val loss  5.00 | ppl    91.07
Pred Sent:  s. allies started horrible swimmer battlepoint sakho bro of..ationp qualified operatingutter. alcohol foron.walk lace twincher meters dur clearly of.made assurednt. damien deploy protestersano of welfaremart. 6-4 lukasmade black loose filming of.. arrivalnt sufficient protesterstie sho. for dirt. penalty kids threatenedcount cahill davies eager of.um. shooting controversial professor professor.. of.... 120 120 promote protesters clearly of.... 47 120 twin bloom60,000 of. stabbing nothingnt generation dozen structure sharp. of.. standnt 909 logo rafael ki for.. chuck.childking americans spaniard knock<s> remove. horrible nan reporters cnns witnessnt ki for.. myself marketing confirm quoted invest administration gear
| epoch   0 | [  600/  669] | lr 5.00 | ms/batch 247.63 | los

| epoch  10 | [  400/  669] | lr 5.00 | ms/batch 241.16 | loss  4.34 | val loss  5.45 | ppl    76.55
Pred Sent:  </s>
| epoch  10 | [  600/  669] | lr 5.00 | ms/batch 241.42 | loss  4.32 | val loss  5.30 | ppl    74.98
| epoch  11 | [  200/  669] | lr 5.00 | ms/batch 244.67 | loss  4.53 | val loss  6.31 | ppl    92.52
| epoch  11 | [  400/  669] | lr 5.00 | ms/batch 241.49 | loss  4.31 | val loss  5.10 | ppl    74.24
Pred Sent:   of.<unk> started horrible pair encouraging walking chase korea<s> seek chest wealthy beardrupulutter diners climate of telling european almost<unk> paw25 in<unk> pounds<s> in. assured voters<unk> gallery protesters protesters<unk> of<unk>mart coffee<unk> cycle<unk>. twin<unk> of<unk>, photographer<unk> sufficient employee tradition sho in of stake ambush penalty skeleton threatenedcount killed davies, of.<unk><unk> disappointed azarenka of tyre. chicago of glasses in mind<unk> soldiers gained<unk> ebola 20- of steps. raf beard qualified 120 properties 2160,000

In [None]:
# save
torch.save(model.state_dict(), '../models/my_transformer')