In [None]:
import argparse
import os
import torch
import gensim
import dgl
import dgl.function as fn
import torch.optim as optim
import math
import time
from torch.utils.data import TensorDataset 
from torch.utils.data import DataLoader
from torch import nn
from encoder import DocuEncoder, ClassEncoder, DocumentTokenizer
from layer import GCN
from classifier import TextClassifier
from preprocessor import TaxoDataManager, DocumentManager
from gensim.test.utils import datapath
from gensim.models import word2vec

In [None]:
import discord_notify as dn

WEBHOOK_URL = 'https://discord.com/api/webhooks/917284193036275712/2Da9DmvQjYugyP8pzvB4AzPMqVEizyVipHYLDPE79ZySU2aPGL3imH-YdcqkiUZxf_ku'

notifier = dn.Notifier(WEBHOOK_URL)
# notifier.send("노트북 실행 시작: train_amazon.ipynb", print_message=False)

In [None]:
DATA_ROOT = 'data/'
TRAINING_DATA_DIR = os.path.join(DATA_ROOT, 'training_data/amazon/')
TOKEN_LENGTH = 500
CLASS_LENGTH = 768

word2vec_model = word2vec.Word2Vec.load(os.path.join(DATA_ROOT, 'pretrained/embedding'))

taxo_manager = TaxoDataManager(TRAINING_DATA_DIR, 'taxonomy.json', 'amazon', word2vec_model)
taxo_manager.load_all()

document_tokenizer = DocumentTokenizer(DATA_ROOT, TOKEN_LENGTH)
graph = taxo_manager.get_graph().to('cuda:0')
features = taxo_manager.get_feature().cuda()

dim = word2vec_model.wv.vector_size
gcn = GCN(dim, dim, dim, 2, nn.ReLU())
class_encoder = ClassEncoder(gcn, word2vec_model)

In [None]:
def create_data_loader(data_name, document_file, token_length, batch_size):
    elapsed_start = time.time()
    training_data_dir = os.path.join(DATA_ROOT, f'training_data/{data_name}/')
    training_document_manager = DocumentManager(document_file, training_data_dir, f'{data_name}_train', document_tokenizer.Tokenize, taxo_manager, force_token_reload=True)
    training_document_manager.load_tokens()
    training_document_manager.load_dicts()

    num_classes = len(graph.nodes())
    training_document_ids = training_document_manager.get_ids()

    for i, document_id in enumerate(training_document_ids, 0):
        tokens = torch.tensor(training_document_manager.get_tokens(document_id), dtype=torch.int32)
        tokens = torch.reshape(tokens, (-1, 1))
        positive, non_negative = training_document_manager.get_output_label(document_id)
        output = torch.zeros(num_classes, 1)
        mask = torch.ones(num_classes, 1, dtype=torch.int32)

        for j in non_negative:
            if j in positive:
                output[j][0] = 1
            else:
                mask[j] = 0
        input = torch.cat((tokens, mask), 0)
        if i==0:
            train_x = input
            train_y = output
        else:
            train_x = torch.cat((train_x, input), 0)
            train_y = torch.cat((train_y, output), 0)
        
    train_x = torch.reshape(train_x, (-1, num_classes + token_length))
    train_y = torch.reshape(train_y, (-1, num_classes, 1))

    train_dataset = TensorDataset(train_x, train_y)
    
    notifier.send(f'{len(train_dataset)}개에 대한 데이터 로더 생성 완료. 걸린 시간: {round(time.time() - elapsed_start, 2)}.')

    return DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

def train_coreclass(text_classifier, epoch, data_loader, loss_function, optimizer):
    text_classifier.cuda()
    text_classifier.train()
    train_start = time.time()
    
    for e in range(epoch):
        start = time.time()
        running_loss = 0.0
        batch_loss = 0.0
        optimizer.zero_grad()
        for i, train_data in enumerate(data_loader):
            batch_start = time.time()
            inputs, outputs = train_data
            predicted = text_classifier(inputs.cuda())
            loss = loss_function(predicted, outputs.cuda())
            batch_loss = loss.item()
            loss.backward()

            if (i+1) % 8 == 0 :
                optimizer.step()
                optimizer.zero_grad()
                # notifer.send(f"train_amazon.ipynb: batch loss update: {batch_loss}")
                print('[%d, %5d] batch loss: %.3f' % (e + 1, i + 1, batch_loss))
            running_loss += batch_loss
            # notifier.send(f"train_amazon.ipynb: 코어 클래스 학습 batch {i + 1} 완료. 걸린 시간: {round(time.time()-batch_start, 1)}", print_message=False)


        print(f'[{e + 1}] total loss: {round(running_loss, 3)}. elapsed time: {time.time() - start}')


    notifier.send(f'{epoch} epoch 학습 완료. 걸린 시간: {round(time.time() - train_start, 2)}.')

def target_distribution(prediction):
    weight = prediction ** 2 / prediction.sum(axis=0)
    weight_1 = (1 - prediction) **2 / (1 - prediction).sum(axis=0)
    return weight / (weight + weight_1)

def train_self(text_classifier, epoch, data_loader, loss_function, optimizer, update_period=25):
    text_classifier.cuda()
    text_classifier.train()
    train_start = time.time()
    
    torch.autograd.set_detect_anomaly(True)
    
    for e in range(epoch): 
        start = time.time()
        running_loss = 0.0
        batch_loss = 0.0
        optimizer.zero_grad()

        for i, train_data in enumerate(data_loader):
            inputs, outputs = train_data
            predicted = text_classifier(inputs.cuda())
            if i % update_period == 0:
                target = target_distribution(predicted) 
            loss = loss_function(torch.log(predicted), target)
            batch_loss = loss.item()
            print(f'[{e + 1}, {i + 1}] loss: {batch_loss}')
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            running_loss += batch_loss
                
        print(f'[{e + 1}] total loss: {round(running_loss, 3)}. elapsed time: {time.time() - start}')
        
    notifier.send(f'{epoch} epoch 학습 완료. 걸린 시간: {round(time.time() - train_start, 2)}.')


In [None]:
# train with amazon
data_loader = create_data_loader(
  'amazon', 'amazon-coreclass-1000.jsonl', token_length=500, batch_size=4)

In [None]:
text_classifier = TextClassifier(class_encoder, DocuEncoder(DATA_ROOT), (dim, CLASS_LENGTH), TOKEN_LENGTH, graph, features, nn.Sigmoid(), False)

In [None]:
optimizer = optim.AdamW([
  {'params': text_classifier.document_encoder.parameters(), 'lr': 5e-5},
  {'params': text_classifier.class_encoder.parameters()},
  {'params': text_classifier.weight}], lr=4e-3)

In [None]:
train_coreclass(text_classifier, 20, data_loader, torch.nn.BCELoss(reduction='sum'), optimizer)

In [None]:
train_self(text_classifier, 20, data_loader, torch.nn.KLDivLoss(reduction='batchmean'), optimizer)