In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
from itertools import chain
from tqdm import tqdm

from importlib import reload

# Utility variable
import sys, getopt
sys.path.insert(0, '../..')

# var
import var.var as V
import var.path as P

# utils
import utils.data as D
import utils.io as IO
import utils.preprocess as PP
import utils.torch as Tor

## Process Command Line Options

In [None]:
## parse arguments
opts, args = getopt.getopt(sys.argv[1:], "d:acns:p:f:")

In [None]:
DATE = "2022-12-09_mixed_3"

TRAIN_OR_ALL = 'train'
COMMENT_AUGMENTATION = False
NEG_SAMPLE = False

SENT_TEMPERATURE = 1
PERS_TEMPERATURE = 1

for opt, arg in opts:
    if opt == '-d':
        DATE = arg
    elif opt == '-a':
        TRAIN_OR_ALL = 'all'
    elif opt == '-c':
        COMMENT_AUGMENTATION = True
    elif opt == '-n':
        NEG_SAMPLE = True
    elif opt == '-s':
        SENT_TEMPERATURE = float(arg)
    elif opt == '-p':
        PERS_TEMPERATURE = float(arg)

In [None]:
print("date:", DATE)
print("train or all:", TRAIN_OR_ALL)
print("comment augmentation:", COMMENT_AUGMENTATION)
print("negative comment sample:", NEG_SAMPLE)
print("sentence temperature:", SENT_TEMPERATURE)
print("persepctive temperature:", PERS_TEMPERATURE)

## Params

In [None]:
if NEG_SAMPLE:
    if COMMENT_AUGMENTATION:
        MODEL_SAVE_DIR = "significance_pHAN_cmt_cos_dist_w_cmt_aug_{}_{}_neg".format(TRAIN_OR_ALL, DATE)
    else:
        MODEL_SAVE_DIR = "significance_pHAN_cmt_cos_dist_wo_cmt_aug_{}_{}_neg".format(TRAIN_OR_ALL, DATE)
else:
    if COMMENT_AUGMENTATION:
        MODEL_SAVE_DIR = "significance_pHAN_cmt_cos_dist_w_cmt_aug_{}_{}".format(TRAIN_OR_ALL, DATE)
    else:
        MODEL_SAVE_DIR = "significance_pHAN_cmt_cos_dist_wo_cmt_aug_{}_{}".format(TRAIN_OR_ALL, DATE)
    
BERTOPIC_MODEL_NAME = "BERTopic_custom_mcs_100_ckip_diversified_low_{}".format(TRAIN_OR_ALL)

BERT_MODEL_NAME = 'ckiplab/bert-base-chinese'
BERT_TOKENIZER_NAME = 'bert-base-chinese'

if COMMENT_AUGMENTATION:
    MODEL_SAVE_DIR_PATH = os.path.join(P.FP_SIGNIFICANCE_PHAN_DIR, TRAIN_OR_ALL, 'w', MODEL_SAVE_DIR)
else:
    MODEL_SAVE_DIR_PATH = os.path.join(P.FP_SIGNIFICANCE_PHAN_DIR, TRAIN_OR_ALL, 'wo', MODEL_SAVE_DIR)

if not os.path.exists(MODEL_SAVE_DIR_PATH):
    os.makedirs(MODEL_SAVE_DIR_PATH)

USE_SBERT_EMBED = True
USE_BERT_EMBED = False
assert (USE_SBERT_EMBED or USE_BERT_EMBED) == True

GPU_NUM = 0

TOP_K = V.TOP_K
NUM_PERSPECTIVE = V.MAX_NUM_PERSPECTIVE

BATCH_SIZE = 128
## pHAN params
BERT_DIM = 768
SENT_DIM = 768
CXT_DIM = 128
PRJ_DIM = 768
COMPRESSION = False
PROJECTION = False
ATTENTION_EMPTY_MASK = True
FREEZE_BERT = True
DROPOUT_RATE = 0.1
LEAKY_RELU_NEG_SLOPE = 0.1
ENC_BZ = 128

## gradient params
GRADIENT_MAX_NORM = 0.5
GRADIENT_CLIP_VALUE = 0.5

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

In [None]:
print("model save dir:", MODEL_SAVE_DIR)

In [None]:
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [None]:
device = torch.device(GPU_NUM)

In [None]:
import pytz
timezone = pytz.timezone('Asia/Taipei')

In [None]:
import logging

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

In [None]:
log_file = os.path.join(MODEL_SAVE_DIR_PATH, 'training_log.log')

logger = logging.getLogger()
logger.setLevel(logging.INFO)
# create file handler which logs even debug messages
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)

In [None]:
# Disable hugging face tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"

### Utils

In [None]:
def defaultdict_init_defaultdict_init_by_int():
    return defaultdict(int)

def defaultdict_init_defaultdict_init_by_float():
    return defaultdict(float)

## Read raw data

In [None]:
df_comments = D.read_df_comments()
df_comments

In [None]:
df_applicants = D.read_df_applicants()
df_applicants

In [None]:
individual_data = []

for year in V.YEAR_DIRS[:-1]:
    _dir = os.path.join(P.FP_SIGNIFICANCE_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL, year)
    
    for file in os.listdir(_dir):
        if file == '.ipynb_checkpoints':
            continue

        fn = os.path.join(_dir, file)

        with open(fn, "rb") as f:
            data = pickle.load(f)
            individual_data.append(data)

In [None]:
len(individual_data)

In [None]:
fn = os.path.join(
    P.FP_COMMENT_CLUSTERING_TOPIC_HIERARCHY_DIR, 
    "{}_topic_aggregate_info.pkl".format(BERTOPIC_MODEL_NAME)
)

with open(fn, "rb") as f:
    topic_aggregate_info = pickle.load(f)
    perspective_mean_embed_dict = topic_aggregate_info['topic_aggregate_embed_mean_dict']

In [None]:
perspective_mean_embed = []

for i, embed in perspective_mean_embed_dict.items():
    perspective_mean_embed.append(embed)
    
perspective_mean_embed = torch.tensor(np.stack(perspective_mean_embed))
perspective_mean_embed.shape

## Prepare training data and testing data

In [None]:
train_pseudo_summary_data = []
train_comment_data = []
train_aug_comment_data = []
train_grade_data = []

test_pseudo_summary_data = []
test_comment_data = []
test_aug_comment_data = []
test_grade_data = []

for data in tqdm(individual_data):
    _year = data['year']
    _id = data['id']
    _name = data['name']
    pseudo_summary = data['pseudo_summary']
    
    ## check if pseudo summary is empty
    ps = ''.join(chain.from_iterable(pseudo_summary))
    if ps == '':
        continue
    
    ## check train or test data
    row = df_applicants.query('`year` == {} and `id` == {}'.format(_year, _id))
    try:
        train_or_test = row['train_or_test'].to_list()[0]
    except:
        train_or_test = 'train'
    
    ## get corresponding comments
    row = df_comments.query('`year` == {} and `id` == {}'.format(_year, _id))
    comments = row['comment'].to_list()
    augmented_comments = row['augmented_comments'].to_list()
    grades = row['grade'].to_list()
    
    ## append data to train data set or test data set
    if train_or_test == 'train' or TRAIN_OR_ALL == 'all':
        for comment, augmented_comment, grade in zip(comments, augmented_comments, grades):
            ## remove empty comment
            if PP.is_empty_sent(comment):
                continue
            
            train_aug_comment_data.append(augmented_comment)
            train_comment_data.append(comment)
            train_pseudo_summary_data.append(pseudo_summary)
            train_grade_data.append(grade)
                
    elif train_or_test == 'test':
        for comment, augmented_comment, grade in zip(comments, augmented_comments, grades):
            ## remove empty comment
            if PP.is_empty_sent(comment):
                continue
                
            test_aug_comment_data.append(augmented_comment)
            test_comment_data.append(comment)
            test_pseudo_summary_data.append(pseudo_summary)
            test_grade_data.append(grade)

### Sentiment analysis to assign positive and negative sample for computing cosine embedding loss

In [None]:
from transformers import BertForSequenceClassification
from transformers import BertTokenizer

In [None]:
sentiment_analysis_model_name = 'IDEA-CCNL/Erlangshen-Roberta-110M-Sentiment'

sentiment_analysis_tokenizer = BertTokenizer.from_pretrained(sentiment_analysis_model_name)
sentiment_analysis_model = BertForSequenceClassification.from_pretrained(sentiment_analysis_model_name).to(device)

In [None]:
def sentiment_analysis_inference(text):
    dataset = Tor.BatchSentenceDataset(text)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
    
    prob_batch = []
    with torch.no_grad():
        for batch in dataloader:
            encoding = sentiment_analysis_tokenizer(batch, padding=True, return_tensors='pt', truncation='longest_first', max_length=510)

            for key in encoding:
                if isinstance(encoding[key], Tensor):
                    encoding[key] = encoding[key].to(device)

            output = sentiment_analysis_model(**encoding)
            postive_prob = torch.nn.functional.softmax(output.logits, dim=-1)[:, 1]
            prob_batch.append(postive_prob)
            
    postive_probs = torch.cat(prob_batch)
     ## -1 represent negative, 1 represent neutral or positive
    sentiment_label = [1 if p > 0.3 else -1 for p in postive_probs]
    
    return sentiment_label

In [None]:
# %%time
if not COMMENT_AUGMENTATION:
    train_comment_sentiment_data = sentiment_analysis_inference(train_comment_data)
    if TRAIN_OR_ALL == 'train':
        test_comment_sentiment_data = sentiment_analysis_inference(test_comment_data)
else:
    train_comment_sentiment_data = sentiment_analysis_inference(train_aug_comment_data)
    if TRAIN_OR_ALL == 'train':
        test_comment_sentiment_data = sentiment_analysis_inference(test_aug_comment_data)

In [None]:
# if not COMMENT_AUGMENTATION:
#     D.write_comment_sentiment(train_comment_sentiment_data, 'train')
#     D.write_comment_sentiment(test_comment_sentiment_data, 'test')
# else:
#     D.write_aug_comment_sentiment(train_comment_sentiment_data, 'train')
#     D.write_aug_comment_sentiment(test_comment_sentiment_data, 'test')

In [None]:
# if TRAIN_OR_ALL == 'train':
#     if not COMMENT_AUGMENTATION:
#         train_comment_sentiment_data = D.read_comment_sentiment('train')
#         test_comment_sentiment_data = D.read_comment_sentiment('test')
#     else:
#         train_comment_sentiment_data = D.read_aug_comment_sentiment('train')
#         test_comment_sentiment_data = D.read_aug_comment_sentiment('test')
# elif TRAIN_OR_ALL == 'all':
#     if not COMMENT_AUGMENTATION:
#         train_comment_sentiment_data = D.read_comment_sentiment('all')
#     else:
#         train_comment_sentiment_data = D.read_aug_comment_sentiment('all')
        
#     test_comment_sentiment_data = []

In [None]:
len(train_pseudo_summary_data), len(train_comment_data), len(train_aug_comment_data), len(train_grade_data), len(train_comment_sentiment_data)

In [None]:
from collections import Counter
for g, c in Counter(train_grade_data).items():
    print(g, c / len(train_grade_data) * 100)

In [None]:
if TRAIN_OR_ALL == 'train':
    len(test_pseudo_summary_data), len(test_comment_data), len(test_aug_comment_data), len(test_grade_data), len(test_comment_sentiment_data)

### class weight

In [None]:
# from sklearn.utils import class_weight

In [None]:
# class_weights = class_weight.compute_class_weight(
#     'balanced', classes=np.unique(test_grade_data), y=test_grade_data
# )
# class_weights

# Train DNN

### Apply one hot encoder to grade label

In [None]:
from sklearn.preprocessing import OneHotEncoder

In [None]:
V.TRAIN_GRADE_LABELS

In [None]:
enc = OneHotEncoder()
one_hot_vector = enc.fit_transform(V.TRAIN_GRADE_LABELS).toarray()
one_hot_vector

In [None]:
train_ext_grade_data = np.array(train_grade_data).reshape(-1, 1)
train_ext_grade_data = enc.transform(train_ext_grade_data).toarray()

if TRAIN_OR_ALL == 'train':
    test_ext_grade_data = np.array(test_grade_data).reshape(-1, 1)
    test_ext_grade_data = enc.transform(test_ext_grade_data).toarray()

In [None]:
train_ext_grade_data.shape

if TRAIN_OR_ALL == 'train':
    test_ext_grade_data.shape

### Create dataset

In [None]:
class PseudoSummaryEvaluationDataset(Dataset):
    def __init__(self, pseudo_summaries, comments, aug_comments, grades, comments_sentiment):
        ## list of sentences
        self.pseudo_summaries = pseudo_summaries
        self.comments = comments
        self.aug_comments = aug_comments
        self.grades = grades
        self.comments_sentiment = comments_sentiment
        
    def __len__(self):
        return len(self.grades)

    def __getitem__(self, idx):
        pseudo_summary = self.pseudo_summaries[idx]
        comment = self.comments[idx]
        aug_comment = self.aug_comments[idx]
        grade = self.grades[idx]
        comment_sentiment = self.comments_sentiment[idx]
        
        return pseudo_summary, comment, aug_comment, grade, comment_sentiment

In [None]:
train_dataset = PseudoSummaryEvaluationDataset(
    train_pseudo_summary_data, train_comment_data, train_aug_comment_data, train_ext_grade_data, train_comment_sentiment_data
)
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda batch: batch,
    num_workers=8, pin_memory=True
)

if TRAIN_OR_ALL == 'train':
    test_dataset = PseudoSummaryEvaluationDataset(
        test_pseudo_summary_data, test_comment_data, test_aug_comment_data, test_ext_grade_data, test_comment_sentiment_data
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE, collate_fn=lambda batch: batch,
        num_workers=8, pin_memory=True
    )

## Model Building

### Load BERT model

In [None]:
if USE_SBERT_EMBED:
    from sentence_transformers import SentenceTransformer
    
    bert_tokenizer = None
    bert_model = SentenceTransformer(BERT_MODEL_NAME).to(device)

In [None]:
if USE_BERT_EMBED:
    from transformers import BertTokenizerFast, AutoModel

    bert_tokenizer = BertTokenizerFast.from_pretrained(BERT_TOKENIZER_NAME)
    bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME).to(device)

### Attention Network

In [None]:
from torch import nn

### Perspective HAN

In [None]:
import utils.pHAN as PHAN

## Training loop & testing loop

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [None]:
## utils
def get_data(batch):
    batch_pseudo_summaries = [p[0] for p in batch]
    batch_comments = [p[1] for p in batch]
    batch_aug_comments = [p[2] for p in batch]
    batch_grades = np.array([p[3] for p in batch])
    batch_comments_sentiment = np.array([p[4] for p in batch])
    
    return batch_pseudo_summaries, batch_comments, batch_aug_comments, batch_grades, batch_comments_sentiment

def logit_to_label(logits, return_numpy=False):
    ## convert logits to labels
    labels_idx = torch.argmax(logits, 1).cpu().detach().numpy()
    labels = [V.GRADE_INDEX_TO_LABEL[idx] for idx in labels_idx]
    
    if return_numpy:
        return np.array(labels)
    
    return labels

def log_gradients(layer, layer_name, mean=True):
    IO.log_dividing_line(logger, "Gradient of {}".format(layer_name))
    for name, i in layer.named_parameters():
        try:
            logger.info("{} gradient mean: {}".format(name, torch.mean(i.grad)))
        except:
            logger.info("{} : no gradient".format(name))

def decouple_loss_fn_dict(d):
    cls_loss_fn = d["cls_loss_fn"]
    cos_dis_loss_fn = d["cos_dis_loss_fn"]
    con_loss_fn = d["con_loss_fn"]
    
    return cls_loss_fn, cos_dis_loss_fn, con_loss_fn

def decouple_loss_weight_dict(d):
    cls_loss_weight = d["cls_loss_weight"]
    cos_dis_loss_weight = d["cos_dis_loss_weight"]
    con_loss_weight = d["con_loss_weight"]

    return cls_loss_weight, cos_dis_loss_weight, con_loss_weight

In [None]:
def train_loop(
    dataloader, model, loss_fn_dict, optimizer, eta, loss_weight_dict, epoch
):
    cls_loss_fn, cos_dis_loss_fn, con_loss_fn = decouple_loss_fn_dict(loss_fn_dict)
    cls_loss_weight, cos_dis_loss_weight, con_loss_weight = decouple_loss_weight_dict(loss_weight_dict)
    
    size = len(dataloader.dataset)
    
    for _id, batch in enumerate(dataloader):
        model.train()
        IO.log_dividing_line(logger, "Epoch {}, Batch {}".format(epoch, _id))
        batch_pseudo_summaries, batch_comments, batch_aug_comments, batch_grades, batch_comments_sentiment = get_data(batch)
        batch_comments_sentiment = torch.tensor(batch_comments_sentiment).to(device)
#         batch_grade_logits = torch.tensor(batch_grades).type(torch.float).to(device)
        
        # Compute prediction
        if COMMENT_AUGMENTATION:
            projected_embed, logits, _, _ = model(batch_pseudo_summaries, batch_aug_comments, eta)
        else:
            projected_embed, logits, _, _ = model(batch_pseudo_summaries, batch_comments, eta)
        comments_embed = model.encode(batch_comments)
        
        # Compute loss
#         cls_loss = cls_loss_fn(logits, batch_grade_logits)
        
        ## compute cosine embedding loss normally
        if NEG_SAMPLE:
            cos_dis_loss = cos_dis_loss_fn(projected_embed, comments_embed, batch_comments_sentiment.to(device))
        else:
            cos_dis_loss = cos_dis_loss_fn(
                projected_embed, comments_embed, torch.ones(comments_embed.shape[0]).to(device)
            )

#         con_loss = con_loss_fn(projected_embed, labels=batch_grade_logits.argmax(1))
#         loss = cls_loss_weight * cls_loss + cos_dis_loss_weight * cos_dis_loss + con_loss_weight * con_loss
        loss = cos_dis_loss
        
#         pred_labels = logit_to_label(logits)
#         true_labels = logit_to_label(batch_grade_logits)
        
#         logger.info("pred logits: {}".format(logits))
#         logger.info("pred label: ".format(pred_labels))
#         logger.info("grades label: ".format(true_labels))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        ## apply gradient clipping
#         nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_MAX_NORM)
#         nn.utils.clip_grad_value_(model.parameters(), clip_value=GRADIENT_CLIP_VALUE)
        optimizer.step()

        ## log gradient
#         if COMPRESSION:
#             log_gradients(model.compression_layer, "compression_layer")
#         log_gradients(model.sentence_att_net, "sentence_att_net")
#         log_gradients(model.perspective_att_net, "perspective_att_net")
# #         log_gradients(model.grade_classifier, "grade_classifier")
#         log_gradients(model.project_head, "project_head")
        
        ## log training loss
#         IO.log_dividing_line(logger, "Loss")
#         logger.info("Total loss: ".format(loss))
#         logger.info("cls loss: ".format(cls_loss))
#         logger.info("cos dis loss: ".format(cos_dis_loss))
#         logger.info("con loss: ".format(con_loss))
        
        if _id % 10 == 0:
            loss, current = loss.item(), _id * (len(batch_pseudo_summaries)+1)
            logger.info("loss: {:>7f}  [{:>5d}/{:>5d}]".format(loss, current, size))

            
def test_loop(
    dataloader, model, loss_fn_dict, eta, loss_weight_dict
):
    cls_loss_fn, cos_dis_loss_fn, con_loss_fn = decouple_loss_fn_dict(loss_fn_dict)
    cls_loss_weight, cos_dis_loss_weight, con_loss_weight = decouple_loss_weight_dict(loss_weight_dict)
    
    IO.log_dividing_line(logger)
    IO.log_dividing_line(logger, "Testing")
    IO.log_dividing_line(logger)
    
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

#     pred_label_list = []
#     true_label_list = []
    
    with torch.no_grad():
        for batch in dataloader:
            model.eval()
            batch_pseudo_summaries, batch_comments, batch_aug_comments, batch_grades, batch_comments_sentiment = get_data(batch)
            batch_comments_sentiment = torch.tensor(batch_comments_sentiment).to(device)
#             batch_grade_logits = torch.tensor(batch_grades).type(torch.float).to(device)
            comments_embed = model.encode(batch_comments)
    
            # Compute prediction
            if COMMENT_AUGMENTATION:
                projected_embed, logits, _, _ = model(batch_pseudo_summaries, batch_aug_comments, eta)
            else:
                projected_embed, logits, _, _ = model(batch_pseudo_summaries, batch_comments, eta)
        
            # Compute loss
    #         cls_loss = cls_loss_fn(logits, batch_grade_logits)
            if NEG_SAMPLE:
                cos_dis_loss = cos_dis_loss_fn(projected_embed, comments_embed, batch_comments_sentiment.to(device))
            else:
                cos_dis_loss = cos_dis_loss_fn(
                    projected_embed, comments_embed, torch.ones(comments_embed.shape[0]).to(device)
                )
                
    #         con_loss = con_loss_fn(projected_embed, labels=batch_grade_logits.argmax(1))
    #         loss = cls_loss_weight * cls_loss + cos_dis_loss_weight * cos_dis_loss + con_loss_weight * con_loss
            loss = cos_dis_loss
            test_loss += loss

            # Compute accuracy
#             pred_labels = logit_to_label(logits, return_numpy=True)
#             true_labels = logit_to_label(batch_grade_logits, return_numpy=True)
            
#             pred_label_list.append(pred_labels)
#             true_label_list.append(true_labels)
            
#             correct += (pred_labels == true_labels).sum()

#     y_test = list(chain.from_iterable(true_label_list))
#     y_test_pred = list(chain.from_iterable(pred_label_list))
    
#     logger.info("Test Classification Report")
#     logger.info(classification_report(y_test, y_test_pred))
#     logger.info("Confusion Matrix")
#     logger.info(confusion_matrix(y_test, y_test_pred))
        
    test_loss /= num_batches
    correct /= size
    logger.info("Test Error: \nAccuracy: {:>0.1f}%, Avg loss: {:>8f} \n".format(100*correct, test_loss))

## Model initialization

In [None]:
pHAN = PHAN.PerspectiveHierarchicalAttentionNetwork(
    bert_model, bert_tokenizer, perspective_mean_embed, 
    NUM_PERSPECTIVE, TOP_K, BERT_DIM, SENT_DIM, CXT_DIM, PRJ_DIM, 
    sent_temperature=SENT_TEMPERATURE, pers_temperature=PERS_TEMPERATURE, 
    dropout_rate=DROPOUT_RATE, leaky_relu_negative_slope=LEAKY_RELU_NEG_SLOPE, 
    encode_batch_size=ENC_BZ, compression=COMPRESSION, projection=PROJECTION,
    attention_empty_mask=ATTENTION_EMPTY_MASK, freeze_bert=FREEZE_BERT
).to(device)
pHAN

### Number of parameters

In [None]:
num_params = torch.tensor(0)

for parameter in pHAN.parameters():
    if parameter.requires_grad:
        num_params += torch.prod(torch.tensor(parameter.shape))

num_params

## Train with whole dataset

### Loss functions

In [None]:
from pytorch_metric_learning import losses

In [None]:
## classification loss
weight = torch.tensor([3, 1.5, 1, 1]).to(device)
# weight = torch.tensor(class_weights).to(device)
cls_loss_fn = nn.BCELoss(weight)
# class_weights

In [None]:
## cosine similarity loss w.r.t comment
cos_dis_loss_fn = nn.CosineEmbeddingLoss()

In [None]:
## contrastive loss
con_loss_fn = losses.SupConLoss(temperature=1)

In [None]:
loss_fn_dict = {
    "cls_loss_fn": cls_loss_fn,
    "cos_dis_loss_fn": cos_dis_loss_fn,
    "con_loss_fn": con_loss_fn
}

### Learning rate scheduler

In [None]:
low_lr_param_list = ['grade_classifier']

low_lr_params = list(filter(
    lambda kv: sum([_name in kv[0] for _name in low_lr_param_list]),
    pHAN.named_parameters()
))
low_lr_params = [params[1] for params in low_lr_params]

base_lr_params = list(filter(
    lambda kv: sum([_name not in kv[0] for _name in low_lr_param_list]),
    pHAN.named_parameters()
))
base_lr_params = [params[1] for params in base_lr_params]

In [None]:
base_learning_rate = 5 * 1e-4

optimizer = torch.optim.AdamW(
    [
        {"params": base_lr_params, "lr": 5 * 1e-4},
        {"params": low_lr_params, "lr": 5 * 1e-4},
        {"params": torch.tensor([1]), "lr": 1}, ## eta
    ],
    lr=base_learning_rate
)

In [None]:
lr_gamma = 0.8
step_size = 2
sch_lambda_params = lambda epoch: lr_gamma ** (epoch // step_size)
sch_lambda_eta = lambda epoch: max(0, 1 - 0.25 * (epoch // 1 // step_size))

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=[sch_lambda_params, sch_lambda_params, sch_lambda_eta]
)

In [None]:
epochs = 20

# for epoch in range(epochs):
#     print(epoch, scheduler.get_last_lr())
#     scheduler.step()

In [None]:
from datetime import datetime

In [None]:
def save_model(t):
    model_name = "epoch_{:04d}.pt".format(t)
    fn = os.path.join(MODEL_SAVE_DIR_PATH, model_name)
    
    torch.save({
        'epoch': t,
        'use_sbert_embed': USE_SBERT_EMBED,
        'use_bert_embed': USE_BERT_EMBED,
        'num_perspective': NUM_PERSPECTIVE,
        'train_or_all': TRAIN_OR_ALL,
        'model_state_dict': pHAN.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'top_k': TOP_K,
        'bert_model': BERT_MODEL_NAME,
        'bert_tokenizer': BERT_TOKENIZER_NAME,
        'bert_topic_model': BERTOPIC_MODEL_NAME,
        'model_save_dir': MODEL_SAVE_DIR,
        'comment_augmentation': COMMENT_AUGMENTATION,
        'perspective_mean_embed': perspective_mean_embed,
        'bert_dim': BERT_DIM,
        'sent_dim': SENT_DIM,
        'cxt_dim': CXT_DIM,
        'prj_dim': PRJ_DIM,
        'sent_temperature': SENT_TEMPERATURE,
        'pers_temperature': PERS_TEMPERATURE,
        'dropout_rate': DROPOUT_RATE,
        'leaky_relu_negative_slope': LEAKY_RELU_NEG_SLOPE,
        'encode_batch_size': ENC_BZ,
        'compression': COMPRESSION,
        'projection': PROJECTION,
        'attention_empty_mask': ATTENTION_EMPTY_MASK,
        'freeze_bert': FREEZE_BERT,
        'gradient_max_norm': GRADIENT_MAX_NORM,
        'gradient_clip_value': GRADIENT_CLIP_VALUE,
        'random_state': RANDOM_STATE,
    }, fn)

## Training

In [None]:
cls_loss_weight = 0
cos_dis_loss_weight = 1
con_loss_weight = 0

loss_weight_dict = {
    "cls_loss_weight": cls_loss_weight,
    "cos_dis_loss_weight": cos_dis_loss_weight,
    "con_loss_weight": con_loss_weight
}

In [None]:
# %%time

for t in tqdm(range(1, epochs+1)):
    now = datetime.now(timezone)
    ts = datetime.strftime(now,'%Y-%m-%d_%H:%M:%S')
    IO.log_dividing_line(logger, "Epoch {}, Timestamp: {}".format(t, ts))
    
    ## get current eta
    eta = scheduler.get_last_lr()[-1]
    
    train_loop(train_dataloader, pHAN, loss_fn_dict, optimizer, eta, loss_weight_dict, t)
    
    if TRAIN_OR_ALL == 'train':
        test_loop(test_dataloader, pHAN, loss_fn_dict, eta, loss_weight_dict)

    scheduler.step()
    
    ## [TODO] log the results
    if t % (step_size*5) == 0:
        ## Save the model
        IO.log_dividing_line(logger, "Saving model from epoch {:04d}...".format(t))
        save_model(t)
    
logger.info("Done!")