In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import random
import pickle
import numpy as np
import math
import os

# A little bit of something to process Korean
import unicodedata
import string
import hanja
from hanja import hangul

def fix_unicode(s):
    return ''.join(x for x in unicodedata.normalize('NFC', s))

stopwords = ['ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄷ', 'ㄸ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 
             'ㄽ', 'ㅁ', 'ㅂ', 'ㅃ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅉ', 'ㅊ', 
             'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ', 'ㅏ', 'ㅐ', 'ㅑ', 'ㅓ', 'ㅔ', 'ㅕ', 
             'ㅖ', 'ㅗ', 'ㅘ', 'ㅙ', 'ㅛ', 'ㅜ', 'ㅝ', 'ㅞ', 'ㅠ', 'ㅡ', 
             'ㅢ', 'ㅣ', 'ㅻ', 'ㆍ', '\t', ' ']

In [None]:
class MLMData(Dataset):
    def __init__(self, data):
        self.data = data


    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx):
        sample = self.data[idx]
        replaced, label = self.pick_random(sample)
        if len(replaced) < max_len: # When shorter than max_len
            position = np.pad(np.arange(0, len(replaced)), 
                              (0, max_len-len(replaced)), 
                              mode='constant')
            replaced = replaced + [char2idx['PAD']] * (max_len - len(replaced))
            label = label + [char2idx['PAD']] * (max_len - len(label))        
        elif len(replaced) >= max_len:
            replaced = replaced[:max_len]
            label = label[:max_len]
            position = np.arange(0, max_len)
        
        return torch.tensor(replaced), position, torch.tensor(label)
        
        
    def pick_random(self, sample):
        label = []
        replaced = []
        for e, char in enumerate(sample):
            q_pick = random.random()
            if q_pick <= 0.15:
                toss = random.random()
                
                if toss <= 0.80:
                    replaced.append(char2idx['MASK'])
                    
                elif toss <= 0.90:
                    replaced.append(char2idx[list(char2idx.keys())[random.randint(0, len(char2idx)-1)]])
                    
                else:
                    replaced.append(char2idx.get(sample[e], char2idx['UNK']))
            
                label.append(char2idx.get(char, char2idx['UNK']))
            
            else:
                replaced.append(char2idx.get(sample[e], char2idx['UNK']))
                label.append(0)
            
        return replaced, label

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len):
        super(PositionalEmbedding, self).__init__()

        pe = torch.zeros(max_len, embedding_dim).float()
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, embedding_dim, 2).float()*
                    -(math.log(10000.0)/embedding_dim)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.pos_embed = nn.Embedding.from_pretrained(pe, freeze=True)
        
    def forward(self, x):
        return self.pos_embed(x)

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len):
        super().__init__()
        self.token = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.position = PositionalEmbedding(embedding_dim=self.token.embedding_dim, 
                                            max_len=max_len)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x, p):
        p = self.position(p)
        x = self.token(x)
        x = x + p
        return self.dropout(x)

In [None]:
class Rumble(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len, num_layers, num_heads):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.max_len = max_len
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.embedding = TokenEmbedding(vocab_size=self.vocab_size, 
                                        embedding_dim=self.embedding_dim, 
                                        max_len=self.max_len)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, 
                                                   nhead=num_heads, 
                                                   dim_feedforward=1024, 
                                                   dropout=0.1, 
                                                   activation='gelu')
        layer_norm = nn.LayerNorm(embedding_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, 
                                             num_layers=num_layers, 
                                             norm=layer_norm)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, x, p):
        x = self.embedding(x, p)
        x = self.encoder(x)
        x = self.linear(x)
        return F.log_softmax(x, dim=-1)

In [None]:
# Assuming data is a list of names
with open('data.pickle', 'rb') as pp:
    data = pickle.load(pp)

data = [''.join(fix_unicode(x).split()) for x in data]
data_processed = []
for d in data:
    d = [x for x in d if (x in string.printable) or (hangul.is_hangul(x))]
    d = [x for x in d if x not in stopwords]
    data_processed.append(d)

vocab = ''.join(stores_filtered)
vocab = set(vocab)
    
char2idx = {}
char2idx['PAD'] = 0
char2idx['UNK'] = 1
char2idx['MASK'] = 2
char2idx.update({x: c+len(char2idx) for c, x in enumerate(sorted(vocab))})

with open('char2idx.pickle', 'wb') as pp:
    pickle.dump(char2idx, pp)

In [None]:
vocab_size = len(char2idx)
embedding_dim = 256
max_len = 40
num_layers = 8
num_heads = 8
batch_size = 500
lr = 1e-4

train_dataset = MLMData(data_processed)
train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                          num_workers=12, shuffle=True, pin_memory=True)
rumble = Rumble(vocab_size=vocab_size, embedding_dim=embedding_dim, max_len=max_len, 
                      num_layers=num_layers, num_heads=num_heads).to(device)

In [None]:
optimizer = optim.Adam(rumble.parameters(), lr=lr)
loss_func = nn.NLLLoss(ignore_index=0)

In [None]:
num_epoch = 50
plot_train = []
beg = time.time()
tot_length = len(train_loader)
for epoch in range(num_epoch):
    rumble.train()
    
    train_loss = 0
    for e, d in enumerate(train_loader):
        X, p, y = d
        X, p, y = X.to(device), p.to(device), y.to(device)

        optimizer.zero_grad()
        y_hat = rumble(X, p)
        loss = loss_func(y_hat.transpose(1,2), y)
        
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        
        iter_num = e + epoch*tot_length
        
        if iter_num%250 == 1:
            print('loss at: {}, {:.3f}, time so far: {}'.format(
                iter_num, train_loss/e, time.time()-beg))
            plot_train.append(train_loss/e)
            
        if iter_num%3000 == 2999:
            torch.save(rumble.state_dict(), 
                       'rumble_{e}_iter.pt'.format(e=iter_num))