In [2]:
import argparse
import os
import torch
import gensim
import dgl
import dgl.function as fn
import torch.optim as optim
import math
import time
from torch.utils.data import TensorDataset 
from torch.utils.data import DataLoader
from torch import nn
from encoder import DocuEncoder, ClassEncoder, DocumentTokenizer
from layer import GCN
from classifier import TextClassifier
from preprocessor import TaxoDataManager, DocumentManager
from gensim.test.utils import datapath
from gensim.models import word2vec

In [None]:
TEXT_CLASSIFIER_DIR = f'{os.path.dirname(os.path.abspath(__file__))}'
DATA_ROOT = os.path.join(TEXT_CLASSIFIER_DIR, 'data')
TRAINING_DATA_DIR = os.path.join(DATA_ROOT, 'training_data/amazon/')
TOKEN_LENGTH = 500
CLASS_LENGTH = 768

word2vec_model = word2vec.Word2Vec.load(os.path.join(DATA_ROOT, 'pretrained/embedding'))

taxo_manager = TaxoDataManager(TRAINING_DATA_DIR, 'taxonomy.json', 'amazon', word2vec_model)
taxo_manager.load_all()

document_tokenizer = DocumentTokenizer(DATA_ROOT, TOKEN_LENGTH)
graph = taxo_manager.get_graph().to('cuda:0')
features = taxo_manager.get_feature().cuda()

dim = word2vec_model.wv.vector_size
gcn = GCN(dim, dim, dim, 2, nn.ReLU)
class_encoder = ClassEncoder(gcn, word2vec_model)

In [None]:
def create_data_loader(data_name, document_file, token_length, batch_size):
    training_data_dir = os.path.join(DATA_ROOT, f'training_data/{data_name}/')
    training_document_manager = DocumentManager(document_file, training_data_dir, f'{data_name}_train', document_tokenizer.Tokenize, taxo_manager)
    training_document_manager.load_tokens()
    training_document_manager.load_dicts()

    graph = taxo_manager.get_graph().to('cuda:0')
    num_classes = len(graph.nodes())
    training_document_ids = training_document_manager.get_ids()

    for i, document_id in enumerate(training_document_ids, 0):
        tokens = torch.tensor(training_document_manager.get_tokens(document_id), dtype=torch.int32)
        tokens = torch.reshape(tokens, (-1, 1))
        positive, non_negative = training_document_manager.get_output_label(document_id)
        output = torch.zeros(num_classes, 1)
        mask = torch.ones(num_classes, 1, dtype=torch.int32)

        for j in non_negative:
            if j in positive:
                output[j][0] = 1
            else:
                mask[j] = 0
            input = torch.cat((tokens, mask), 0)
            if i==0:
                train_x = input
                train_y = output
            else:
                train_x = torch.cat((train_x, input), 0)
                train_y = torch.cat((train_y, output),0)

    train_x = torch.reshape(train_x, (-1, num_classes + token_length))
    train_y = torch.reshape(train_y, (-1, num_classes, 1))

    train_dataset = TensorDataset(train_x, train_y)

    return DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
def train_coreclass(text_classifier, epoch, data_loader, loss_function, optimizer):
    text_classifier.cuda()
    text_classifier.train()

    for epoch in range(epoch): 
        start = time.time()
        running_loss = 0.0
        batch_loss = 0.0
        optimizer.zero_grad()
        for i, train_data in enumerate(data_loader):
            inputs, outputs = train_data
            predicted = text_classifier(inputs.cuda())
            loss = loss_function(predicted, outputs.cuda())
            batch_loss = loss.item()
            loss.backward()

            if (i+1) % 8 == 0 :
                optimizer.step()
                optimizer.zero_grad()
                print('[%d, %5d] batch loss: %.3f' % (epoch + 1, i + 1, batch_loss))
            running_loss += batch_loss


        print('[%d] total loss: %.3f' % (epoch + 1, running_loss))
        print('elapsed time : %f'%(time.time()-start))

    print('Finished Core Class Training')

In [None]:
def target_distribution(prediction):
    weight = prediction ** 2 / prediction.sum(axis=0)
    weight_1 = (1 - prediction) **2 / (1 - prediction).sum(axis=0)
    return weight / (weight + weight_1)

def train_self(text_classifier, epoch, data_loader, loss_function, optimizer, update_period=25):
    for _ in range(epoch): 
        start = time.time()
        running_loss = 0.0
        batch_loss = 0.0
        optimizer.zero_grad()
        target = None

        for i, train_data in enumerate(data_loader):
            inputs, outputs = train_data
            predicted = text_classifier(inputs.cuda())
            if i % update_period == 0:
                target = target_distribution(predicted) 
            loss = loss_function(target, predicted)
            loss.backward()

In [None]:
# train with amazon
data_loader = create_data_loader(
  'amazon', 'train-with-core-class-1000.jsonl', token_length=500, batch_size=4)

text_classifier = TextClassifier(class_encoder, DocuEncoder(DATA_ROOT), (dim, CLASS_LENGTH), TOKEN_LENGTH, graph, features, nn.Sigmoid(), False)

optimizer = optim.AdamW([
  {'params': text_classifier.document_encoder.parameters(), 'lr': 5e-5},
  {'params': text_classifier.class_encoder.parameters()},
  {'params': text_classifier.weight}], lr=4e-3)

train_coreclass(text_classifier, 20, data_loader, torch.nn.BCELoss(reduction='sum'), optimizer)