In [72]:
from torch.utils.data import Dataset, DataLoader
from nltk import sent_tokenize
import torch.nn as nn
import pandas as pd
import numpy as np
import itertools
import easydict
import MeCab
import torch
import json
import re

In [75]:
FILE_PATH = 'D:\\data\\text\\news-articles\\kbanker_articles_subtitles.csv'
CONFIG_PATH = 'config.json'
device = torch.device("cuda:0")
with open(CONFIG_PATH, 'r') as f:
    args = easydict.EasyDict(json.load(f))

In [3]:
#pre-processing shit
def pre_process_raw_article(article):
    """Pre-processing news articles.
    
    Args
        article (str): article text
    
    """
    replacements = [
        ('[“”]', '"'),
        ('[‘’]', '\''),
        ('\([^)]*\)', ''),
        ('[^가-힣\'"A-Za-z0-9.\s\?\!]', ' '),
        ('(?=[^0-9])\.(?=[^0-9])', '. '),
        ('\s\s+', ' ')
    ]
    
    for old, new in replacements:
        article = re.sub(old, new, article)
        
    return article

def mecab_tokenize(sentence):
    t = MeCab.Tagger()
    return [re.split(',', re.sub('\t', ',', s))[0] for s in t.parse(sentence).split('\n') if (s!='') & ('EOS' not in s)]

In [4]:
class NLPCorpusDataset(Dataset):
    """NLP Corpus dataset.
    
    Args:
        csv_file (str): Path to the csv file
        root_dir (str): root
        
    Attributes:
        root_dir (str): root
        word_to_idx (dict): word_to_idx mapping
        idx_to_word (dict): idx_to_word mapping
        x (list): train data (5-gram)
        y (list): label
        
    """

    def __init__(self, csv_file, root_dir):
        articles = pd.read_csv(csv_file, encoding='utf-8')['article'].dropna().values
        articles = [pre_process_raw_article(article) for article in articles]
        sentences = itertools.chain.from_iterable([sent_tokenize(article) for article in articles])
        corpus = [mecab_tokenize(s) for s in list(sentences)]
        self.root_dir = root_dir
        del articles
        del sentences
        
        #construct word matrix
        word_set = set(itertools.chain.from_iterable(corpus))
        self.word_to_idx = {word : idx for idx, word in enumerate(word_set)}
        self.idx_to_word = {self.word_to_idx[word] : word for word in self.word_to_idx}
        del word_set
        corpus = [[self.word_to_idx[word] for word in sentence] for sentence in corpus]
        
        #make train label dataset
        self.x = []
        self.y = []
        for sentence in corpus:
            for i in range(len(sentence) - args.window_size):
                self.x.append(sentence[i:i+args.window_size])
                self.y.append([sentence[i+args.window_size]])
        del corpus

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        return self.x[idx], self.y[idx]

In [None]:
nlp_dataset = NLPCorpusDataset(csv_file=FILE_PATH, root_dir='.')

In [12]:
import pickle
# with open('kbanker_nlp_dataset.pkl', 'wb') as f:
#     pickle.dump(nlp_dataset, f)
    
with open('D:\\data\\text\\torch-dataset\\kbanker_nlp_dataset.pkl', 'rb') as f:
    nlp_dataset = pickle.load(f)

In [49]:
class EmbeddingModule(nn.Module):
    def __init__(self, vocab_size, embed_dim, h_dim):
        super(EmbeddingModule, self).__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.h_dim = h_dim
        
        self.embedding = nn.Embedding(vocab_size, embed_dim).float()
        self.embedding.weight.data.uniform_(-1,1)
        self.linear1 = nn.Linear(embed_dim*args.window_size, h_dim)
        self.linear1.weight.data.uniform_(-1,1)
        self.linear2 = nn.Linear(h_dim, vocab_size)
        self.linear2.weight.data.uniform_(-1,1)
        self.motorway = nn.Linear(embed_dim*args.window_size, vocab_size)
        self.motorway.weight.data.uniform_(-1,1)
        
        self.tanh = nn.Tanh()
    
    def forward(self, x):
        embedded = self.embedding(x).view(args.batch_size, args.window_size*self.embed_dim)
        embedded.retain_grad()
        net = self.linear1(embedded)
        net = self.tanh(net)
        net = self.linear2(net)
        net = net + self.motorway(embedded)
        return net

In [None]:
def collate_fn(data):
    seqs, labels = zip(*data)
    return seqs, labels

dataloader = DataLoader(nlp_dataset, batch_size=args.batch_size, \
                        shuffle=False, num_workers=0, collate_fn=collate_fn)

model = EmbeddingModule(len(nlp_dataset.word_to_idx),\
                        args.embedding_dim, args.h_dim).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for epoch in range(args.epoch):
    print('Epoch ' + str(epoch))
    for i, sample in enumerate(dataloader):
        x = torch.LongTensor(sample[0]).to(device)
        y = torch.LongTensor(sample[1]).view(args.batch_size).to(device)
        y_pred = model(x)
        
#         before = model.embedding.weight.data.cpu().numpy()
        
        loss = criterion(y_pred, y)
        
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
#         after = model.embedding.weight.data.cpu().numpy()
        
        if i % 100 == 99:
            print(i, loss.item())
#             print(model.embedding.weight.grad)
#             print((before-after).sum())
        
        if i == 1000:
            break

Epoch 0
99 63.13578414916992
199 63.04270553588867
299 64.4983139038086
399 64.17352294921875
499 64.94847869873047
599 64.23686218261719
699 63.3019905090332
799 62.33510208129883
899 62.901512145996094
999 63.828792572021484
Epoch 1
99 62.62953186035156
199 62.32461929321289
299 63.7624397277832
399 63.61264419555664
499 64.48085021972656
599 63.7266845703125
699 62.833984375
799 61.30979919433594
899 62.07927322387695
999 63.38001251220703
Epoch 2
99 62.12385559082031
199 61.63299560546875
299 63.055423736572266
399 63.054622650146484
499 64.00579071044922
599 63.22053146362305
699 62.382171630859375
799 60.31037521362305
899 61.2607536315918
999 62.94895935058594
Epoch 3
99 61.623714447021484
199 60.984073638916016
299 62.363380432128906
399 62.503211975097656
499 63.52195358276367
599 62.722198486328125
699 61.94243240356445
799 59.328678131103516
899 60.45521926879883
999 62.51863479614258
Epoch 4
99 61.128238677978516
199 60.37434387207031
299 61.67583465576172
399 61.9541168212

In [62]:
args.batch_size

128