In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import random
import pickle
import numpy as np
import math
import os

# A little bit of something to process Korean
import unicodedata
import string
import hanja
from hanja import hangul

def fix_unicode(s):
    return ''.join(x for x in unicodedata.normalize('NFC', s))

stopwords = ['ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄷ', 'ㄸ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 
             'ㄽ', 'ㅁ', 'ㅂ', 'ㅃ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅉ', 'ㅊ', 
             'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ', 'ㅏ', 'ㅐ', 'ㅑ', 'ㅓ', 'ㅔ', 'ㅕ', 
             'ㅖ', 'ㅗ', 'ㅘ', 'ㅙ', 'ㅛ', 'ㅜ', 'ㅝ', 'ㅞ', 'ㅠ', 'ㅡ', 
             'ㅢ', 'ㅣ', 'ㅻ', 'ㆍ', '\t', ' ']

In [None]:
class data2Dataset(Dataset):
    def __init__(self, data, label, max_len):
        self.X = data
        self.y = label
        self.max_len = max_len


    def __len__(self):
        return len(self.y)


    def __getitem__(self, idx):
        batch_x = self.X[idx]
        batch_y = self.y[idx]        
        if len(batch_x) > self.max_len:
            batch_x = batch_x[:self.max_len]
            batch_x = [char2idx.get(x, char2idx['UNK']) for x in batch_x]
            position = torch.arange(0,self.max_len)
        else:
            batch_x = [d for d in batch_x]
            len_x = len(batch_x)
            batch_x = batch_x + ['PAD'] * (self.max_len-len_x)
            batch_x = [char2idx.get(x, char2idx['UNK']) for x in batch_x]
            position = F.pad(torch.arange(0, len_x), (0, self.max_len-len_x), 'constant', 0)
        batch_y = label2idx[batch_y]
        
        return torch.tensor(batch_x).long(), torch.tensor(batch_y).long(), position.long()

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len=100):
        super(PositionalEmbedding, self).__init__()

        pe = torch.zeros(max_len, embedding_dim).float()
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, embedding_dim, 2).float()*
                    -(math.log(10000.0) / embedding_dim)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.pos_embed = nn.Embedding.from_pretrained(pe, freeze=True)
        
    def forward(self, x):
        return self.pos_embed(x)

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len):
        super().__init__()
        self.token = nn.Embedding(vocab_size, embedding_dim)
        self.position = PositionalEmbedding(embedding_dim=self.token.embedding_dim, max_len=max_len)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x, p):
        p = self.position(p)
        x = self.token(x)
        x = x + p
        return self.dropout(x)

In [None]:
class Rumble(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len, num_layers, num_heads):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.max_len = max_len
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.embedding = TokenEmbedding(vocab_size=self.vocab_size, 
                                        embedding_dim=self.embedding_dim, 
                                        max_len=self.max_len)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, 
                                                   nhead=num_heads, 
                                                   dim_feedforward=1024, 
                                                   dropout=0.1, 
                                                   activation='gelu')
        layer_norm = nn.LayerNorm(embedding_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, 
                                             num_layers=num_layers, 
                                             norm=layer_norm)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, x, p):
        x = self.embedding(x, p)
        x = self.encoder(x)
        x = self.linear(x)               
        return F.log_softmax(x, dim=-1)

In [None]:
class Soundwave(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.max_len = max_len
        self.linear = nn.Linear(embedding_dim*max_len, label_size)
        self.dropout = nn.Dropout(p=0.1)
        
    def forward(self, x):
        x = x.view(-1, self.embedding_dim * self.max_len)
        x = self.linear(self.dropout(x))             
        return F.log_softmax(x, dim=-1)

In [None]:
with open('training_pair.pickle', 'rb') as pp:
    data = pickle.load(pp)

texts = [''.join(fix_unicode(x).split()) for y, x in data]
text_filtered = []
for text in texts:
    text = [x for x in text if (x in string.printable) or (hangul.is_hangul(x))]
    text = [x for x in text if x not in stopwords]
    text_filtered.append(text)

labels = [fix_unicode(y) for y, x in data]
label2idx = {x: c for c, x in enumerate(sorted(labels))}
idx2label = {c: x for c, x in enumerate(sorted(labels))}

with open('char2idx_wMASK.pickle', 'rb') as pp:
    char2idx = pickle.load(pp)

In [None]:
vocab_size = len(char2idx)
label_size = len(idx2label)
embedding_dim = 256
max_len = 40
num_layers = 8
num_heads = 8
batch_size = 32
lr = 2e-5

train_dataset = data2Dataset(text_filtered, labels, max_len)
train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                          num_workers=12, shuffle=True, pin_memory=True)

rumble = Rumble(vocab_size=vocab_size, embedding_dim=embedding_dim, max_len=max_len, 
                     num_layers=num_layers, num_heads=num_heads).to(device)
rumble_path = '/path/to/rumble/model/file.pt'
rumble.load_state_dict(torch.load(rumble_path))

soundwave = Soundwave(vocab_size=vocab_size, embedding_dim=embedding_dim, 
                      max_len=max_len).to(device)
optimizer = optim.Adam(soundwave.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
loss_func = nn.NLLLoss(ignore_index=0)

In [None]:
num_epoch = 2
plot_train = []
beg = time.time()
tot_length = len(train_loader)
soundwave.train()
for epoch in range(num_epoch):    
    train_loss = 0
    for e, d in enumerate(train_loader):
        X, y, p = d
        X, y, p = X.to(device), y.to(device), p.to(device)

        optimizer.zero_grad()
        feature = rumble.embedding(X,p)
        feature = rumble.encoder(feature)
        
        y_hat = soundwave(feature)
        loss = loss_func(y_hat, y)
        
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        
        iter_num = e + epoch*tot_length
        
        if iter_num%250 == 1:
            print('loss at: {}, {:.3f}, time so far: {}'.format(
                iter_num, train_loss/e, time.time()-beg))
            torch.save(soundwave.state_dict(), 
                       'Soundwave_{iter}_iter.pt'.format(iter=iter_num))
            plot_train.append(train_loss/e)

    scheduler.step()