In [5]:
import argparse
import os
import sys
import logging
import pickle
from functools import partial
import time
from tqdm import tqdm
from collections import Counter
import random
import numpy as np

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import LearningRateMonitor

from transformers import AdamW, T5Tokenizer
from mvp.t5 import MyT5ForConditionalGeneration
from transformers import get_linear_schedule_with_warmup

from mvp.data_utils import *
from mvp.eval_utils import *
from mvp.process import *
torch.set_float32_matmul_precision('high')

In [6]:
# setting args
class Args:
    def __init__(self):
        self.path = '/home/elicer/ABSA'
        self.data_path = f'{self.path}/data'
        self.method = 'mvp' # task 
        self.paraphrase = False # task 
        self.task = 'asqp' # task 
        self.dataset = 'rest16' # data 
        self.eval_data_split = 'test' # test or dev
        self.top_k = 5
        self.ctrl_token = "post"
        self.data_ratio = 1.0
        self.model_name_or_path = 't5-base' # used base model
        self.load_ckpt_name = None # 사전 훈련된 모델의 체크포인트 파일로드 
        self.do_train = False # train or not
        self.do_inference = True # inference or not
        self.max_seq_length = 200 # 입력 시퀀스 최대 길이
        self.n_gpu = 1 # gpu 개수
        self.train_batch_size = 16
        self.eval_batch_size = 64
        self.gradient_accumulation_steps = 1
        self.learning_rate = 1e-4
        self.num_train_epochs = 20
        self.seed = 25
        self.weight_decay = 0.0
        self.adam_epsilon = 1e-8
        self.warmup_steps = 0.0
        self.multi_path = False
        self.num_path = 1
        self.beam_size = 1
        self.save_top_k = 1
        self.check_val_every_n_epoch = 10
        self.single_view_type = "rank"
        self.sort_label = False
        self.load_path_cache = False
        self.lowercase = False
        self.multi_task = False
        self.constrained_decode = False
        self.agg_strategy = 'vote'

def init_args():
    args = Args()

    if args.task == 'asqp':
        args.lowercase = True

    if args.method == 'dlo':
        args.top_k = 1
        args.single_view_type = "heuristic"
        args.agg_strategy = 'heuristic'

    if args.method == 'paraphrase':
        args.paraphrase = True
        args.output_dir =  f'{args.path}/outputs/{args.method}/{args.task}/{args.dataset}/{args.ctrl_token}_data{args.data_ratio}'
    else:
        args.output_dir =  f'{args.path}/outputs/{args.method}/{args.task}/{args.dataset}/top_{args.top_k}_{args.ctrl_token}_data{args.data_ratio}'
    
    if not os.path.exists(args.output_dir):
        #os.mkdir(args.output_dir)
        os.makedirs(args.output_dir, exist_ok=True)
    return args

args = init_args()

print('method:', args.method)
print('output path:', args.output_dir)

method: mvp
output path: /home/elicer/ABSA/outputs/mvp/asqp/rest16/top_5_post_data1.0


In [7]:
class T5FineTuner(pl.LightningModule):
    """
    Fine tune a pre-trained T5 model
    """

    def __init__(self, config, tfm_model, tokenizer):
        super().__init__()
        self.save_hyperparameters(ignore=['tfm_model'])
        self.config = config
        self.model = tfm_model
        self.tokenizer = tokenizer

    def forward(self,
                input_ids,
                attention_mask=None,
                decoder_input_ids=None,
                decoder_attention_mask=None,
                labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
        )

    def _step(self, batch):
        lm_labels = batch["target_ids"]
        lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(input_ids=batch["source_ids"],
                       attention_mask=batch["source_mask"],
                       labels=lm_labels,
                       decoder_attention_mask=batch['target_mask'])

        loss = outputs[0]
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        # get f1
        outs = self.model.generate(input_ids=batch['source_ids'],
                                   attention_mask=batch['source_mask'],
                                   max_length=self.config.max_seq_length,
                                   return_dict_in_generate=True,
                                   output_scores=True,
                                   num_beams=1)

        dec = [
            self.tokenizer.decode(ids, skip_special_tokens=True)
            for ids in outs.sequences
        ]
        target = [
            self.tokenizer.decode(ids, skip_special_tokens=True)
            for ids in batch["target_ids"]
        ]
        if args.paraphrase:
            scores, _, _ = compute_scores(dec, target, args.paraphrase, verbose=False)
        else:
            scores, _, _ = compute_scores(dec, target, args.paraphrase, verbose=False)
        f1 = torch.tensor(scores['f1'], dtype=torch.float64)

        # get loss
        loss = self._step(batch)

        if stage:
            self.log(f"{stage}_loss",
                     loss,
                     prog_bar=True,
                     on_step=False,
                     on_epoch=True)
            self.log(f"{stage}_f1",
                     f1,
                     prog_bar=True,
                     on_step=False,
                     on_epoch=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        """ Prepare optimizer and schedule (linear warmup and decay) """
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.config.weight_decay,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.config.learning_rate,
                          eps=self.config.adam_epsilon)
        scheduler = {
            "scheduler":
            get_linear_schedule_with_warmup(optimizer,
                                            **self.config.lr_scheduler_init),
            "interval":
            "step",
        }
        return [optimizer], [scheduler]

    def train_dataloader(self):
        print("load training data.")
        train_dataset = ABSADataset(tokenizer=self.tokenizer,
                                    task_name=args.task,
                                    data_name=args.dataset,
                                    data_type="train",
                                    top_k=self.config.top_k,
                                    args=self.config,
                                    max_len=self.config.max_seq_length)

        dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.train_batch_size,
            drop_last=True
            if args.data_ratio > 0.3 else False, # don't drop on few-shot
            shuffle=True,
            num_workers=2)

        return dataloader

    def val_dataloader(self):
        val_dataset = ABSADataset(tokenizer=self.tokenizer,
                                  task_name=args.task,
                                  data_name=args.dataset,
                                  data_type="dev",
                                  top_k=self.config.num_path,
                                  args=self.config,
                                  max_len=self.config.max_seq_length)
        return DataLoader(val_dataset,
                          batch_size=self.config.eval_batch_size,
                          num_workers=2)

    @staticmethod
    def rindex(_list, _value):
        return len(_list) - _list[::-1].index(_value) - 1

    def prefix_allowed_tokens_fn(self, task, data_name, source_ids, batch_id,
                                 input_ids):
        """
        Constrained Decoding
        # ids = self.tokenizer("text", return_tensors='pt')['input_ids'].tolist()[0]
        """
        if not os.path.exists('./force_tokens.json'):
            dic = {"cate_tokens":{}, "all_tokens":{}, "sentiment_tokens":[], 'special_tokens':[]}
            for task in force_words.keys():
                dic["all_tokens"][task] = {}
                for dataset in force_words[task].keys():
                    cur_list = force_words[task][dataset]
                    tokenize_res = []
                    for w in cur_list:
                        tokenize_res.extend(self.tokenizer(w, return_tensors='pt')['input_ids'].tolist()[0])
                    dic["all_tokens"][task][dataset] = tokenize_res
            for k,v in cate_list.items():
                tokenize_res = []
                for w in v:
                    tokenize_res.extend(self.tokenizer(w, return_tensors='pt')['input_ids'].tolist()[0]) 
                dic["cate_tokens"][k] = tokenize_res
            sp_tokenize_res = []
            for sp in ['great', 'ok', 'bad']:
                sp_tokenize_res.extend(self.tokenizer(sp, return_tensors='pt')['input_ids'].tolist()[0])
            for task in force_words.keys():
                dic['sentiment_tokens'][task] = sp_tokenize_res
            dic['sentiment_tokens'] = sp_tokenize_res
            special_tokens_tokenize_res = []
            for w in ['[O','[A','[S','[C','[SS']:
                special_tokens_tokenize_res.extend(self.tokenizer(w, return_tensors='pt')['input_ids'].tolist()[0]) 
            special_tokens_tokenize_res = [r for r in special_tokens_tokenize_res if r != 784]
            dic['special_tokens'] = special_tokens_tokenize_res
            import json
            with open("force_tokens.json", 'w') as f:
                json.dump(dic, f, indent=4)

        to_id = {
            'OT': [667],
            'AT': [188],
            'SP': [134],
            'AC': [254],
            'SS': [4256],
            'EP': [8569],
            '[': [784],
            ']': [908],
            'it': [34],
            'null': [206,195]
        }

        left_brace_index = (input_ids == to_id['['][0]).nonzero()
        right_brace_index = (input_ids == to_id[']'][0]).nonzero()
        num_left_brace = len(left_brace_index)
        num_right_brace = len(right_brace_index)
        last_right_brace_pos = right_brace_index[-1][
            0] if right_brace_index.nelement() > 0 else -1
        last_left_brace_pos = left_brace_index[-1][
            0] if left_brace_index.nelement() > 0 else -1
        cur_id = input_ids[-1]

        if cur_id in to_id['[']:
            return force_tokens['special_tokens']
        elif cur_id in to_id['AT'] + to_id['OT'] + to_id['EP'] + to_id['SP'] + to_id['AC']:  
            return to_id[']']  
        elif cur_id in to_id['SS']:  
            return to_id['EP'] 

        # get cur_term
        if last_left_brace_pos == -1:
            return to_id['['] + [1]   # start of sentence: [
        elif (last_left_brace_pos != -1 and last_right_brace_pos == -1) \
            or last_left_brace_pos > last_right_brace_pos:
            return to_id[']']  # ]
        else:
            cur_term = input_ids[last_left_brace_pos + 1]

        ret = []
        if cur_term in to_id['SP']:  # SP
            ret = force_tokens['sentiment_tokens'][task]
        elif cur_term in to_id['AT']:  # AT
            force_list = source_ids[batch_id].tolist()
            if task != 'aste':  
                force_list.extend(to_id['it'] + [1])  
            ret = force_list  
        elif cur_term in to_id['SS']:
            ret = [3] + to_id[']'] + [1]
        elif cur_term in to_id['AC']:  # AC
            ret = force_tokens['cate_tokens'][data_name]
        elif cur_term in to_id['OT']:  # OT
            force_list = source_ids[batch_id].tolist()
            if task == "acos":
                force_list.extend(to_id['null'])  # null
            ret = force_list
        else:
            raise ValueError(cur_term)

        if num_left_brace == num_right_brace:
            ret = set(ret)
            ret.discard(to_id[']'][0]) # remove ]
            for w in force_tokens['special_tokens']:
                ret.discard(w)
            ret = list(ret)
        elif num_left_brace > num_right_brace:
            ret += to_id[']'] 
        else:
            raise ValueError
        ret.extend(to_id['['] + [1]) # add [
        return ret

In [10]:
for i in range(1,11):

    args.eval_data_split = f'train_zero_{i}'
    print("\n****** Conduct inference on trained checkpoint ******")
    
    # initialize the T5 model from previous checkpoint
    print(f"Load trained model from {args.output_dir}")
    print(
        'Note that a pretrained model is required and `do_true` should be False'
    )

    if args.task == 'asqp' and args.dataset == 'rest16':
        model_path = os.path.join(args.output_dir, "final2")
    else:
        model_path = os.path.join(args.output_dir, "final")

    if args.paraphrase:
        tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, local_files_only=True if args.model_name_or_path != "t5-base" else False)
    else:
        tokenizer = T5Tokenizer.from_pretrained(model_path)

    tfm_model = MyT5ForConditionalGeneration.from_pretrained(model_path)
    model = T5FineTuner(args, tfm_model, tokenizer)
    
    if args.load_ckpt_name:
        ckpt_path = os.path.join(args.output_dir, args.load_ckpt_name)
        print("Loading ckpt:", ckpt_path)
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint["state_dict"])
    
    log_file_path = os.path.join(args.output_dir, "result.txt")
    
    # compute the performance scores
    with open(log_file_path, "a+") as f:
        config_str = f"seed: {args.seed}, beam: {args.beam_size}, constrained: {args.constrained_decode}\n"
        print(config_str)
        f.write(config_str)
    
        if args.multi_task:
            f1s = []
            for task in task_data_list:
                for data in task_data_list[task]:
                    scores = evaluate(model, task, data, data_type=args.eval_data_split)
                    print(task, data, scores)
                    exp_results = "{} {} precision: {:.2f} recall: {:.2f} F1 = {:.2f}".format(
                        args.eval_data_split, args.agg_strategy, scores['precision'], scores['recall'],
                        scores['f1'])
                    f.write(f"{task}: \t{data}: \t{exp_results}\n")
                    f.flush()
                    f1s.append(scores['f1'])
            f.write(f"Average F1: \t{sum(f1s) / len(f1s)}\n")
            f.flush()
        else:
            scores = evaluate(args,
                              model,
                              args.task,
                              args.dataset,
                            data_type=args.eval_data_split)
    
            exp_results = "{} {} precision: {:.2f} recall: {:.2f} F1 = {:.2f}".format(
                args.eval_data_split, args.agg_strategy, scores['precision'], scores['recall'], scores['f1'])
            print()
            print(exp_results)
            f.write(exp_results + "\n")
            f.flush()


****** Conduct inference on trained checkpoint ******
Load trained model from /home/elicer/ABSA/outputs/mvp/asqp/rest16/top_5_post_data1.0
Note that a pretrained model is required and `do_true` should be False
seed: 25, beam: 1, constrained: False

Total examples = 723
Total examples = 723
723 723 723


100%|██████████| 12/12 [01:10<00:00,  5.90s/it]


pred labels count Counter({1: 452, 2: 182, 3: 72, 4: 15, 5: 1, 6: 1})
gold  [O] not any longer [A] place [C] restaurant general [S] bad
pred  [O] good [A] place [C] restaurant general [S] great

gold  [O] null [A] it [C] service general [S] bad
pred  [O] never brought us complimentary noodles [A] noodles [C] food quality [S] bad [SSEP] [O] ignored repeated requests for sugar [A] it [C] service general [S] bad

gold  [O] outrageously good [A] food [C] food quality [S] great
pred  [O] good [A] food [C] food quality [S] great

gold  [O] well [A] it [C] restaurant prices [S] great [SSEP] [O] well [A] it [C] food quality [S] great
pred  [O] can not eat this well [A] it [C] food quality [S] bad [SSEP] [O] well [A] it [C] food prices [S] bad

gold  [O] null [A] cart attendant [C] service general [S] bad
pred  [O] walked away [A] cart attendant [C] service general [S] bad

gold  [O] null [A] it [C] service general [S] bad
pred  [O] asked her three times [A] it [C] service general [S] bad

gold

100%|██████████| 12/12 [01:08<00:00,  5.72s/it]


pred labels count Counter({1: 494, 2: 161, 3: 56, 4: 12})
gold  [O] not any longer [A] place [C] restaurant general [S] bad
pred  [O] good [A] place [C] restaurant general [S] great

gold  [O] null [A] it [C] service general [S] bad
pred  [O] threw our dishes on the table [A] it [C] service general [S] bad

gold  [O] outrageously good [A] food [C] food quality [S] great
pred  [O] good [A] food [C] food quality [S] great

gold  [O] well [A] it [C] restaurant prices [S] great [SSEP] [O] well [A] it [C] food quality [S] great
pred  [O] can not eat this well [A] it [C] food quality [S] bad [SSEP] [O] well [A] it [C] food prices [S] bad

gold  [O] null [A] cart attendant [C] service general [S] bad
pred  [O] walked away [A] cart attendant [C] service general [S] bad

gold  [O] null [A] it [C] service general [S] bad
pred  [O] asked her three times [A] it [C] service general [S] bad

gold  [O] great [A] fish [C] restaurant general [S] great [SSEP] [O] glad [A] it [C] restaurant general [S] g

100%|██████████| 12/12 [01:12<00:00,  6.07s/it]


pred labels count Counter({1: 447, 2: 180, 3: 78, 4: 14, 5: 3, 6: 1})
gold  [O] not any longer [A] place [C] restaurant general [S] bad
pred  [O] good [A] place [C] restaurant general [S] great

gold  [O] null [A] it [C] service general [S] bad
pred  [O] never brought us complimentary noodles [A] noodles [C] food quality [S] bad [SSEP] [O] ignored repeated requests for sugar [A] it [C] service general [S] bad

gold  [O] outrageously good [A] food [C] food quality [S] great
pred  [O] good [A] food [C] food quality [S] great

gold  [O] well [A] it [C] restaurant prices [S] great [SSEP] [O] well [A] it [C] food quality [S] great
pred  [O] can not eat this well [A] it [C] food quality [S] bad [SSEP] [O] well [A] it [C] food prices [S] bad

gold  [O] null [A] cart attendant [C] service general [S] bad
pred  [O] walked away [A] cart attendant [C] service general [S] bad

gold  [O] null [A] it [C] service general [S] bad
pred  [O] asked her three times [A] it [C] service general [S] bad

gold

100%|██████████| 12/12 [01:11<00:00,  5.92s/it]


pred labels count Counter({1: 450, 2: 175, 3: 82, 4: 13, 5: 3})
gold  [O] not any longer [A] place [C] restaurant general [S] bad
pred  [O] good [A] place [C] restaurant general [S] great

gold  [O] null [A] it [C] service general [S] bad
pred  [O] never brought us complimentary noodles [A] noodles [C] food quality [S] bad [SSEP] [O] ignored repeated requests for sugar [A] it [C] service general [S] bad

gold  [O] outrageously good [A] food [C] food quality [S] great
pred  [O] good [A] food [C] food quality [S] great

gold  [O] well [A] it [C] restaurant prices [S] great [SSEP] [O] well [A] it [C] food quality [S] great
pred  [O] can not eat this well [A] it [C] food quality [S] bad [SSEP] [O] well [A] it [C] food prices [S] bad

gold  [O] null [A] cart attendant [C] service general [S] bad
pred  [O] walked away [A] cart attendant [C] service general [S] bad

gold  [O] null [A] it [C] service general [S] bad
pred  [O] came back [A] dish [C] food quality [S] great

gold  [O] great [A] f

100%|██████████| 12/12 [01:10<00:00,  5.87s/it]


pred labels count Counter({1: 438, 2: 183, 3: 87, 4: 13, 5: 2})
gold  [O] not any longer [A] place [C] restaurant general [S] bad
pred  [O] good [A] place [C] restaurant general [S] great

gold  [O] null [A] it [C] service general [S] bad
pred  [O] never brought us complimentary noodles [A] noodles [C] food quality [S] bad [SSEP] [O] ignored repeated requests for sugar [A] it [C] service general [S] bad

gold  [O] outrageously good [A] food [C] food quality [S] great
pred  [O] good [A] food [C] food quality [S] great

gold  [O] well [A] it [C] restaurant prices [S] great [SSEP] [O] well [A] it [C] food quality [S] great
pred  [O] can not eat this well [A] it [C] food quality [S] bad [SSEP] [O] well [A] it [C] food prices [S] bad

gold  [O] null [A] cart attendant [C] service general [S] bad
pred  [O] walked away [A] cart attendant [C] service general [S] bad

gold  [O] null [A] it [C] service general [S] bad
pred  [O] asked her three times [A] it [C] service general [S] bad

gold  [O] 

 50%|█████     | 6/12 [00:40<00:40,  6.72s/it]


KeyboardInterrupt: 