In [None]:
import argparse
import os
import torch
import gensim
import dgl
import dgl.function as fn
import torch.optim as optim
import math
import time
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch import nn
from encoder import DocuEncoder, ClassEncoder, DocumentTokenizer
from layer import GCN
from classifier import TextClassifier
from preprocessor import TaxoDataManager, DocumentManager
from gensim.test.utils import datapath
from gensim.models import word2vec

In [None]:
import discord_notify as dn

WEBHOOK_URL = 'https://discord.com/api/webhooks/917284193036275712/2Da9DmvQjYugyP8pzvB4AzPMqVEizyVipHYLDPE79ZySU2aPGL3imH-YdcqkiUZxf_ku'

notifier = dn.Notifier(WEBHOOK_URL)
# notifier.send("노트북 실행 시작: train_amazon.ipynb", print_message=False)

In [None]:
DATA_ROOT = 'data/'
TRAINING_DATA_DIR = os.path.join(DATA_ROOT, 'training_data/amazon/')
TOKEN_LENGTH = 500
CLASS_LENGTH = 768

word2vec_model = word2vec.Word2Vec.load(os.path.join(DATA_ROOT, 'pretrained/embedding'))

taxo_manager = TaxoDataManager(TRAINING_DATA_DIR, 'taxonomy.json', 'amazon', word2vec_model)
taxo_manager.load_all()

document_tokenizer = DocumentTokenizer(DATA_ROOT, TOKEN_LENGTH)
graph = taxo_manager.get_graph().to('cuda:0')
features = taxo_manager.get_feature().cuda()

dim = word2vec_model.wv.vector_size
gcn = GCN(dim, dim, dim, 2, nn.ReLU())
class_encoder = ClassEncoder(gcn, word2vec_model)

In [None]:
def create_dataset(data_name, document_file, token_length, num_val=None):
    elapsed_start = time.time()
    training_data_dir = os.path.join(DATA_ROOT, f'training_data/{data_name}/')
    training_document_manager = DocumentManager(document_file, training_data_dir, f'{data_name}_train', document_tokenizer.Tokenize, taxo_manager, force_token_reload=True)
    training_document_manager.load_tokens()
    training_document_manager.load_dicts()

    num_classes = len(graph.nodes())
    training_document_ids = training_document_manager.get_ids()

    for i, document_id in enumerate(training_document_ids, 0):
        tokens = torch.tensor(training_document_manager.get_tokens(document_id), dtype=torch.int32)
        tokens = torch.reshape(tokens, (-1, 1))
        positive, non_negative = training_document_manager.get_output_label(document_id)
        output = torch.zeros(num_classes, 1)
        mask = torch.ones(num_classes, 1, dtype=torch.int32)

        for j in non_negative:
            if j in positive:
                output[j][0] = 1
            else:
                mask[j] = 0
        input = torch.cat((tokens, mask), 0)
        if i==0:
            train_x = input
            train_y = output
        else:
            train_x = torch.cat((train_x, input), 0)
            train_y = torch.cat((train_y, output), 0)
        
    train_x = torch.reshape(train_x, (-1, num_classes + token_length))
    train_y = torch.reshape(train_y, (-1, num_classes, 1))

    dataset = TensorDataset(train_x, train_y)
    num_dataset = len(dataset)
    if num_val is not None:
        dataset = random_split(dataset, [num_val, num_dataset - num_val])
    
    notifier.send(f'{num_dataset}개 데이터셋 생성 완료. 걸린 시간: {round(time.time() - elapsed_start, 2)}.')
    
    return dataset
    
    
def validate_coreclass(model, dataloader, criterion):
    valid_loss = 0.0
    for data, labels in dataloader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        target = model(data)
        loss = criterion(target,labels)
        valid_loss += loss.item()
        
    return valid_loss / len(dataloader)

def train_coreclass(text_classifier, epoch, data_loader, loss_function, optimizer, valid_dataloader = None, save_path = None):
    text_classifier.cuda()
    train_start = time.time()
    min_valid_loss = np.inf
    
    for e in range(epoch):
        start = time.time()
        running_loss = 0.0
        optimizer.zero_grad()
        for i, train_data in enumerate(data_loader):
            inputs, outputs = train_data
            predicted = text_classifier(inputs.cuda())
            loss = loss_function(predicted, outputs.cuda())
            loss.backward()
            running_loss += loss.item()

            if (i+1) % 8 == 0 :
                optimizer.step()
                optimizer.zero_grad()
                # print('[%d, %5d] batch loss: %.3f' % (e + 1, i + 1, batch_loss))
                
        running_loss /= len(data_loader)

        valid_loss = validate_coreclass(text_classifier, valid_dataloader, loss_function)
        if min_valid_loss > valid_loss:
            print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) Saving The Model')
            min_valid_loss = valid_loss
            torch.save(text_classifier.state_dict(), save_path)

        print(f'[{e + 1}] train loss: {round(running_loss, 3)}. validation_loss: {round(valid_loss, 3)}. elapsed time: {time.time() - start}')


    notifier.send(f'{epoch} epoch 코어클래스 학습 완료. 걸린 시간: {round(time.time() - train_start, 2)}.')
    
def safe_div(a, b, epsilon=1e-8):
    return a / b.clamp(min=epsilon)

def safe_log(a, epsilon=1e-8):
    return torch.log(a.clamp(min=epsilon))

def target_distribution(prediction):
    weight = safe_div(prediction ** 2, prediction.sum(axis=0))
    weight_1 = safe_div((1 - prediction) **2, (1 - prediction).sum(axis=0))
    return safe_div(weight, (weight + weight_1))

def validate_self(model, dataloader, criterion):
    valid_loss = 0.0
    for data, _ in dataloader:
        if torch.cuda.is_available():
            data = data.cuda()
        
        predicted = model(data)
        target = target_distribution(predicted) 
        loss = criterion(predicted, target)
        valid_loss += loss.item()
        
    return valid_loss / len(dataloader)

def train_self(text_classifier, epoch, data_loader, loss_function, optimizer, update_period, valid_dataloader, save_path):
    text_classifier.cuda()
    train_start = time.time()
    min_valid_loss = np.inf
    
    for e in range(epoch): 
        start = time.time()
        running_loss = 0.0
        batch_loss = 0.0
        optimizer.zero_grad()

        for i, train_data in enumerate(data_loader):
            inputs, outputs = train_data
            predicted = text_classifier(inputs.cuda())
            target = target_distribution(predicted)
            loss = loss_function(predicted, target)
            loss.backward()
            running_loss += loss.item()
            if i % update_period == 0:
                optimizer.step()
                optimizer.zero_grad()
        
        running_loss /= len(data_loader)
        valid_loss = validate_self(text_classifier, valid_dataloader, loss_function)
        if min_valid_loss > valid_loss:
            print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) Saving The Model')
            min_valid_loss = valid_loss
            torch.save(text_classifier.state_dict(), save_path)

        print(f'[{e + 1}] train loss: {round(running_loss, 3)}. validation_loss: {round(valid_loss, 3)}. elapsed time: {time.time() - start}')
        
    notifier.send(f'{epoch} epoch 자기 학습 완료. 걸린 시간: {round(time.time() - train_start, 2)}.')


In [None]:
# train with amazon
val_dataset, train_dataset = create_dataset('amazon', 'amazon-coreclass-45000.jsonl', token_length=500, num_val=5000)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [None]:
text_classifier = TextClassifier(class_encoder, DocuEncoder(DATA_ROOT), (dim, CLASS_LENGTH), TOKEN_LENGTH, graph, features, nn.Sigmoid(), False)

In [None]:
optimizer = optim.AdamW([
  {'params': text_classifier.document_encoder.parameters(), 'lr': 5e-5},
  {'params': text_classifier.class_encoder.parameters()},
  {'params': text_classifier.weight}], lr=4e-3)

In [None]:
save_path = os.path.join(DATA_ROOT, 'trained/text-classifier-amazon.pt')
train_coreclass(text_classifier, 20, train_dataloader, torch.nn.BCELoss(reduction='sum'), optimizer, val_dataloader, save_path)

In [None]:
def kl_div_loss(predicted, target):
    return (target * safe_log(safe_div(target, predicted))).sum()

train_self(text_classifier, 5, train_dataloader, kl_div_loss, optimizer, 25, val_dataloader, save_path)