In [1]:
from torch.utils.data import Dataset, DataLoader
from nltk import sent_tokenize
import torch.nn as nn
import pandas as pd
import numpy as np
import itertools
import easydict
import MeCab
import torch
import json
import re

In [2]:
FILE_PATH = 'D:\\data\\text\\news-articles\\kbanker_articles_subtitles.csv'
CONFIG_PATH = 'config.json'
device = torch.device("cuda:0")
with open(CONFIG_PATH, 'r') as f:
    args = easydict.EasyDict(json.load(f))

In [3]:
#pre-processing shit
def pre_process_raw_article(article):
    """Args
        article: str
    """
    replacements = [
        ('[“”]', '"'),
        ('[‘’]', '\''),
        ('\([^)]*\)', ''),
        ('[^가-힣\'"A-Za-z0-9.\s\?\!]', ' '),
        ('(?=[^0-9])\.(?=[^0-9])', '. '),
        ('\s\s+', ' ')
    ]
    
    for old, new in replacements:
        article = re.sub(old, new, article)
        
    return article

def mecab_tokenize(sentence):
    t = MeCab.Tagger()
    return [re.split(',', re.sub('\t', ',', s))[0] for s in t.parse(sentence).split('\n') if (s!='') & ('EOS' not in s)]

In [4]:
class NLPCorpusDataset(Dataset):
    """NLP Corpus dataset."""

    def __init__(self, csv_file, root_dir):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        articles = pd.read_csv(csv_file, encoding='utf-8')['article'].dropna().values
        articles = [pre_process_raw_article(article) for article in articles]
        sentences = itertools.chain.from_iterable([sent_tokenize(article) for article in articles])
        corpus = [mecab_tokenize(s) for s in list(sentences)]
        self.root_dir = root_dir
        del articles
        del sentences
        
        #construct word matrix
        word_set = set(itertools.chain.from_iterable(corpus))
        self.word_to_idx = {word : idx for idx, word in enumerate(word_set)}
        self.idx_to_word = {self.word_to_idx[word] : word for word in self.word_to_idx}
        del word_set
        corpus = [[self.word_to_idx[word] for word in sentence] for sentence in corpus]
        
        #make train label dataset
        self.x = []
        self.y = []
        for sentence in corpus:
            for i in range(len(sentence) - args.window_size):
                self.x.append(sentence[i:i+args.window_size])
                self.y.append([sentence[i+args.window_size]])
        del corpus

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        return self.x[idx], self.y[idx]

In [None]:
nlp_dataset = NLPCorpusDataset(csv_file=FILE_PATH, root_dir='.')

In [5]:
import pickle
# with open('kbanker_nlp_dataset.pkl', 'wb') as f:
#     pickle.dump(nlp_dataset, f)
    
with open('kbanker_nlp_dataset.pkl', 'rb') as f:
    nlp_dataset = pickle.load(f)

In [6]:
class EmbeddingModule(nn.Module):
    def __init__(self, vocab_size, embed_dim, h_dim):
        super(EmbeddingModule, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim).float()
        self.linear1 = nn.Linear(embed_dim*args.window_size, h_dim)
        self.linear2 = nn.Linear(h_dim, vocab_size)
        self.motorway = nn.Linear(embed_dim*args.window_size, vocab_size)
    
    def forward(self, x):
        embedded = self.embedding(x).view((1, -1))
        embedded_concat = []
        net = self.linear1(embedded.reshape(args.batch_size,args.window_size*args.embedding_dim))
        net = nn.Tanh()(net)
        net = self.linear2(net)
        net = net + self.motorway(embedded.reshape(args.batch_size,args.window_size*args.embedding_dim))
        return net

In [9]:
def collate_fn(data):
    seqs, labels = zip(*data)
    return seqs, labels

dataloader = DataLoader(nlp_dataset, batch_size=args.batch_size, \
                        shuffle=False, num_workers=0, collate_fn=collate_fn)

model = EmbeddingModule(len(nlp_dataset.word_to_idx),\
                        args.embedding_dim, args.h_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
for epoch in range(1):
    print('Epoch ' + str(epoch))
    for i, sample in enumerate(dataloader):
        x = torch.LongTensor(sample[0]).to(device)
        y = torch.LongTensor(sample[1]).reshape(args.batch_size).to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        
        if i % 100 == 99:
            print(i, loss.item())
            print(model.embedding(torch.LongTensor([0]).to(device)))
    
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        
        if i == 1000:
            break

Epoch 0
99 11.532383918762207
tensor([[ 0.7010,  1.0316, -1.0672,  0.3053, -1.1856,  0.3842, -0.5599,  0.4130,
         -0.8032,  1.1080,  1.2751, -1.6904,  0.2204,  0.7843, -0.6552,  0.5292,
          0.8784,  0.3009,  0.7784,  1.0987, -1.3345,  0.6207,  2.1669, -1.0885,
          1.4550, -1.0350,  0.0527, -0.7317,  0.2984, -0.6801, -0.9115, -0.6598,
          0.2099,  1.1501, -0.6335, -1.1527, -0.3100, -0.9189, -1.1825, -2.6594,
         -0.6708,  0.6252, -0.0678,  0.3141,  0.4518,  0.0474,  1.3313, -0.4428,
         -0.5219,  1.7468,  1.8137,  0.8343,  1.4478, -1.6874,  0.2213, -1.3763,
         -0.6182,  0.2024, -0.4245, -2.2423, -0.0670,  0.9719,  0.9227, -0.3641,
         -0.7674,  0.3908, -1.2225, -0.0809, -0.0822,  0.0715,  0.7930, -1.2627,
          0.1679,  0.2641, -0.2950,  1.5672, -0.5050, -0.8829,  1.0967,  2.0495,
         -1.4386, -0.1893,  0.3223, -0.1637,  0.8473, -1.4654, -1.2605, -0.3227,
          1.0614, -1.9751, -0.8907, -1.5332, -0.0130, -0.5478,  0.8994, -0.2433

899 10.517586708068848
tensor([[ 0.7010,  1.0316, -1.0672,  0.3053, -1.1856,  0.3842, -0.5599,  0.4130,
         -0.8032,  1.1080,  1.2751, -1.6904,  0.2204,  0.7843, -0.6552,  0.5292,
          0.8784,  0.3009,  0.7784,  1.0987, -1.3345,  0.6207,  2.1669, -1.0885,
          1.4550, -1.0350,  0.0527, -0.7317,  0.2984, -0.6801, -0.9115, -0.6598,
          0.2099,  1.1501, -0.6335, -1.1527, -0.3100, -0.9189, -1.1825, -2.6594,
         -0.6708,  0.6252, -0.0678,  0.3141,  0.4518,  0.0474,  1.3313, -0.4428,
         -0.5219,  1.7468,  1.8137,  0.8343,  1.4478, -1.6874,  0.2213, -1.3763,
         -0.6182,  0.2024, -0.4245, -2.2423, -0.0670,  0.9719,  0.9227, -0.3641,
         -0.7674,  0.3908, -1.2225, -0.0809, -0.0822,  0.0715,  0.7930, -1.2627,
          0.1679,  0.2641, -0.2950,  1.5672, -0.5050, -0.8829,  1.0967,  2.0495,
         -1.4386, -0.1893,  0.3223, -0.1637,  0.8473, -1.4654, -1.2605, -0.3227,
          1.0614, -1.9751, -0.8907, -1.5332, -0.0130, -0.5478,  0.8994, -0.2433,
     