# Install libraries

In [None]:
!pip install -r requirements.txt

In [None]:
# https://github.com/teddysum/korean_ABSA_baseline/blob/0d47e3eef31419a2d31d859da2dfba4f9df75531/src/sentiment_analysis.py#L532-L630

def evaluation_f1(true_data, pred_data):

    true_data_list = true_data
    pred_data_list = pred_data

    ce_eval = {
        'TP': 0,
        'FP': 0,
        'FN': 0,
        'TN': 0
    }

    pipeline_eval = {
        'TP': 0,
        'FP': 0,
        'FN': 0,
        'TN': 0
    }

    for i in range(len(true_data_list)):

        # TP, FN checking
        is_ce_found = False
        is_pipeline_found = False
        for y_ano  in true_data_list[i]['annotation']:
            y_category = y_ano[0]
            y_polarity = y_ano[2]

            for p_ano in pred_data_list[i]['annotation']:
                p_category = p_ano[0]
                p_polarity = p_ano[1]

                if y_category == p_category:
                    is_ce_found = True
                    if y_polarity == p_polarity:
                        is_pipeline_found = True

                    break

            if is_ce_found is True:
                ce_eval['TP'] += 1
            else:
                ce_eval['FN'] += 1

            if is_pipeline_found is True:
                pipeline_eval['TP'] += 1
            else:
                pipeline_eval['FN'] += 1

            is_ce_found = False
            is_pipeline_found = False

        # FP checking
        for p_ano in pred_data_list[i]['annotation']:
            p_category = p_ano[0]
            p_polarity = p_ano[1]

            for y_ano  in true_data_list[i]['annotation']:
                y_category = y_ano[0]
                y_polarity = y_ano[2]

                if y_category == p_category:
                    is_ce_found = True
                    if y_polarity == p_polarity:
                        is_pipeline_found = True

                    break

            if is_ce_found is False:
                ce_eval['FP'] += 1

            if is_pipeline_found is False:
                pipeline_eval['FP'] += 1
            
            is_ce_found = False
            is_pipeline_found = False

    ce_precision = ce_eval['TP']/(ce_eval['TP']+ce_eval['FP'])
    ce_recall = ce_eval['TP']/(ce_eval['TP']+ce_eval['FN'])

    ce_result = {
        'Precision': ce_precision,
        'Recall': ce_recall,
        'F1': 2*ce_recall*ce_precision/(ce_recall+ce_precision)
    }

    pipeline_precision = pipeline_eval['TP']/(pipeline_eval['TP']+pipeline_eval['FP'])
    pipeline_recall = pipeline_eval['TP']/(pipeline_eval['TP']+pipeline_eval['FN'])

    pipeline_result = {
        'Precision': pipeline_precision,
        'Recall': pipeline_recall,
        'F1': 2*pipeline_recall*pipeline_precision/(pipeline_recall+pipeline_precision)
    }

    return {
        'category extraction result': ce_result,
        'entire pipeline result': pipeline_result
    }

# Load libraries

In [None]:
acd_save_path = './model/acd_best.bin'
asc_save_path = './model/asc_best.bin'

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import re
import time
import json
import random
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt

from sklearn.metrics import f1_score

import torch
import torch.nn.functional as F
from torch.utils.data import (
    TensorDataset, DataLoader, RandomSampler, SequentialSampler, WeightedRandomSampler
)
from transformers import (
    AdamW, get_linear_schedule_with_warmup,
    ElectraTokenizerFast, ElectraForSequenceClassification, ElectraForTokenClassification,
    BertTokenizerFast, BertForSequenceClassification
)

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
# device = 'cpu'
device = 'cuda'
seed = 1235
batch_size = 4
learning_rate = 1e-5
warmup_proportion = 0.1
epochs = 10
warmup_steps = 1000
# max_len= 256

gradient_accumulation_steps = 4
num_train_epochs = epochs

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f81de31e990>

# Load data

## data

In [None]:
train = []
with open('data/nikluge-sa-2022-train.jsonl', 'r') as file:
    for line in file.readlines():
        train.append(json.loads(line))

train[34]= {'id': 'nikluge-sa-2022-train-00035',
 'sentence_form': '물론 쥐시장 가면 천원짜리도 많겠지만 가격대비 면이 좋네.',
 'annotation': [['본품#품질', ['면', 26, 27], 'positive']]}
train[2529]= {'id': 'nikluge-sa-2022-train-02530',
 'sentence_form': '피부가 건조해서 거칠어지신 분들... 피부 속과 겉 모두 건조하신 분들은 #수분집중 라인을 선택해주시면 되겠습니당!',
 'annotation': [['제품 전체#일반', [' #수분집중 라인', 41, 49], 'positive']]}

val = []
with open('data/nikluge-sa-2022-dev.jsonl', 'r') as file:
    for line in file.readlines():
        val.append(json.loads(line))

val[2609] = {'id': 'nikluge-sa-2022-dev-02610',
 'sentence_form': '매일 아침저녁으로 열심히 바르고 있는 크림이에요 🙆🏻💞',
 'annotation': [['제품 전체#일반', ['크림', 21, 23], 'positive']]}

In [None]:
entity_property_pair = [
    '본품#가격', '본품#다양성', '본품#디자인', '본품#인지도', '본품#일반', '본품#편의성', '본품#품질', 
    '브랜드#가격', '브랜드#디자인', '브랜드#인지도', '브랜드#일반', '브랜드#품질', 
    '제품 전체#가격', '제품 전체#다양성', '제품 전체#디자인', '제품 전체#인지도', '제품 전체#일반', '제품 전체#편의성', '제품 전체#품질', 
    '패키지/구성품#가격', '패키지/구성품#다양성', '패키지/구성품#디자인', '패키지/구성품#일반', '패키지/구성품#편의성', '패키지/구성품#품질'
]

entity_property_pair.sort()

entity_name_to_id = {entity_property_pair[i]: i for i in range( len(entity_property_pair) )}
entity_id_to_name = {v:k for k, v in entity_name_to_id.items()}

label_id_to_name = ['True', 'False']
label_name_to_id = {label_id_to_name[i]: i for i in range(len(label_id_to_name))}

polarity_id_to_name = ['positive', 'negative', 'neutral']
polarity_name_to_id = {polarity_id_to_name[i]: i for i in range(len(polarity_id_to_name))}

In [None]:
from copy import deepcopy
def convert(array: np.array, data: list):
    iter = array.shape[0]
    if iter != len(data):
        raise ValueError
    result = deepcopy(data)
    for i in range(iter):
        ind = np.where(array[i]==1)[0]
        result[i]['annotation'] = [[ entity_id_to_name[k], 'positive'] for k in ind]
    return result

# Set model

## ACD

In [None]:
tokenizer = ElectraTokenizerFast.from_pretrained("kykim/electra-kor-base")
model = ElectraForSequenceClassification.from_pretrained("kykim/electra-kor-base", num_labels = 1)

## ASC

In [None]:
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from transformers import ElectraModel
from transformers.models.electra.modeling_electra import ElectraPreTrainedModel, ElectraClassificationHead, ElectraDiscriminatorPredictions

class ElectraForCustom(ElectraPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.electra = ElectraModel(config)
        self.classifier = ElectraClassificationHead(config)

        config_category = deepcopy(config)
        config_category.num_labels = 1
        self.category_classifier = ElectraClassificationHead(config_category)

        self.init_weights()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        output = self.electra(input_ids, attention_mask, token_type_ids)
        cls = self.classifier(output[0])
        rtd = self.category_classifier(output[0])

        return cls, rtd

In [None]:
tokenizer = ElectraTokenizerFast.from_pretrained("kykim/electra-kor-base")
model_asc_encoder = ElectraForCustom.from_pretrained("kykim/electra-kor-base", num_labels = 1)
model_asc = ElectraForCustom.from_pretrained("kykim/electra-kor-base", num_labels = len(polarity_id_to_name))

# Data loader

## ACD

In [None]:
def get_devset(batch, train, max_len = 512):
    train_docs = [t['sentence_form'] for t in train]
    docs = [train_docs[b] for b in batch]
    train_batch = [train[b] for b in batch]

    result = []
    labels = []
    
    for idx, t in enumerate(train_batch):
        lab = torch.zeros(max_len, dtype=torch.long)
        cat_ids = [0] * len(entity_name_to_id)
        for a in t['annotation']:
            cat_ids[entity_name_to_id[a[0]]] = 1
        
        sents = [t['sentence_form']] * len(entity_name_to_id)
        cats = list(entity_name_to_id.keys())
        tokenized = tokenizer(sents, cats, padding='max_length', truncation=True, max_length=max_len)
        tokenized = [v for k, v in tokenized.items()]
        result.append(tokenized)
        labels.append(cat_ids)

    result = torch.tensor(result, dtype=torch.long).transpose(1,2)
    labels = torch.tensor(labels, dtype=torch.long)
    print(result.shape, labels.shape)
    return result, labels


In [None]:
train_range = torch.arange(len(train))
train_ds = TensorDataset(*get_devset(train_range, train))
train_loader = DataLoader(train_ds, sampler=RandomSampler(train_ds), batch_size=batch_size)

dev_range = torch.arange(len(val))
dev_ds = TensorDataset(*get_devset(dev_range, val))
dev_loader = DataLoader(dev_ds, sampler=SequentialSampler(dev_ds), batch_size=batch_size)

torch.Size([3001, 25, 3, 512]) torch.Size([3001, 25])
torch.Size([2794, 25, 3, 512]) torch.Size([2794, 25])


In [None]:
# for testset

test_range = torch.arange(len(test))
test_ds = TensorDataset(*get_devset(test_range, test))
test_loader = DataLoader(test_ds, sampler=SequentialSampler(test_ds), batch_size=16)

torch.Size([2127, 25, 3, 512]) torch.Size([2127, 25])


# Optimizer

## ACD

In [None]:
warmup_steps = 200
learning_rate = 5e-5
epochs = 20

In [None]:
num_data = len(train)
num_training_steps = int(num_data / (batch_size*gradient_accumulation_steps) ) * epochs

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

# optimizer = AdamW(
#     optimizer_grouped_parameters,
#     lr=learning_rate,
# )

import torch_optimizer as optim
optimizer = optim.Lamb(optimizer_grouped_parameters, lr=learning_rate, )

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps
)

# Train

## set

In [None]:
from datetime import date
today = str(date.today()).replace('-','')
print(today)

20221107


## ACD

In [None]:
num_train_epochs = epochs

logit_threshold = 3.0
topk = 1

best_score = 0
global_step = 0
total_loss = 0
f1s = []

print(today)
print(acd_save_path, logit_threshold, topk, num_train_epochs, batch_size, best_score, sep = ', ')
print()

for ep in range(int(num_train_epochs)):
    model.train()
    pbar = tqdm(train_loader)
    for step, batch in enumerate(pbar):
        inputs, labels = [b.to(device) for b in batch]
        
        inputs = inputs.reshape(-1, 3, 512)
        input_ids, input_segment, input_masks = inputs[:,0], inputs[:,1], inputs[:,2]
        len_max = max(input_masks.sum(-1))
        input_ids, input_masks, segment_ids = input_ids[:,:len_max], input_masks[:,:len_max], input_segment[:,:len_max]

        outputs = model(input_ids, input_masks, segment_ids)

        # cls: bs * num lab * len
        cls_logit = outputs['logits'].reshape(-1, len(entity_name_to_id))

        loss_fct = torch.nn.CrossEntropyLoss()
        lab = labels.float() / labels.sum(1, keepdim = True)
        loss = loss_fct(cls_logit, lab)

        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        total_loss += loss.item()

        loss.backward()
        pbar.update(1)

        lr = optimizer.param_groups[0]['lr']

        if (step + 1) % 10 == 0:
            #pbar.set_description(desc=f'loss:{np.mean(total_loss)}')
            pbar.set_description(desc='  loss : {}, lr : {}'.format(total_loss/10, round(lr, 6)))
            total_loss = 0
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            global_step += 1

    # gamma *= 0.7
    time.sleep(0.1)

    model.eval()
    dev_pred = np.array([], dtype = np.float32).reshape(-1, len(entity_name_to_id))
    dev_true = np.array([], dtype = np.int32).reshape(-1, len(entity_name_to_id))
    dev_losses = []
    # dev_step = 0
    for batch in dev_loader:
        inputs, labels = [b.to(device) for b in batch]
        
        inputs = inputs.reshape(-1, 3, 512)
        input_ids, input_segment, input_masks = inputs[:,0], inputs[:,1], inputs[:,2]
        len_max = max(input_masks.sum(-1))
        input_ids, input_masks, segment_ids = input_ids[:,:len_max], input_masks[:,:len_max], input_segment[:,:len_max]

        with torch.no_grad():
            outputs = model(input_ids, input_masks, segment_ids)
            # b * num cats
            cls_logit = outputs['logits'].reshape(-1, len(entity_name_to_id))
            lab = labels.float() / labels.sum(1, keepdim = True)
            dev_losses.append(torch.nn.CrossEntropyLoss()(cls_logit, lab).item())
            
        true = labels.detach().cpu().numpy()
               
        pred = (cls_logit > logit_threshold).int().detach().cpu().numpy()
        if np.where(pred.sum(-1) == 0)[0].size == 0:
            pass
        else:
            replaced_ind = np.where(pred.sum(-1) == 0)[0]
            replace = F.one_hot( torch.topk(cls_logit, topk, -1)[1], len(entity_name_to_id) ).sum(1).detach().cpu().numpy()
            pred[replaced_ind] = replace[replaced_ind]
        # pred = F.one_hot( torch.topk(cls_logit, 2, -1)[1], len(entity_name_to_id)//2 ).sum(1).detach().cpu().numpy()
        # pred = (F.one_hot(outputs['logits'].argmax(2), len(entity_name_to_id)+1).sum(1)).detach().cpu()
        # pred = (pred[:, 1::2] + pred[:, 2::2] > 0).int().numpy()

        dev_pred = np.vstack((dev_pred, pred))
        dev_true = np.vstack((dev_true, true))
            
    # score
    score = evaluation_f1(val, convert(dev_pred.astype(int), val))['category extraction result']['F1']
    f1 = f1_score(dev_true.astype(int), dev_pred.astype(int), average = 'micro')
    f1s.append(f1)
    print('epoch : {}, dev score : {}, loss : {}'.format(ep, score, np.mean(dev_losses)), end = '\t')
  
    if score > best_score:
        print('-----------best----------')
        best_score = score
        torch.save(model.state_dict(), acd_save_path)
    else:
        print()
    time.sleep(0.1)

## ASC

### train encoder with train set and NSMC

#### data loader

In [None]:
def get_devset(batch, train, max_len = 512):
    train_docs = [t['sentence_form'] for t in train]
    docs = [train_docs[b] for b in batch]
    train_batch = [train[b] for b in batch]

    result = []
    labels = []
    
    for idx, t in enumerate(train_batch):
        lab = torch.zeros(max_len, dtype=torch.long)
        cat_ids = [0] * len(entity_name_to_id)
        for a in t['annotation']:
            cat_ids[entity_name_to_id[a[0]]] = 1
        
        sents = [t['sentence_form']] * len(entity_name_to_id)
        cats = list(entity_name_to_id.keys())
        tokenized = tokenizer(sents, cats, padding='max_length', truncation=True, max_length=max_len)
        tokenized = [v for k, v in tokenized.items()]
        result.append(tokenized)
        labels.append(cat_ids)

    result = torch.tensor(result, dtype=torch.long).transpose(1,2)
    labels = torch.tensor(labels, dtype=torch.long)
    print(result.shape, labels.shape)
    return result, labels

def get_devset_sa(batch, train: tuple, max_len = 512):
    x, y = train
    batch_x = [x[b] for b in batch]
    batch_y = [y[b] for b in batch]

    tokenized = tokenizer(batch_x, padding='max_length', truncation=True, max_length=max_len)
    tokenized = [v for k, v in tokenized.items()]

    result = torch.tensor(tokenized, dtype=torch.long).transpose(0,1)
    labels = torch.tensor(batch_y, dtype=torch.long)
    
    print(result.shape, labels.shape)
    return result, labels


In [None]:
batch_size = 2
gradient_accumulation_steps = 4

In [None]:
# for electra model

train_range = torch.arange(len(train))
train_ds = TensorDataset(*get_devset(train_range, train))
train_loader = DataLoader(train_ds, sampler=RandomSampler(train_ds), batch_size=batch_size)

dev_range = torch.arange(len(val))
dev_ds = TensorDataset(*get_devset(dev_range, val))
dev_loader = DataLoader(dev_ds, sampler=SequentialSampler(dev_ds), batch_size=batch_size)

#### optimizer

In [None]:
warmup_steps = 300
learning_rate = 5e-5
epochs = 20

In [None]:
num_data = len(train)
num_training_steps = int(num_data / (batch_size*gradient_accumulation_steps) ) * epochs

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model_asc_encoder.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model_asc_encoder.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

# optimizer = AdamW(
#     optimizer_grouped_parameters,
#     lr=learning_rate,
# )

import torch_optimizer as optim
optimizer = optim.Lamb(optimizer_grouped_parameters, lr=learning_rate, )

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps
)

#### execute

In [None]:
model_asc_encoder.to(device)
print(model.device)

cuda:0


In [None]:
from itertools import cycle

rng = np.random.default_rng(10)
rtrain_perm = rng.permutation(len(rtrain_x))
rtrain_perm = np.tile(rtrain_perm, 3)

asc_encoder_path = './model/electra_asc_encoder.bin'

In [None]:
dev_range = torch.arange(len(rtest_x))
dev_ds = TensorDataset(*get_devset_sa(dev_range, (rtest_x, rtest_y)) )
dev_loader = DataLoader(dev_ds, sampler=SequentialSampler(dev_ds), batch_size = 16 )

torch.Size([50000, 3, 512]) torch.Size([50000])


In [None]:
num_train_epochs = epochs

# logit_threshold = .2
# topk = 1

best_score = 0
global_step = 0
total_loss = 0
gamma = 1

print(today)
print(asc_encoder_path, num_train_epochs, batch_size, best_score, sep = ', ')
print()

for ep in range(int(num_train_epochs)):
    model_asc_encoder.train()

    alpha = 4
    windows = train_ds.tensors[0].shape[0] * alpha
    train_range_sa = torch.LongTensor(rtrain_perm)[ep * windows: (ep+1) * windows ]
    train_ds_sa = TensorDataset(*get_devset_sa(train_range_sa, (rtrain_x, rtrain_y)) )
    train_loader_sa = DataLoader(train_ds_sa, sampler=RandomSampler(train_ds_sa), batch_size=batch_size  * alpha )

    pbar = tqdm(zip(cycle(train_loader), train_loader_sa), total = int(train_ds.tensors[0].shape[0] / train_loader.batch_size) +1  )

    for step, (batch, batch_sa) in enumerate(pbar):
        inputs, labels = [b.to(device) for b in batch]
        
        inputs = inputs.reshape(-1, 3, 512)
        input_ids, input_segment, input_masks = inputs[:,0], inputs[:,1], inputs[:,2]
        len_max = max(input_masks.sum(-1))
        input_ids, input_masks, segment_ids = input_ids[:,:len_max], input_masks[:,:len_max], input_segment[:,:len_max]

        _, outputs = model_asc_encoder(input_ids, input_masks, segment_ids)

        # cls: bs * num lab * len
        cls_logit = outputs.reshape(-1, len(entity_name_to_id))

        loss_fct = torch.nn.CrossEntropyLoss()
        lab = labels.float() / labels.sum(1, keepdim = True)
        loss_cat = loss_fct(cls_logit, lab)

        ######
        inputs, labels = [b.to(device) for b in batch_sa]
        
        input_ids, input_segment, input_masks = inputs[:,0], inputs[:,1], inputs[:,2]
        len_max = max(input_masks.sum(-1))
        input_ids, input_masks, segment_ids = input_ids[:,:len_max], input_masks[:,:len_max], input_segment[:,:len_max]

        outputs_sa, _ = model_asc_encoder(input_ids, input_masks, segment_ids)
        loss_bce = torch.nn.BCEWithLogitsLoss()
        loss_sa = loss_bce(outputs_sa.flatten(), labels.float())
        loss = loss_sa + loss_cat * gamma

        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        total_loss += loss.item()

        loss.backward()
        # pbar.update(1)

        lr = optimizer.param_groups[0]['lr']

        if (step + 1) % 10 == 0:
            #pbar.set_description(desc=f'loss:{np.mean(total_loss)}')
            pbar.set_description(desc='  loss : {}, lr : {}'.format(total_loss/10, round(lr, 6)))
            total_loss = 0
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            global_step += 1

    gamma *= 0.9
    time.sleep(0.1)

    model_asc_encoder.eval()
    dev_true = []
    dev_pred = []
    dev_losses = []
    # dev_step = 0
    for batch in tqdm(dev_loader, total = int(dev_ds.tensors[0].shape[0] / dev_loader.batch_size) + 1 ):
        inputs, labels = [b.to(device) for b in batch]
        
        input_ids, input_segment, input_masks = inputs[:,0], inputs[:,1], inputs[:,2]
        len_max = max(input_masks.sum(-1))
        input_ids, input_masks, segment_ids = input_ids[:,:len_max], input_masks[:,:len_max], input_segment[:,:len_max]

        with torch.no_grad():
            outputs_sa, _ = model_asc_encoder(input_ids, input_masks, segment_ids)
            # b * num cats
            loss_bce = torch.nn.BCEWithLogitsLoss()
            loss_sa = loss_bce(outputs_sa.flatten(), labels.float())
            dev_losses.append(loss_sa.item())
            
        true = labels.detach().cpu().numpy().tolist()
        pred = (outputs_sa.flatten() > 0.).int().detach().cpu().numpy().tolist()

        dev_true += true
        dev_pred += pred
            
    # score
    score = accuracy_score(true, pred)
    # score = evaluation_f1(val, convert(dev_pred.astype(int), val))['category extraction result']['F1']
    # f1 = f1_score(dev_true.astype(int), dev_pred.astype(int), average = 'micro')
    # f1s.append(f1)
    print('epoch : {}, dev score : {}, loss : {}'.format(ep, score, np.mean(dev_losses)), end = '\t')
  
    if score > best_score:
        print('-----------best----------')
        best_score = score
        torch.save(model_asc_encoder.state_dict(), asc_encoder_path)
    else:
        print()
    time.sleep(0.1)

### train classifier

#### data loader

In [None]:
def get_trainset(batch, train, max_len = 512):
    train_docs = [t['sentence_form'] for t in train]
    docs = [train_docs[b] for b in batch]
    train_batch = [train[b] for b in batch]

    result = torch.LongTensor([]).reshape(-1, 3, max_len)
    labels = []
    
    for idx, t in enumerate(train_batch):
        polar = []
        entity = []
        for a in t['annotation']:
            entity.append(a[0])
            polar.append(a[2])
        
        sents = [t['sentence_form']] * len(entity)
        tokenized = tokenizer(sents, entity, padding='max_length', truncation=True, max_length=max_len)
        tokenized = torch.LongTensor([v for k, v in tokenized.items()]).transpose(0, 1 )
        result = torch.vstack((result, tokenized))
        labels += [ polarity_name_to_id[p] for p in polar]

    labels = torch.tensor(labels, dtype=torch.long)
    print(result.shape, labels.shape)
    return result, labels

In [None]:
# ASC

batch_size = 16
gradient_accumulation_steps = 1

In [None]:
train_range = torch.arange(len(train))
train_ds = TensorDataset(*get_trainset(train_range, train))
train_loader = DataLoader(train_ds, sampler=RandomSampler(train_ds), batch_size=batch_size)

In [None]:
# class_weights = {k: sum(polarity_counts.values()) /v for k, v in polarity_counts.items()}
# weights = [ class_weights[polarity_id_to_name[t]] for t in train_ds.tensors[1] ]

# train_loader = DataLoader(train_ds, sampler=WeightedRandomSampler(weights, len(weights), replacement=True), 
# batch_size=batch_size)

In [None]:
dev_range = torch.arange(len(val))
dev_ds = TensorDataset(*get_trainset(dev_range, val))
dev_loader = DataLoader(dev_ds, sampler=SequentialSampler(dev_ds), batch_size=batch_size)

#### optimizer

In [None]:
warmup_steps = 200
learning_rate = 5e-5
epochs = 20

In [None]:
num_data = len(train)
num_training_steps = int(num_data / (batch_size*gradient_accumulation_steps) ) * epochs

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model_asc.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model_asc.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]

# optimizer = AdamW(
#     optimizer_grouped_parameters,
#     lr=learning_rate,
# )

import torch_optimizer as optim
optimizer = optim.Lamb(optimizer_grouped_parameters, lr=learning_rate, )

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps
)

#### execute

In [None]:
state_dict = model_asc.state_dict()
state_dict_pretrained_encoder = torch.load(asc_encoder_path, map_location = 'cpu')
for n,p in state_dict.items():
    if 'electra' in n:
        state_dict[n] = state_dict_pretrained_encoder[n]

model_asc.load_state_dict(state_dict)

In [None]:
model_asc.to(device)
print(device)

cuda


In [None]:
best_score = 0
global_step = 0
total_loss = 0

for ep in range(int(num_train_epochs)):
    model_asc.train()
    pbar = tqdm(train_loader)
    for step, batch in enumerate(pbar):
        inputs, labels = [b.to(device) for b in batch]
        
        input_ids, input_segment, input_masks = inputs[:,0], inputs[:,1], inputs[:,2]
        len_max = max(input_masks.sum(-1))
        input_ids, input_masks, segment_ids = input_ids[:,:len_max], input_masks[:,:len_max], input_segment[:,:len_max]

        outputs = model_asc(input_ids, input_masks, segment_ids)

        cls_logit = outputs['logits']

        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(cls_logit, labels)

        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        total_loss += loss.item()

        loss.backward()
        pbar.update(1)

        lr = optimizer.param_groups[0]['lr']

        if (step + 1) % 10 == 0:
            #pbar.set_description(desc=f'loss:{np.mean(total_loss)}')
            pbar.set_description(desc='  loss : {}, lr : {}'.format(total_loss/10, round(lr, 6)))
            total_loss = 0
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            global_step += 1

    # gamma *= 0.7
    time.sleep(0.1)

    model_asc.eval()
    dev_pred = np.array([])
    dev_true = np.array([])
    dev_losses = []
    # dev_step = 0
    for batch in dev_loader:
        inputs, labels = [b.to(device) for b in batch]
        
        input_ids, input_segment, input_masks = inputs[:,0], inputs[:,1], inputs[:,2]
        len_max = max(input_masks.sum(-1))
        input_ids, input_masks, segment_ids = input_ids[:,:len_max], input_masks[:,:len_max], input_segment[:,:len_max]

        with torch.no_grad():
            outputs = model_asc(input_ids, input_masks, segment_ids)
            # b * num cats
            cls_logit = outputs['logits']
            dev_losses.append(torch.nn.CrossEntropyLoss()(cls_logit, labels).item())
            
            true = labels.unsqueeze(0).detach().cpu().numpy()
               
            pred = cls_logit.argmax(1).unsqueeze(0).detach().cpu().numpy()
            # pred = F.one_hot( torch.topk(cls_logit, 2, -1)[1], len(entity_name_to_id)//2 ).sum(1).detach().cpu().numpy()
            # pred = (F.one_hot(outputs['logits'].argmax(2), len(entity_name_to_id)+1).sum(1)).detach().cpu()
            # pred = (pred[:, 1::2] + pred[:, 2::2] > 0).int().numpy()

            dev_pred = np.append(dev_pred, pred)
            dev_true = np.append(dev_true, true)
            
    # score
    # score = evaluation_f1(val, convert(dev_pred.astype(int)))['category extraction result']['F1']
    score = f1_score(dev_true.astype(int), dev_pred.astype(int), average = 'macro')
    
    print('epoch : {}, dev score : {}, loss : {}'.format(ep, score, np.mean(dev_losses)), end = '\t')
  
    if score > best_score:
        print('-----------best----------')
        best_score = score
        torch.save( model_asc.state_dict(), asc_save_path )
    else:
        print()
    time.sleep(0.1)

