In [1]:
import numpy as np
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from my_sentence_piecer import MySentencePiecer
from albert_pre import AlbertPre
from tf_to_csv import TfToCsv

In [2]:
torch.cuda.current_device(), torch.cuda.get_device_name(device=None)

(0, 'GeForce RTX 2080 Ti')

# Params

In [3]:
MAX_SENT_N = 30

MAX_WORD_N = 150

MAX_WORD_SENT_N = 300

BATCHSIZE = 20

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data Processing

## Sentence Piecer

In [4]:
sentence_piecer = MySentencePiecer(vocab_size=10000, force_update=False)

In [5]:
print(sentence_piecer.vocab_size)
print(sentence_piecer.vocab_list[:20])
test = "hallo, i'm leaving. this is another sentences."
tokens = sentence_piecer.get_ids_from_vocab(test)
print(tokens)
print(sentence_piecer.get_real_text_from_ids(tokens))

10000
['<unk>', '<s>', '</s>', '▁the', 's', ',', '.', '▁to', '▁a', '▁in', '▁of', '▁and', '▁.', "'", '-', '▁was', '▁for', '▁on', '▁is', '▁he']
[1459, 118, 5, 46, 13, 74, 1111, 6, 57, 18, 220, 1100, 4, 6, 2]
 hallo, i'm leaving. this is another sentences.</s>


In [6]:
len(sentence_piecer.vocab_list)

10000

In [7]:
albert_pre = AlbertPre()

## Dataset

In [8]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, article, n_highlights, highlights,transform=None):
        self.x = self.to_tensor_list(article, dtype=torch.float)

        self.y_n = torch.tensor(n_highlights, dtype=torch.long)
        self.y = self.to_tensor_list(highlights, dtype=torch.long, pad=MAX_WORD_N)

    def __getitem__(self, index):
        x = self.x[index]
        y_n = self.y_n[index]
        y = self.y[index]

        return x, y_n, y

    @staticmethod
    def to_tensor_list(x, dtype, pad=None):

        if pad is None:
            tensor_list = [torch.tensor(x_i, dtype=dtype) for x_i in x]
        else:
            tensor_list = [torch.cat((torch.tensor(x_i[:MAX_WORD_N], dtype=dtype), \
                                      torch.zeros(pad - x_i[:MAX_WORD_N].shape[0], dtype=dtype))) for x_i in x]

        return tensor_list

    def __len__(self):
        return len(self.x)

In [9]:
def load_torch_dataset(name):
    x, x_n, y_n,y = albert_pre.load_np_files(name)
    return MyDataset(x,y_n,y)

test_ds = load_torch_dataset("test")
train_ds = load_torch_dataset("val")

In [10]:
BATCHSIZE = 10
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCHSIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCHSIZE)

# My Model

In [11]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [12]:
class ContextDecoder(nn.Module):
    def __init__(self, max_sent, d_model, nhead, dim_feedforward, out_dim=150, dropout=0.1):
        super().__init__()
        self.max_sent = max_sent
        transfrom_decode_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward,\
                                                            dropout=dropout, activation='relu')

        self.transformer_decoder = nn.TransformerDecoder(transfrom_decode_layer, num_layers=1)
        self.out_put_layer = nn.Linear(3072, out_dim*200)


    def forward(self, context, mask=None):
        # dims
        bs = context.shape[0]
        dim_context = context.shape[2]

        context_memory = torch.zeros(context[:,0,:].shape).to(device).reshape(bs,1,dim_context)

        for i in range(self.max_sent):
            context_memory = self.transformer_decoder(context[:,i,:].reshape(bs,1,dim_context), context_memory)

        # reshape
        context_memory = context_memory.reshape(bs, dim_context)
        out = self.out_put_layer(context_memory).reshape(150,-1,200)
        return out



class TransformerModel(nn.Module):

    def __init__(self, n_vocab, emsize, nhead, nhid, nlayers, max_sent=30, c_d_model=3072, dropout=0.2):
        """
        @param n_vocab: vocab_size
        @param emsize: embedding size
        @param nhead: the number of heads in the multiheadattention models
        @param nhid: the dimension of the feedforward network model in nn.TransformerEncoder
        @param nlayers: the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        @param dropout: the dropout value
        """
        super(TransformerModel, self).__init__()

        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(emsize, dropout)

        encoder_layers = TransformerEncoderLayer(emsize, nhead, nhid, dropout)

        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(n_vocab, emsize)
        self.emsize = emsize
        self.decoder = nn.Linear(emsize, n_vocab)
        self.context_decoder = ContextDecoder(max_sent, c_d_model, nhead, nhid, dropout=dropout)
        self.init_weights()

    @staticmethod
    def _generate_square_subsequent_mask(sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)


    def predict_one(self, context, n):
        context_sum = self.context_decoder(context)
        in_src = []

        for i in range(torch.max(n)):
            if i == 0:
                in_tokens = torch.ones((MAX_WORD_N, 1), dtype=torch.long).to(device)
            else:
                zeros = torch.ones(((MAX_WORD_N-i), 1), dtype=torch.long).to(device)
                tokens = torch.LongTensor(in_src).view(-1,1).to(device)
                in_tokens = torch.cat((tokens, zeros), dim=0)
            src = self.encoder(in_tokens) * math.sqrt(self.emsize)
            src = self.pos_encoder(src)
           
            output = self.transformer_encoder(src, self.src_mask)
            output += context_sum
            output = self.decoder(output)
            out_token = output.argmax(2)
            out_token = out_token[i].item()
            in_src.append(out_token)

        return in_src


    def forward(self, context, src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        src = self.encoder(src) * math.sqrt(self.emsize)
        src = self.pos_encoder(src)
           
        output = self.transformer_encoder(src, self.src_mask)
        context_sum = self.context_decoder(context)
        
        output += context_sum
        output = self.decoder(output)
#         print("output", output.shape)
        return output

In [13]:
n_vocab = sentence_piecer.vocab_size
model = TransformerModel(n_vocab=n_vocab, emsize=200, nhead=2, nhid=200,\
                         nlayers=1, max_sent=30, c_d_model=3072, dropout=0.2).to(device)

In [14]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [15]:
test_sent = iter(test_loader)
x_test, n_test, y_test =  next(test_sent)

x_test = x_test[0,:,:].view(1,30,3072).to(device)
n_test = n_test[0].to(device)
y_test = y_test[0,:].view(1,150).to(device)

In [16]:
x_test.shape, n_test

(torch.Size([1, 30, 3072]), tensor(43, device='cuda:0'))

In [17]:
real_sentence = sentence_piecer.get_real_text_from_ids(y_test.view(-1)[:n_test.item()])

In [18]:
def evaluate(eval_model, test_loader):
    eval_model.eval()
    test_loss = []

    with torch.no_grad():
        for i, (x, n, y) in enumerate(test_loader):
            x = x.to(device)
            n = n.to(device)
            y = y.permute(1,0).to(device)

            output = eval_model(x, y)
            loss = criterion(output.view(MAX_WORD_N, n_vocab, -1), y)
            test_loss.append(loss.item())
            if i > 5:
                break

        sent_ids = eval_model.predict_one(x_test, n_test)
        pred_sentence = sentence_piecer.get_real_text_from_ids(sent_ids)

#     print("REAL Sent: ", real_sentence)
    print("Pred Sent: ", pred_sentence)

    test_loss = np.array(test_loss)
    return np.mean(test_loss)
print(real_sentence)

 experts question if packed out planes are putting passengers at risk . u.s consumer advisory group says minimum space must be stipulated . safety tests conducted on planes with more leg room than airlines offer .</s>
Pred Sent:   account re shark debt shark beaten appointed shark satisfied shark satisfied account teenager verbal attend appointed documentary compwilfried shark devoted muslims teenager morning teenager jail command command command sharkwrittenfra concealabove command information informationfrafra ambitious stability rob want


In [19]:
EPOCHS = 20
log_interval = 200

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.
    start_time = time.time()
    for i, (x, n, y) in enumerate(train_loader):
        x = x.to(device)
        n = n.to(device)
        y = y.permute(1,0).to(device)
      

        optimizer.zero_grad()
        output = model(x, y)
        loss = criterion(output.view(MAX_WORD_N, n_vocab, -1), y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()

        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            test_loss = evaluate(model, test_loader)
            print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, i, len(train_loader),scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, test_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

Pred Sent:   its me wantmmer a a a a a a them all a land a a a a a a ""he a a a a a a a a world me a a a a a a a aalhe a
| epoch   0 | [  200/ 1337] | lr 5.00 | ms/batch 219.57 | loss  6.28 | val loss  9.54 | ppl   533.92
Pred Sent:  al justa tax tax tax tax tax inquiry taxalal taxa tax tax tax tax inquiry inquiryalna tax tax tax tax tax inquiry taxalaaa00 tax inquiry inquiry inquiry inquiryala tax
| epoch   0 | [  400/ 1337] | lr 5.00 | ms/batch 213.96 | loss  5.23 | val loss  4.33 | ppl   186.01
Pred Sent:   this mark mark mark in in bonus in in in one have in in in in in in in in have wife in in in in in in in in them in in in in in in in in in. in in
| epoch   0 | [  600/ 1337] | lr 5.00 | ms/batch 214.32 | loss  4.70 | val loss  4.79 | ppl   109.59
Pred Sent:   a a a tax a tax to to tax tax a a a a to to to tax tax to a a a to to to to to to tax a a a a to to to to to to a a a
| epoch   0 | [  800/ 1337] | lr 5.00 | ms/batch 217.04 | loss  4.33 | val loss  3.73 | ppl    76.14
Pred

Pred Sent:  <s> appealinglevel appealing appealing appealing appealing appealing appealing appealing<s> 0level explain appealing appealinghm debris refugee appealing<s>idlevel appealing appealing appealing appealing appealing appealing appealing, appealinglevel appealing appealing setting original original expire appealing<s> classesmart
| epoch   4 | [ 1000/ 1337] | lr 5.00 | ms/batch 229.39 | loss  3.67 | val loss  3.53 | ppl    39.17
Pred Sent:  </s> visiting</s></s></s></s></s> joke joke joke</s></s></s></s></s></s>hm joke joke joke</s></s></s></s></s></s> joke joke joke joke</s></s></s></s></s></s></s> joke joke joke</s></s></s>
| epoch   4 | [ 1200/ 1337] | lr 5.00 | ms/batch 228.97 | loss  3.68 | val loss  3.53 | ppl    39.84
Pred Sent:   has aguero southeastairvivvivviv lambvivvivs 0 commission pipeline novakvivreadvivvivvivs commissionviv pipeline sessionviv novakviv session sessionset poppy commission commission appealing novak immune studio appealings commissionlevel
| epoch

Pred Sent:   tobillion toair asking appealing appealing a a appealing a a a a a a65 a to users a a a to a to to a to appealing a a a write a a a wheat wheat appealing a a a
| epoch   8 | [  800/ 1337] | lr 5.00 | ms/batch 220.11 | loss  3.42 | val loss  3.69 | ppl    30.64
Pred Sent:  <unk> chemical 75 sex social hand appealing stuff chemical possible say breach ipswich pipelinequimore individuals individuals individuals running<s> rush dropped youngestrie running spectator spectator running rush have superintendent rush rush rush rush rush rush rushrie world merely users
| epoch   8 | [ 1000/ 1337] | lr 5.00 | ms/batch 220.93 | loss  3.43 | val loss  3.75 | ppl    30.95
Pred Sent:   her santiagorich lit quest hand tiny 3,000 boxing 3,000 ‘ 3,000 activist treat historian expirehouse joke slur nhs them dogs airwick fun eric upon uponari upond itself upon brand rack fun seeking jokewing fun</s> bloom ties
| epoch   8 | [ 1200/ 1337] | lr 5.00 | ms/batch 220.43 | loss  3.45 | val loss  3.

Pred Sent:  ed hipbat valentine norway fbi lit penalties deeply novak<s> objects bad lit novak novak novak novak beef users world novak sniff investigating novak juan m 3,000 novak asking them romance write write took strategytan novak cuddle novak atan hal
| epoch  12 | [  600/ 1337] | lr 5.00 | ms/batch 220.29 | loss  3.29 | val loss  3.84 | ppl    26.97
Pred Sent:   short cautious bloomfriend point pen 09:4 double asking asking – objects deciding horrifying asking include cautious include asking asking have wheat cancellationping leak 09:4 leak bieber bieber appealing adent wheat wheat essential pen god somehow cautious to to traditional transfer
| epoch  12 | [  800/ 1337] | lr 5.00 | ms/batch 220.56 | loss  3.19 | val loss  3.81 | ppl    24.17
Pred Sent:  johnrichrich includezan lambert pleased treat totally slice being cautious carriageleg deemed include65billion items users ""shaw vitamin youngest qpr jaguar carriage include include carriage "" dragon usersourie watched users in

Pred Sent:  john panamaiest lit afterwards trace lit fianc65 lit them litgrand litgrand novak want novak demonstration lit them placing depression lit 2002 depression juice beef lit lit would lit pra ricaiest want depression novakrich novakd 08:4 started
| epoch  16 | [  200/ 1337] | lr 5.00 | ms/batch 224.35 | loss  3.37 | val loss  3.84 | ppl    29.14
Pred Sent:   mean scariest keep raw dish keep parapet doubt down author novak rawship novak novak novak remain castle world novakarian maj keepkel placed nor novak novak world sheriff novak novak novak novak novak novak passport novak world keep novak
| epoch  16 | [  400/ 1337] | lr 5.00 | ms/batch 220.60 | loss  3.05 | val loss  3.98 | ppl    21.09
Pred Sent:  ed lit name cafe lit nose five mosspet novak in reached in lit affidavit novak novak novak beef five world five affordableiyapet five m beefpet rupert president qualifier gates novak novak in seekingpet turned novak inlan in
| epoch  16 | [  600/ 1337] | lr 5.00 | ms/batch 220.2

Pred Sent:   penalties scar removing sto objects dish scar penaltiespet 3,000 will objects include franchise volunteers novak billy jones itemswan<s> santiago rupertpet novak approximately deciding rupert include lit, 6-2 novak novak franchise mainly constant novak ca persistent<unk>iest 999
| epoch  19 | [ 1000/ 1337] | lr 5.00 | ms/batch 219.70 | loss  2.87 | val loss  4.18 | ppl    17.72
Pred Sent:   her scar decided charm toyota however valentine properties portion nonprofit being jeff however aim however etihad soul80crib individuals being depression aim however however however buckingham norhousecrib their qualifier howeverwan hiding pen buckingham buckingham packaging alexandra. other however
| epoch  19 | [ 1200/ 1337] | lr 5.00 | ms/batch 219.96 | loss  2.92 | val loss  4.08 | ppl    18.63
