<a href="https://colab.research.google.com/github/shinichiromizuno/QueryMultiTopic/blob/master/BERT_Base.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Mount your Google Drive.
# Connect to GPU in Google Colab.

In [None]:
################ Preprocess ################

In [None]:
# Configure working directories.
work_dir =  '/content/drive/MyDrive/work_BERT_Base'

In [None]:
# Install required packages and configure module path.
!pip install transformers
!pip install pyrouge
%env MODULE_PATH=/content/drive/MyDrive/QueryMultiTopic

In [None]:
# Step1. Data Cleansing for ESG Dataset
from glob import glob
import os
import shutil

dir = work_dir + '/DatasetBeforeTokenization'
if os.path.exists(dir):
  shutil.rmtree(dir)
os.mkdir(dir)

dsdirs = '/content/drive/MyDrive/DatasetSDGs/'
years = ['2021', '2020', '2019', '2018', '2017']

for dsdir in sorted(glob(dsdirs + '/*')):
  for year in years:
    for filename in sorted(glob(dsdir + '/' + year + '/*')):
      text = []
      if filename[-4:] == '.txt':
        with open(filename, mode='r') as f:
          prev = ''
          for line in f:
            line = line.strip().replace('\u2028', ' ').replace('\u2029', ' ')
            if line == '':
              continue
            if line[-1:] in ('-', ','):
              prev += line + ' '
              continue
            if line[0].islower() and prev != '':
              line = prev + line
            if line[0].isupper() and line[-1:] == '.':
              line = line.replace('- ', '')
              text.append(line)
            prev = ''
        filename = filename.split('/')[-1].replace('doc.txt', 'doc-.txt')
        with open(dir + '/' + dsdir.split('/')[-1] + '-' + year + '-' + filename.split('/')[-1], mode='w') as f:
          text = ' '.join(text)
          f.write(text)

In [None]:
# Step2. Sentence Split & Tokenization

# Download Stanford CoreNLP
!wget https://nlp.stanford.edu/software/stanford-corenlp-4.2.2.zip
!unzip stanford-corenlp-4.2.2.zip
!rm stanford-corenlp-4.2.2.zip
%env PYTHONDONTWRITEBYTECODE=1

# Configuration
%env CLASSPATH=/content/stanford-corenlp-4.2.2/stanford-corenlp-4.2.2.jar
%env RAW_PATH=$work_dir/DatasetBeforeTokenization

In [None]:
# Sentence Splitting and Tokenization
import os
import shutil
dir = work_dir + '/DatasetAfterTokenization'
if os.path.exists(dir):
  shutil.rmtree(dir)
os.mkdir(dir)

%env TOKENIZED_PATH=$work_dir/DatasetAfterTokenization
!mkdir logs
!python $MODULE_PATH/src/preprocess.py -mode tokenize -raw_path $RAW_PATH -save_path $TOKENIZED_PATH -log_file /content/logs/preprocess.log

In [None]:
# Step3. Creating Simple JSON Files
import json
import os
import shutil
from glob import glob

dir = work_dir + '/json_data'
if os.path.exists(dir):
  shutil.rmtree(dir)
os.mkdir(dir)

key = ''
dict = {}
for data in sorted(glob(work_dir + '/DatasetAfterTokenization/*')):
  json_open = open(data)
  json_load = json.load(json_open)
  s_list = [[w['word'] for w in s['tokens']] for s in json_load['sentences']]
  num, cmp, year, dtype, labels = data.replace(work_dir + '/DatasetAfterTokenization/', '').split('-')
  labels, _, _ = labels.split('.')
  if key != (num + '.' + cmp + '.' + year):
    if any(dict) == True:
      with open(dir + '/' + key + '.json', mode='w') as f:
        f.write(json.dumps(dict))
      dict = {}
    key = num + '.' + cmp + '.' + year
  if dtype == 'doc':
    dict['src'] = s_list
  if dtype[:3] == 'sum':
    for label in labels.split(','):
      tgt = 'tgt' + label
      if not dict.get(tgt):
        dict[tgt] = []
      dict[tgt] = dict[tgt] + [s_list]
  else:
    None

if any(dict) == True:
  with open(dir + '/' + key + '.json', mode='w') as f:
    f.write(json.dumps(dict))

In [None]:
# Excluding short sentences
threshold = 8

for data in sorted(glob(work_dir + '/json_data/*')):
  json_open = open(data)
  json_load = json.load(json_open)
  json_load_trimmed = {}
  json_load_trimmed['src'] = [s for s in json_load['src'] if len(s) > threshold]
  keys = list(json_load.keys())
  keys.remove('src')
  
  for i in range(1,17+1):
    tgt_idx = 'tgt' + str(i)
    if tgt_idx not in keys:
      continue
    tgts = json_load[tgt_idx]
    tgts_trimmed = []
    for tgt in tgts:
      tgts_trimmed.append([s for s in tgt if len(s) > threshold])
    json_load_trimmed[tgt_idx] = tgts_trimmed
  with open(data, mode='w') as f:
    f.write(json.dumps(json_load_trimmed))

In [None]:
# Splitting Train, Validation and Test Data
import os
import random

def train_test_valid_list_gen(l_size, tr=0.7, te=0.15, v=0.15):
  if round(tr+te+v, 10) != 1:
    raise Exception('Total allocation is not equal to 1')
  n_te = round(l_size * te)
  n_v = round(l_size * v)
  n_tr = l_size - (n_te + n_v)
  ttvl = ['train'] * n_tr
  ttvl += ['test'] * n_te
  ttvl += ['valid'] * n_v
  random.shuffle(ttvl)
  return ttvl

ttvl = train_test_valid_list_gen(len(glob(work_dir + '/json_data/*')))

for data, filetype in zip(sorted(glob(work_dir + '/json_data/*')), ttvl):
  os.rename(data, work_dir + '/json_data/' + filetype + '.' + data.split('/')[-1])

In [None]:
# Check Statistics
from glob import glob
import json

for filetype in ['train', 'test', 'valid']:
  cnt_dict = {'src':0}
  for i in range(1,17+1):
    cnt_dict['tgt' + str(i)] = 0
  for data in sorted(glob(work_dir + '/json_data/' + filetype + '*')):
    json_open = open(data)
    json_load = json.load(json_open)
    keys = list(json_load.keys())
    cnt_dict['src'] += len(json_load['src'])
    for i in range(1,17+1):
      tgt_idx = 'tgt' + str(i)
      if tgt_idx not in keys:
        continue
      for j in json_load[tgt_idx]:
        cnt_dict[tgt_idx] += len(j)
  filetype = 'all' if filetype == '' else filetype
  print(f'{filetype}: {cnt_dict}')

train: {'src': 123081, 'tgt1': 1090, 'tgt2': 925, 'tgt3': 6620, 'tgt4': 2951, 'tgt5': 4673, 'tgt6': 2306, 'tgt7': 4907, 'tgt8': 7278, 'tgt9': 5417, 'tgt10': 3587, 'tgt11': 4516, 'tgt12': 7092, 'tgt13': 6169, 'tgt14': 1901, 'tgt15': 3175, 'tgt16': 2774, 'tgt17': 4377}
test: {'src': 24920, 'tgt1': 201, 'tgt2': 187, 'tgt3': 1015, 'tgt4': 458, 'tgt5': 682, 'tgt6': 416, 'tgt7': 963, 'tgt8': 1339, 'tgt9': 1277, 'tgt10': 473, 'tgt11': 770, 'tgt12': 1311, 'tgt13': 1527, 'tgt14': 508, 'tgt15': 550, 'tgt16': 491, 'tgt17': 1008}
valid: {'src': 25663, 'tgt1': 310, 'tgt2': 251, 'tgt3': 1554, 'tgt4': 612, 'tgt5': 995, 'tgt6': 272, 'tgt7': 1251, 'tgt8': 1797, 'tgt9': 1491, 'tgt10': 594, 'tgt11': 931, 'tgt12': 1550, 'tgt13': 1312, 'tgt14': 579, 'tgt15': 817, 'tgt16': 651, 'tgt17': 1407}


In [None]:
# Step4. Creating Labeled JSON Files
import json
import torch
from transformers import BertTokenizer
import os
import shutil
from glob import glob

dir = work_dir + '/json_data_labeled'
if os.path.exists(dir):
  shutil.rmtree(dir)
os.mkdir(dir)

# stride = float('inf')
stride = 5

qrys = ['End poverty in all its forms everywhere', 'End hunger, achieve food security and improved nutrition and promote sustainable agriculture', 'Ensure healthy lives and promote well-being for all at all ages', 'Ensure inclusive and equitable quality education and promote lifelong learning opportunities for all', 'Achieve gender equality and empower all women and girls', 'Ensure availability and sustainable management of water and sanitation for all', 'Ensure access to affordable, reliable, sustainable and modern energy for all', 'Promote sustained, inclusive and sustainable economic growth, full and productive employment and decent work for all', 'Build resilient infrastructure, promote inclusive and sustainable industrialization and foster innovation', 'Reduce inequality within and among countries', 'Make cities and human settlements inclusive, safe, resilient and sustainable', 'Ensure sustainable consumption and production patterns', 'Take urgent action to combat climate change and its impacts*', 'Conserve and sustainably use the oceans, seas and marine resources for sustainable development', 'Protect, restore and promote sustainable use of terrestrial ecosystems, sustainably manage forests, combat desertification, and halt and reverse land degradation and halt biodiversity loss', 'Promote peaceful and inclusive societies for sustainable development, provide access to justice for all and build effective, accountable and inclusive institutions at all levels', 'Strengthen the means of implementation and revitalize the global partnership for sustainable development']
qry_subtokens_idxs_list = []
qry_subtokens_list = []
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=False)
tokenizer.add_tokens(['[L1]'])
for qry in qrys:
  qry_text =  '[L1] ' + qry
  qry_subtokens = tokenizer.tokenize(qry_text)
  qry_subtokens_list.append(qry_subtokens)
  qry_subtoken_idxs = tokenizer.convert_tokens_to_ids(qry_subtokens)
  qry_subtokens_idxs_list.append(qry_subtoken_idxs)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
for data in sorted(glob(work_dir + '/json_data/*')):
  json_open = open(data)
  json_load = json.load(json_open)
  keys = list(json_load.keys())
  keys.remove('src')
  src = json_load['src']
  ds_dict = {}
  
  for i, qry_subtoken_idxs in enumerate(qry_subtokens_idxs_list):
    ds_dict['qry_ids'+str(i+1)] = qry_subtoken_idxs
  for i, qry in enumerate(qrys):
    ds_dict['qry_txt'+str(i+1)] = qry

  # Adding labels by category
  for i in range(1,17+1):
    labels = [0] * len(src)
    tgt_idx = 'tgt' + str(i)
    if tgt_idx not in keys:
      ds_dict[tgt_idx] = labels
      continue
    tgts = json_load[tgt_idx]
    for tgt in tgts:
      lent = len(tgt)
      tcon = []
      for j in tgt[0:lent]:
          tcon += j
      tjoin = ' '.join(tcon)
      last = len(src) - lent
      for k in range(0, last+1):
        scon = []
        for m in src[k:k+lent]:
          scon += m
        sjoin = ' '.join(scon)
        if sjoin == tjoin:
          for l in range(k, k+lent):
            labels[l] = 1
    ds_dict[tgt_idx] = labels
  
  # Adding CLS and SEP, and tokenize to ids
  src_subtoken_idxs_list = []
  src_txt = []
  for sent in src:
    text = ' '.join(sent)
    src_txt.append(text)
    src_subtokens = tokenizer.tokenize('[CLS] ' + text + ' [SEP]')
    src_subtoken_idxs = tokenizer.convert_tokens_to_ids(src_subtokens)
    src_subtoken_idxs_list.append(src_subtoken_idxs)

  # Adding starting positions(stpos) and number of sentences to be included(nstbi)
  for i in range(1,17+1):
    stpos, nstbi, ipos, scnt, nums, tflag = [], [], 0, 0, 0, False
    max_scnt = 512 - len(qry_subtokens_idxs_list[i-1])
    while ipos < len(src_subtoken_idxs_list) and tflag is False:
      nums, scnt = 0, 0
      if len(src_subtoken_idxs_list[ipos+nums]) > max_scnt:
        src_subtoken_idxs_list[ipos+nums] = src_subtoken_idxs_list[ipos+nums][:max_scnt]
      # Ending process if it exceeds 512 tokens or processes last token
      while scnt <= max_scnt:
        # Ending with Terminate flag if it processes last token
        if ipos+nums > len(src_subtoken_idxs_list) - 1:
          tflag = True
          nums += 1
          break
        scnt += len(src_subtoken_idxs_list[ipos+nums])
        nums += 1
      nums -= 1
      stpos.append(ipos)
      nstbi.append(nums)
      ipos += stride if nums > stride else nums
    ds_dict['stpos'+str(i)] = stpos
    ds_dict['nstbi'+str(i)] = nstbi

  ds_dict['src_ids'] = src_subtoken_idxs_list
  ds_dict['src_txt'] = src_txt
  
  filename = data.split('/')[-1]
  with open(dir + '/Baseline.' + filename, mode='w') as f:
    f.write(json.dumps(ds_dict))

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
# Check Statistics
from glob import glob
import json

cnt_dict = {'src':0}
for i in range(1,17+1):
  cnt_dict['tgt' + str(i)] = 0

for data in sorted(glob(work_dir + '/json_data_labeled/*')):
  json_open = open(data)
  json_load = json.load(json_open)
  keys = list(json_load.keys())
  cnt_dict['src'] += len(json_load['src_txt'])
  for i in range(1,17+1):
    tgt_idx = 'tgt' + str(i)
    cnt_dict[tgt_idx] += sum(json_load[tgt_idx])
print(cnt_dict)

{'src': 173664, 'tgt1': 1493, 'tgt2': 1338, 'tgt3': 8891, 'tgt4': 3932, 'tgt5': 6201, 'tgt6': 2849, 'tgt7': 6938, 'tgt8': 10217, 'tgt9': 8102, 'tgt10': 4522, 'tgt11': 6078, 'tgt12': 9676, 'tgt13': 8761, 'tgt14': 2985, 'tgt15': 4482, 'tgt16': 3815, 'tgt17': 6630}


In [None]:
# Step5. Creating PT Files
import json
import itertools
import os
import shutil
from glob import glob
import torch
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

dir = work_dir + '/pt_data'
if os.path.exists(dir):
  shutil.rmtree(dir)
os.mkdir(dir)

sep_vid = tokenizer.vocab['[SEP]']
cls_vid = tokenizer.vocab['[CLS]']
pad_vid = tokenizer.vocab['[PAD]']

for data in sorted(glob(work_dir + '/json_data_labeled/*')):
  json_open = open(data)
  json_load = json.load(json_open)
  src_ids = json_load['src_ids']
  dataset = {'src': json_load['src_txt'], 'tgts': {}}
  datasets = []
  for idx in range(1, 18):
    tgt = 'tgt' + str(idx)
    dataset['tgts'][str(idx)] = json_load[tgt]
    for stpos, nstbi in zip(json_load['stpos'+str(idx)], json_load['nstbi'+str(idx)]):
      pt_dict = {}
      pt_dict['qry'] = str(idx)
      # Adding start position and end position
      pt_dict['stpos'] = stpos
      enpos = stpos + nstbi
      pt_dict['enpos'] = enpos
      # Adding source ids and source text
      src_subtoken_idxs = list(itertools.chain.from_iterable(src_ids[stpos:enpos]))
      qry_src_subtoken_idxs = json_load['qry_ids'+str(idx)] + src_subtoken_idxs
      pt_dict['src_ids'] = qry_src_subtoken_idxs
      # Adding Segment ID
      segments_ids = len(json_load['qry_ids'+str(idx)]) * [0]
      segments_ids += len(src_subtoken_idxs)* [1]
      pt_dict['segs'] = segments_ids
      # Adding CLS positions
      cls_ids = [i for i, t in enumerate(qry_src_subtoken_idxs) if t == cls_vid]
      pt_dict['clss'] = cls_ids
      # Adding labels
      pt_dict['tgt'] = json_load[tgt][stpos:enpos]
      datasets.append(pt_dict)
  dataset['ds'] = datasets
  filename = '.'.join(data.split('/')[-1].split('.')[1:5])
  torch.save(dataset, dir + '/' + filename + '.pt')

In [None]:
# Check Number of Dataset
from glob import glob
import torch
ds_counts = []
zero_sent_cnts = []
z_cnt = 0
o_cnt = 0
for data_type in ['train', 'test', 'valid']:
  ds_list = glob(work_dir + '/pt_data/' + data_type + '.*')
  ds_count = 0
  for ds in ds_list:
    dataset = torch.load(ds)
    ds_count += len(dataset['ds'])
    for d in dataset['ds']:
      z_cnt += len([j for j in d['tgt'] if j == 0])
      o_cnt += len([j for j in d['tgt'] if j == 1])
  ds_counts.append(ds_count)
print(f'train: {ds_counts[0]}, test: {ds_counts[1]}, valid: {ds_counts[2]}, ds_total: {ds_counts[0] + ds_counts[1] + ds_counts[2]}, zero_sent_in_ds: {z_cnt}, one_sent_in_ds: {o_cnt}')

train: 413167, test: 83580, valid: 86139, ds_total: 582886, zero_sent_in_ds: 8288904, one_sent_in_ds: 285986


In [None]:
# Check Number of Dataset by Query
from glob import glob
import torch
for i in range(1,17+1):
  ds_counts = []
  zero_sent_cnts = []
  z_cnt = 0
  o_cnt = 0
  for data_type in ['train', 'test', 'valid']:
    ds_list = glob(work_dir + '/pt_data/' + data_type + '.*')
    ds_count = 0
    for ds in ds_list:
      dataset = torch.load(ds)
      for d in dataset['ds']:
        if d['qry'] == str(i):
          z_cnt += len([j for j in d['tgt'] if j == 0])
          o_cnt += len([j for j in d['tgt'] if j == 1])
          ds_count += 1
    ds_counts.append(ds_count)
  print(f'qry: {i}, train: {ds_counts[0]}, test: {ds_counts[1]}, valid: {ds_counts[2]}, ds_total: {ds_counts[0] + ds_counts[1] + ds_counts[2]}, zero_sent_in_ds: {z_cnt}, one_sent_in_ds: {o_cnt}')
  # break

qry: 1, train: 24296, test: 4914, valid: 5065, ds_total: 34275, zero_sent_in_ds: 507023, one_sent_in_ds: 4389
qry: 2, train: 24304, test: 4916, valid: 5067, ds_total: 34287, zero_sent_in_ds: 501610, one_sent_in_ds: 3765
qry: 3, train: 24304, test: 4916, valid: 5067, ds_total: 34287, zero_sent_in_ds: 479025, one_sent_in_ds: 26350
qry: 4, train: 24304, test: 4918, valid: 5067, ds_total: 34289, zero_sent_in_ds: 492703, one_sent_in_ds: 11640
qry: 5, train: 24301, test: 4915, valid: 5066, ds_total: 34282, zero_sent_in_ds: 489762, one_sent_in_ds: 18614
qry: 6, train: 24302, test: 4915, valid: 5067, ds_total: 34284, zero_sent_in_ds: 498716, one_sent_in_ds: 8667
qry: 7, train: 24304, test: 4916, valid: 5067, ds_total: 34287, zero_sent_in_ds: 484713, one_sent_in_ds: 20662
qry: 8, train: 24307, test: 4918, valid: 5067, ds_total: 34292, zero_sent_in_ds: 470572, one_sent_in_ds: 29765
qry: 9, train: 24304, test: 4918, valid: 5067, ds_total: 34289, zero_sent_in_ds: 479634, one_sent_in_ds: 23697
qry:

In [None]:
# Check Output
from glob import glob
import torch
dataset = torch.load(glob(work_dir + '/pt_data/*.1.Ahresty.2020.pt')[0])
print(dataset['src'])
print(dataset['tgts'])
print(len(dataset['ds']))
for d in dataset['ds']:
  print(d)
  break

['We aim to create an affluent society by pursuing each of these areas and integrating them .', 'Our corporate name , Ahresty , comes from “ RST . ”', 'It was created by linking together the pronunciation of the first letters in Research , Service , and Technology .', 'Research means continuous research , investigation , and development of new technologies , new markets , and new sales techniques .', 'Service means providing warm , attentive service through personal interaction .', 'Technology means truly excellent technology that incorporates both physical and soft aspects and is highly beneficial for society .', 'These three areas of Research , Service , and Technology can not be considered independently .', 'Both technology and a spirit of service are necessary to accomplish the research involved in R&D .', 'To explain it in another way , research , service , and technology are intricately linked and each supports the others .', 'It is an organic relationship in which each component

In [None]:
# データセットの作成(既に存在していれば実行不要)
import json
import random
from glob import glob
import torch

src_doc_valid = []
tgt_list_valid = []
src_doc_test = []
tgt_list_test = []

for data in sorted(glob(work_dir + '/json_data_labeled/*')):
  filetype = data.split('/')[-1].split('.')[1]
  json_open = open(data)
  json_load = json.load(json_open)
  src_txt = json_load['src_txt']
  if filetype == 'valid':
    src_doc_valid.append(src_txt)
    tgt_list = []
    for i in range(1,18):
      tgt_list.append(json_load['tgt'+str(i)])
    tgt_list_valid.append(tgt_list)
  if filetype == 'test':
    src_doc_test.append(src_txt)
    tgt_list = []
    for i in range(1,18):
      tgt_list.append(json_load['tgt'+str(i)])
    tgt_list_test.append(tgt_list)

# Incorporate SDGs goal text in query text
sdgsdir = '/content/drive/MyDrive/DatasetSDGs/0-SDGs/'

all_query_txt = []
for goal in sorted(glob(sdgsdir + '/*')):
  query_txt = []
  with open(goal, mode='r') as f:
    for line in f:
      query_txt.append(line.strip())
      break
  all_query_txt.append(' '.join(query_txt))

dataset = {'src_doc_valid': src_doc_valid, 'tgt_list_valid': tgt_list_valid, 'src_doc_test': src_doc_test, 'tgt_list_test': tgt_list_test, 'all_query_txt': all_query_txt}
torch.save(dataset, work_dir + '/dataset.pt')  

In [None]:
################ Training ################

In [None]:
# Install Pytorch
!pip3 install torchtext==0.3.1
!pip3 install torch==1.6.0
!pip install tensorboardX

Collecting torchtext==0.3.1
  Downloading torchtext-0.3.1-py3-none-any.whl (62 kB)
[?25l[K     |█████▎                          | 10 kB 25.6 MB/s eta 0:00:01[K     |██████████▌                     | 20 kB 25.5 MB/s eta 0:00:01[K     |███████████████▊                | 30 kB 17.3 MB/s eta 0:00:01[K     |█████████████████████           | 40 kB 12.1 MB/s eta 0:00:01[K     |██████████████████████████▎     | 51 kB 12.4 MB/s eta 0:00:01[K     |███████████████████████████████▌| 61 kB 12.3 MB/s eta 0:00:01[K     |████████████████████████████████| 62 kB 784 kB/s 
Installing collected packages: torchtext
  Attempting uninstall: torchtext
    Found existing installation: torchtext 0.11.0
    Uninstalling torchtext-0.11.0:
      Successfully uninstalled torchtext-0.11.0
Successfully installed torchtext-0.3.1
Collecting torch==1.6.0
  Downloading torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl (748.8 MB)
[K     |████████████████████████████████| 748.8 MB 20 kB/s 
Installing collected pa

Collecting tensorboardX
  Downloading tensorboardX-2.5-py2.py3-none-any.whl (125 kB)
[?25l[K     |██▋                             | 10 kB 20.8 MB/s eta 0:00:01[K     |█████▎                          | 20 kB 26.4 MB/s eta 0:00:01[K     |███████▉                        | 30 kB 30.1 MB/s eta 0:00:01[K     |██████████▌                     | 40 kB 14.4 MB/s eta 0:00:01[K     |█████████████                   | 51 kB 11.1 MB/s eta 0:00:01[K     |███████████████▊                | 61 kB 13.0 MB/s eta 0:00:01[K     |██████████████████▎             | 71 kB 12.1 MB/s eta 0:00:01[K     |█████████████████████           | 81 kB 13.3 MB/s eta 0:00:01[K     |███████████████████████▌        | 92 kB 14.7 MB/s eta 0:00:01[K     |██████████████████████████▏     | 102 kB 12.7 MB/s eta 0:00:01[K     |████████████████████████████▊   | 112 kB 12.7 MB/s eta 0:00:01[K     |███████████████████████████████▍| 122 kB 12.7 MB/s eta 0:00:01[K     |████████████████████████████████| 125 kB 12.

In [None]:
# Training Models
!rm -r $work_dir/models
!mkdir $work_dir/models
!mkdir logs

%env BERT_DATA_PATH=$work_dir/pt_data/
%env MODEL_PATH=$work_dir/models
%env LOG_PATH=/content/logs/baseline
!python $MODULE_PATH/src/train.py -mode train -summarizer baseline -dropout 0.1 -bert_data_path $BERT_DATA_PATH -model_path $MODEL_PATH -lr 2e-3 -visible_gpus 0  -gpu_ranks 0 -world_size 1 -report_every 10 -save_checkpoint_steps 10000 -batch_size 4000 -decay_method noam -train_steps 60000 -accum_count 2 -log_file $LOG_PATH -use_interval true -warmup_steps 10000 -ff_size 2048 -inter_layers 2 -heads 8

In [None]:
################ Validation ################

In [None]:
# Validation
!cp $MODULE_PATH/bert_config_uncased_base_baseline.json ../bert_config_uncased_base.json
!rm -r $work_dir/results
!mkdir $work_dir/results
!rm -r $work_dir/logs
!mkdir $work_dir/logs
%env LOG_PATH=$work_dir/logs/baseline

%env RESULT_PATH=$work_dir/results/
%env BERT_DATA_PATH=$work_dir/pt_data/
%env MODEL_PATH=$work_dir/models
!python $MODULE_PATH/src/train.py -mode valid -summarizer baseline -bert_data_path $BERT_DATA_PATH -visible_gpus 0  -gpu_ranks 0 -batch_size 512 -log_file $LOG_PATH  -result_path $RESULT_PATH -test_from $MODEL_PATH/model_step_10000.pt -block_trigram true
!python $MODULE_PATH/src/train.py -mode valid -summarizer baseline -bert_data_path $BERT_DATA_PATH -visible_gpus 0  -gpu_ranks 0 -batch_size 512 -log_file $LOG_PATH  -result_path $RESULT_PATH -test_from $MODEL_PATH/model_step_20000.pt -block_trigram true
!python $MODULE_PATH/src/train.py -mode valid -summarizer baseline -bert_data_path $BERT_DATA_PATH -visible_gpus 0  -gpu_ranks 0 -batch_size 512 -log_file $LOG_PATH  -result_path $RESULT_PATH -test_from $MODEL_PATH/model_step_30000.pt -block_trigram true
!python $MODULE_PATH/src/train.py -mode valid -summarizer baseline -bert_data_path $BERT_DATA_PATH -visible_gpus 0  -gpu_ranks 0 -batch_size 512 -log_file $LOG_PATH  -result_path $RESULT_PATH -test_from $MODEL_PATH/model_step_40000.pt -block_trigram true
!python $MODULE_PATH/src/train.py -mode valid -summarizer baseline -bert_data_path $BERT_DATA_PATH -visible_gpus 0  -gpu_ranks 0 -batch_size 512 -log_file $LOG_PATH  -result_path $RESULT_PATH -test_from $MODEL_PATH/model_step_50000.pt -block_trigram true
!python $MODULE_PATH/src/train.py -mode valid -summarizer baseline -bert_data_path $BERT_DATA_PATH -visible_gpus 0  -gpu_ranks 0 -batch_size 512 -log_file $LOG_PATH  -result_path $RESULT_PATH -test_from $MODEL_PATH/model_step_60000.pt -block_trigram true

In [None]:
# Exploring Optimal Hyper-Parameters
def predict_goals(cos_sim_list, threshold):
  src_pred_list = []
  for cos_sim in cos_sim_list:
    ones = torch.ones(cos_sim.shape).int()
    zeros = torch.zeros(cos_sim.shape).int()
    pred = torch.where(cos_sim > threshold, ones, zeros)
    src_pred_list.append(pred.tolist())
  return src_pred_list

In [None]:
def gen_optim_list(pt_optim_dicts_all, qry, priority):
  optim_dicts_all = []
  for pt_optim_dicts in pt_optim_dicts_all:
    optim_dicts = {}
    optim_dicts[qry] = {}
    for pt_optim_dict in pt_optim_dicts[qry]:
      stpos = pt_optim_dict['stpos']
      enpos = pt_optim_dict['enpos']
      pred = pt_optim_dict['pred']
      for j in range(stpos, enpos):
        st_idx = j - stpos
        en_idx = enpos - 1 - j
        if priority == 'center':          
          pred_eval = st_idx * en_idx
        elif priority == 'top':
          pred_eval = (en_idx) ** 2
        elif priority == 'bottom':
          pred_eval = (st_idx) ** 2
        else:
          None
        if optim_dicts[qry].get(j):
          if optim_dicts[qry][j][-1] < pred_eval:
            optim_dicts[qry][j] = [pred[st_idx], pred_eval]
        else:
          optim_dicts[qry][j] = [pred[st_idx], pred_eval]
    optim_dicts_all.append(optim_dicts)
  return optim_dicts_all

In [None]:
# Showing F1 Score and Accuracy in all an by qry
import numpy as np
from sklearn.metrics import f1_score, accuracy_score

def show_statistics_valid(all_predited_list, all_tgt_list, thres_list):
  best_all_predited_list = []
  for i in range(17):
    by_thres_predicted = []
    f1_list, acc_list = [], []
    for j, thres in enumerate(thres_list):
      by_qry_predicted = []
      by_qry_tgt = []
      for doc_pred, doc_tgt in zip(all_predited_list, all_tgt_list):
        by_qry_predicted += doc_pred[j][i]
        by_qry_tgt += doc_tgt[i]
      f1 = f1_score(by_qry_predicted, by_qry_tgt)
      acc = accuracy_score(by_qry_predicted, by_qry_tgt)
      f1_list.append(f1)
      acc_list.append(acc)
    bs = np.argmax(f1_list)
    print(f'Goal{i+1}, threshold: {thres_list[bs]}, f1_score: {round(f1_list[bs], 3)}, accuracy: {round(acc_list[bs], 3)}')
    best_all_predited_list.append([doc_pred[bs][i] for doc_pred in all_predited_list])
  
  best_all_predited_list_t = np.array(best_all_predited_list).T.tolist()
  single_predicted = []
  single_tgt = []
  for doc_pred, doc_tgt in zip(best_all_predited_list_t, all_tgt_list):
    for qry_pred, qry_tgt in zip(doc_pred, doc_tgt):
      single_predicted += qry_pred
      single_tgt += qry_tgt
  best_f1 = f1_score(single_predicted, single_tgt)
  best_acc = accuracy_score(single_predicted, single_tgt)
  print(f'Total Best f1_score:{round(best_f1, 3)}, accuracy:{round(best_acc, 3)}')

In [None]:
import torch
from glob import glob
import numpy as np

def validate_execute(validation_list):
  for step, align in validation_list:
    print(f'step:{step}, align:{align}')
    dataset = torch.load(glob(work_dir + '/dataset.pt')[0])
    optim_list_all = []
    pt_optim_dicts_all = torch.load(glob(work_dir + '/results/valid.step' + str(step) + '.optim_dicts_all.pt')[0])
    for i in range(1, 18):
      optim_list_by_qry = []
      optim_dicts_all = gen_optim_list(pt_optim_dicts_all, str(i), align)
      for optim_dict in optim_dicts_all:
        optim_list_by_qry.append(torch.tensor([optim_dict[str(i)][j][0] for j in sorted(optim_dict[str(i)])]))
      optim_list_all.append(optim_list_by_qry)

    thres_range = range(1, 70, 1)
    src_pred_list = []
    for i in range(17):
      src_pred_list_by_thres = []
      for j in thres_range:
        src_pred = predict_goals(optim_list_all[i], threshold=(j/100))
        src_pred_list_by_thres.append(src_pred)
      src_pred_list.append(src_pred_list_by_thres)
    src_pred_list_t = np.array(src_pred_list).T.tolist()
    show_statistics_valid(src_pred_list_t, dataset['tgt_list_valid'], [i/100 for i in thres_range])

In [None]:
# Showing Optimal Hyperparameters (Models, Threshold, and Alignment)
valiation_list = [(i, j) for j in ['center', 'top', 'bottom'] for i in range(65000, 80001, 5000)]
validate_execute(valiation_list)

step:65000, align:center




Goal1, threshold: 0.02, f1_score: 0.172, accuracy: 0.965
Goal2, threshold: 0.07, f1_score: 0.466, accuracy: 0.986
Goal3, threshold: 0.07, f1_score: 0.343, accuracy: 0.87
Goal4, threshold: 0.1, f1_score: 0.338, accuracy: 0.968
Goal5, threshold: 0.14, f1_score: 0.468, accuracy: 0.961
Goal6, threshold: 0.1, f1_score: 0.19, accuracy: 0.979
Goal7, threshold: 0.19, f1_score: 0.49, accuracy: 0.935
Goal8, threshold: 0.12, f1_score: 0.437, accuracy: 0.899
Goal9, threshold: 0.17, f1_score: 0.33, accuracy: 0.884
Goal10, threshold: 0.15, f1_score: 0.28, accuracy: 0.957
Goal11, threshold: 0.22, f1_score: 0.398, accuracy: 0.951
Goal12, threshold: 0.21, f1_score: 0.45, accuracy: 0.907
Goal13, threshold: 0.32, f1_score: 0.468, accuracy: 0.943
Goal14, threshold: 0.15, f1_score: 0.489, accuracy: 0.977
Goal15, threshold: 0.14, f1_score: 0.424, accuracy: 0.953
Goal16, threshold: 0.35, f1_score: 0.245, accuracy: 0.974
Goal17, threshold: 0.01, f1_score: 0.224, accuracy: 0.745




Total Best f1_score:0.364, accuracy:0.932
step:70000, align:center




Goal1, threshold: 0.57, f1_score: 0.106, accuracy: 0.979
Goal2, threshold: 0.25, f1_score: 0.261, accuracy: 0.983
Goal3, threshold: 0.18, f1_score: 0.299, accuracy: 0.889
Goal4, threshold: 0.6, f1_score: 0.315, accuracy: 0.973
Goal5, threshold: 0.41, f1_score: 0.453, accuracy: 0.95
Goal6, threshold: 0.03, f1_score: 0.09, accuracy: 0.965
Goal7, threshold: 0.02, f1_score: 0.308, accuracy: 0.89
Goal8, threshold: 0.01, f1_score: 0.387, accuracy: 0.873
Goal9, threshold: 0.04, f1_score: 0.258, accuracy: 0.877
Goal10, threshold: 0.65, f1_score: 0.291, accuracy: 0.965
Goal11, threshold: 0.06, f1_score: 0.25, accuracy: 0.913
Goal12, threshold: 0.01, f1_score: 0.335, accuracy: 0.863
Goal13, threshold: 0.11, f1_score: 0.357, accuracy: 0.914
Goal14, threshold: 0.04, f1_score: 0.292, accuracy: 0.972
Goal15, threshold: 0.08, f1_score: 0.271, accuracy: 0.942
Goal16, threshold: 0.03, f1_score: 0.136, accuracy: 0.962
Goal17, threshold: 0.01, f1_score: 0.194, accuracy: 0.773




Total Best f1_score:0.29, accuracy:0.922
step:75000, align:center




Goal1, threshold: 0.42, f1_score: 0.135, accuracy: 0.987
Goal2, threshold: 0.13, f1_score: 0.248, accuracy: 0.977
Goal3, threshold: 0.26, f1_score: 0.323, accuracy: 0.851
Goal4, threshold: 0.69, f1_score: 0.323, accuracy: 0.95
Goal5, threshold: 0.42, f1_score: 0.44, accuracy: 0.945
Goal6, threshold: 0.02, f1_score: 0.143, accuracy: 0.958
Goal7, threshold: 0.34, f1_score: 0.485, accuracy: 0.937
Goal8, threshold: 0.44, f1_score: 0.402, accuracy: 0.893
Goal9, threshold: 0.34, f1_score: 0.333, accuracy: 0.889
Goal10, threshold: 0.46, f1_score: 0.299, accuracy: 0.969
Goal11, threshold: 0.69, f1_score: 0.315, accuracy: 0.955
Goal12, threshold: 0.53, f1_score: 0.402, accuracy: 0.888
Goal13, threshold: 0.67, f1_score: 0.43, accuracy: 0.945
Goal14, threshold: 0.17, f1_score: 0.45, accuracy: 0.976
Goal15, threshold: 0.18, f1_score: 0.343, accuracy: 0.942
Goal16, threshold: 0.52, f1_score: 0.285, accuracy: 0.925
Goal17, threshold: 0.03, f1_score: 0.166, accuracy: 0.58




Total Best f1_score:0.314, accuracy:0.916
step:80000, align:center




Goal1, threshold: 0.54, f1_score: 0.162, accuracy: 0.984
Goal2, threshold: 0.34, f1_score: 0.244, accuracy: 0.984
Goal3, threshold: 0.02, f1_score: 0.296, accuracy: 0.787
Goal4, threshold: 0.51, f1_score: 0.37, accuracy: 0.973
Goal5, threshold: 0.35, f1_score: 0.415, accuracy: 0.95
Goal6, threshold: 0.12, f1_score: 0.186, accuracy: 0.957
Goal7, threshold: 0.2, f1_score: 0.461, accuracy: 0.945
Goal8, threshold: 0.11, f1_score: 0.403, accuracy: 0.902
Goal9, threshold: 0.07, f1_score: 0.29, accuracy: 0.864
Goal10, threshold: 0.65, f1_score: 0.281, accuracy: 0.972
Goal11, threshold: 0.12, f1_score: 0.266, accuracy: 0.922
Goal12, threshold: 0.21, f1_score: 0.428, accuracy: 0.914
Goal13, threshold: 0.25, f1_score: 0.444, accuracy: 0.934
Goal14, threshold: 0.3, f1_score: 0.379, accuracy: 0.97
Goal15, threshold: 0.31, f1_score: 0.395, accuracy: 0.942
Goal16, threshold: 0.24, f1_score: 0.225, accuracy: 0.967
Goal17, threshold: 0.11, f1_score: 0.23, accuracy: 0.812




Total Best f1_score:0.332, accuracy:0.928
step:65000, align:top




Goal1, threshold: 0.01, f1_score: 0.176, accuracy: 0.95
Goal2, threshold: 0.09, f1_score: 0.451, accuracy: 0.987
Goal3, threshold: 0.09, f1_score: 0.339, accuracy: 0.885
Goal4, threshold: 0.19, f1_score: 0.353, accuracy: 0.977
Goal5, threshold: 0.15, f1_score: 0.438, accuracy: 0.96
Goal6, threshold: 0.1, f1_score: 0.182, accuracy: 0.979
Goal7, threshold: 0.23, f1_score: 0.483, accuracy: 0.941
Goal8, threshold: 0.12, f1_score: 0.416, accuracy: 0.895
Goal9, threshold: 0.17, f1_score: 0.328, accuracy: 0.883
Goal10, threshold: 0.2, f1_score: 0.28, accuracy: 0.963
Goal11, threshold: 0.21, f1_score: 0.374, accuracy: 0.947
Goal12, threshold: 0.27, f1_score: 0.429, accuracy: 0.916
Goal13, threshold: 0.34, f1_score: 0.456, accuracy: 0.943
Goal14, threshold: 0.14, f1_score: 0.474, accuracy: 0.974
Goal15, threshold: 0.16, f1_score: 0.411, accuracy: 0.955
Goal16, threshold: 0.37, f1_score: 0.246, accuracy: 0.975
Goal17, threshold: 0.01, f1_score: 0.227, accuracy: 0.744




Total Best f1_score:0.351, accuracy:0.934
step:70000, align:top




Goal1, threshold: 0.35, f1_score: 0.092, accuracy: 0.97
Goal2, threshold: 0.26, f1_score: 0.25, accuracy: 0.982
Goal3, threshold: 0.01, f1_score: 0.292, accuracy: 0.84
Goal4, threshold: 0.53, f1_score: 0.32, accuracy: 0.972
Goal5, threshold: 0.4, f1_score: 0.432, accuracy: 0.948
Goal6, threshold: 0.01, f1_score: 0.066, accuracy: 0.938
Goal7, threshold: 0.02, f1_score: 0.305, accuracy: 0.89
Goal8, threshold: 0.01, f1_score: 0.372, accuracy: 0.87
Goal9, threshold: 0.03, f1_score: 0.258, accuracy: 0.868
Goal10, threshold: 0.66, f1_score: 0.278, accuracy: 0.965
Goal11, threshold: 0.06, f1_score: 0.25, accuracy: 0.914
Goal12, threshold: 0.01, f1_score: 0.33, accuracy: 0.862
Goal13, threshold: 0.07, f1_score: 0.333, accuracy: 0.9
Goal14, threshold: 0.04, f1_score: 0.273, accuracy: 0.971
Goal15, threshold: 0.06, f1_score: 0.248, accuracy: 0.934
Goal16, threshold: 0.03, f1_score: 0.131, accuracy: 0.961
Goal17, threshold: 0.01, f1_score: 0.192, accuracy: 0.773




Total Best f1_score:0.278, accuracy:0.915
step:75000, align:top




Goal1, threshold: 0.42, f1_score: 0.106, accuracy: 0.986
Goal2, threshold: 0.13, f1_score: 0.256, accuracy: 0.977
Goal3, threshold: 0.23, f1_score: 0.316, accuracy: 0.84
Goal4, threshold: 0.69, f1_score: 0.29, accuracy: 0.948
Goal5, threshold: 0.39, f1_score: 0.413, accuracy: 0.94
Goal6, threshold: 0.02, f1_score: 0.145, accuracy: 0.959
Goal7, threshold: 0.31, f1_score: 0.472, accuracy: 0.933
Goal8, threshold: 0.45, f1_score: 0.387, accuracy: 0.891
Goal9, threshold: 0.29, f1_score: 0.327, accuracy: 0.873
Goal10, threshold: 0.41, f1_score: 0.281, accuracy: 0.965
Goal11, threshold: 0.56, f1_score: 0.314, accuracy: 0.946
Goal12, threshold: 0.53, f1_score: 0.391, accuracy: 0.885
Goal13, threshold: 0.6, f1_score: 0.407, accuracy: 0.937
Goal14, threshold: 0.19, f1_score: 0.439, accuracy: 0.977
Goal15, threshold: 0.07, f1_score: 0.317, accuracy: 0.904
Goal16, threshold: 0.53, f1_score: 0.278, accuracy: 0.925
Goal17, threshold: 0.07, f1_score: 0.165, accuracy: 0.673




Total Best f1_score:0.313, accuracy:0.915
step:80000, align:top




Goal1, threshold: 0.6, f1_score: 0.13, accuracy: 0.985
Goal2, threshold: 0.33, f1_score: 0.277, accuracy: 0.985
Goal3, threshold: 0.02, f1_score: 0.294, accuracy: 0.785
Goal4, threshold: 0.51, f1_score: 0.372, accuracy: 0.973
Goal5, threshold: 0.31, f1_score: 0.388, accuracy: 0.945
Goal6, threshold: 0.06, f1_score: 0.147, accuracy: 0.928
Goal7, threshold: 0.19, f1_score: 0.425, accuracy: 0.941
Goal8, threshold: 0.04, f1_score: 0.383, accuracy: 0.862
Goal9, threshold: 0.06, f1_score: 0.298, accuracy: 0.856
Goal10, threshold: 0.57, f1_score: 0.265, accuracy: 0.967
Goal11, threshold: 0.12, f1_score: 0.257, accuracy: 0.92
Goal12, threshold: 0.2, f1_score: 0.411, accuracy: 0.909
Goal13, threshold: 0.23, f1_score: 0.405, accuracy: 0.927
Goal14, threshold: 0.3, f1_score: 0.351, accuracy: 0.969
Goal15, threshold: 0.31, f1_score: 0.365, accuracy: 0.939
Goal16, threshold: 0.27, f1_score: 0.205, accuracy: 0.969
Goal17, threshold: 0.09, f1_score: 0.224, accuracy: 0.796




Total Best f1_score:0.317, accuracy:0.921
step:65000, align:bottom




Goal1, threshold: 0.01, f1_score: 0.171, accuracy: 0.948
Goal2, threshold: 0.15, f1_score: 0.494, accuracy: 0.991
Goal3, threshold: 0.05, f1_score: 0.337, accuracy: 0.849
Goal4, threshold: 0.04, f1_score: 0.316, accuracy: 0.952
Goal5, threshold: 0.15, f1_score: 0.44, accuracy: 0.96
Goal6, threshold: 0.14, f1_score: 0.164, accuracy: 0.983
Goal7, threshold: 0.17, f1_score: 0.471, accuracy: 0.929
Goal8, threshold: 0.11, f1_score: 0.433, accuracy: 0.893
Goal9, threshold: 0.17, f1_score: 0.318, accuracy: 0.882
Goal10, threshold: 0.2, f1_score: 0.263, accuracy: 0.962
Goal11, threshold: 0.22, f1_score: 0.366, accuracy: 0.948
Goal12, threshold: 0.21, f1_score: 0.448, accuracy: 0.906
Goal13, threshold: 0.32, f1_score: 0.472, accuracy: 0.943
Goal14, threshold: 0.16, f1_score: 0.487, accuracy: 0.977
Goal15, threshold: 0.14, f1_score: 0.41, accuracy: 0.951
Goal16, threshold: 0.43, f1_score: 0.252, accuracy: 0.976
Goal17, threshold: 0.02, f1_score: 0.218, accuracy: 0.78




Total Best f1_score:0.357, accuracy:0.931
step:70000, align:bottom




Goal1, threshold: 0.03, f1_score: 0.071, accuracy: 0.932
Goal2, threshold: 0.26, f1_score: 0.269, accuracy: 0.983
Goal3, threshold: 0.01, f1_score: 0.288, accuracy: 0.839
Goal4, threshold: 0.56, f1_score: 0.289, accuracy: 0.971
Goal5, threshold: 0.4, f1_score: 0.426, accuracy: 0.947
Goal6, threshold: 0.01, f1_score: 0.085, accuracy: 0.939
Goal7, threshold: 0.01, f1_score: 0.298, accuracy: 0.869
Goal8, threshold: 0.01, f1_score: 0.376, accuracy: 0.87
Goal9, threshold: 0.04, f1_score: 0.243, accuracy: 0.874
Goal10, threshold: 0.6, f1_score: 0.259, accuracy: 0.962
Goal11, threshold: 0.08, f1_score: 0.232, accuracy: 0.917
Goal12, threshold: 0.01, f1_score: 0.332, accuracy: 0.861
Goal13, threshold: 0.1, f1_score: 0.364, accuracy: 0.912
Goal14, threshold: 0.04, f1_score: 0.275, accuracy: 0.971
Goal15, threshold: 0.06, f1_score: 0.27, accuracy: 0.934
Goal16, threshold: 0.03, f1_score: 0.13, accuracy: 0.961
Goal17, threshold: 0.01, f1_score: 0.192, accuracy: 0.773




Total Best f1_score:0.274, accuracy:0.913
step:75000, align:bottom




Goal1, threshold: 0.15, f1_score: 0.101, accuracy: 0.978
Goal2, threshold: 0.28, f1_score: 0.262, accuracy: 0.987
Goal3, threshold: 0.24, f1_score: 0.309, accuracy: 0.842
Goal4, threshold: 0.68, f1_score: 0.314, accuracy: 0.948
Goal5, threshold: 0.42, f1_score: 0.428, accuracy: 0.944
Goal6, threshold: 0.02, f1_score: 0.135, accuracy: 0.959
Goal7, threshold: 0.5, f1_score: 0.462, accuracy: 0.945
Goal8, threshold: 0.44, f1_score: 0.384, accuracy: 0.89
Goal9, threshold: 0.31, f1_score: 0.311, accuracy: 0.876
Goal10, threshold: 0.47, f1_score: 0.296, accuracy: 0.969
Goal11, threshold: 0.61, f1_score: 0.27, accuracy: 0.947
Goal12, threshold: 0.53, f1_score: 0.391, accuracy: 0.885
Goal13, threshold: 0.68, f1_score: 0.415, accuracy: 0.945
Goal14, threshold: 0.15, f1_score: 0.446, accuracy: 0.974
Goal15, threshold: 0.22, f1_score: 0.348, accuracy: 0.949
Goal16, threshold: 0.48, f1_score: 0.272, accuracy: 0.92
Goal17, threshold: 0.02, f1_score: 0.165, accuracy: 0.53




Total Best f1_score:0.297, accuracy:0.911
step:80000, align:bottom




Goal1, threshold: 0.54, f1_score: 0.138, accuracy: 0.984
Goal2, threshold: 0.28, f1_score: 0.227, accuracy: 0.98
Goal3, threshold: 0.02, f1_score: 0.29, accuracy: 0.786
Goal4, threshold: 0.57, f1_score: 0.358, accuracy: 0.975
Goal5, threshold: 0.36, f1_score: 0.393, accuracy: 0.949
Goal6, threshold: 0.13, f1_score: 0.178, accuracy: 0.959
Goal7, threshold: 0.2, f1_score: 0.443, accuracy: 0.943
Goal8, threshold: 0.11, f1_score: 0.399, accuracy: 0.901
Goal9, threshold: 0.06, f1_score: 0.269, accuracy: 0.852
Goal10, threshold: 0.53, f1_score: 0.266, accuracy: 0.965
Goal11, threshold: 0.12, f1_score: 0.258, accuracy: 0.92
Goal12, threshold: 0.18, f1_score: 0.413, accuracy: 0.906
Goal13, threshold: 0.28, f1_score: 0.432, accuracy: 0.936
Goal14, threshold: 0.29, f1_score: 0.343, accuracy: 0.967
Goal15, threshold: 0.36, f1_score: 0.379, accuracy: 0.944
Goal16, threshold: 0.26, f1_score: 0.227, accuracy: 0.969
Goal17, threshold: 0.11, f1_score: 0.227, accuracy: 0.811




Total Best f1_score:0.32, accuracy:0.926


In [None]:
################ Test ################

In [None]:
!cp $MODULE_PATH/bert_config_uncased_base_baseline.json ../bert_config_uncased_base.json
%env LOG_PATH=$work_dir/logs/baseline

%env RESULT_PATH=$work_dir/results/
%env BERT_DATA_PATH=$work_dir/pt_data/
%env MODEL_PATH=$work_dir/models
# Specify the optimal model.
!python $MODULE_PATH/src/train.py -mode test -summarizer baseline -bert_data_path $BERT_DATA_PATH -visible_gpus 0  -gpu_ranks 0 -batch_size 512 -log_file $LOG_PATH  -result_path $RESULT_PATH -test_from $MODEL_PATH/model_step_65000.pt -block_trigram true

env: LOG_PATH=/content/drive/MyDrive/work_baseline/logs/baseline
env: RESULT_PATH=/content/drive/MyDrive/work_baseline/results/
env: BERT_DATA_PATH=/content/drive/MyDrive/work_baseline/pt_data/
env: MODEL_PATH=/content/drive/MyDrive/work_baseline/models
[2022-03-27 11:37:32,685 INFO] Loading checkpoint from /content/drive/MyDrive/work_baseline/models/model_step_65000.pt
Namespace(accum_count=1, batch_size=512, bert_config_path='../bert_config_uncased_base.json', bert_data_path='/content/drive/MyDrive/work_baseline/pt_data/', beta1=0.9, beta2=0.999, block_trigram=True, dataset='', decay_method='', dropout=0.1, ff_size=2048, gpu_ranks=[0], heads=8, hidden_size=128, inter_layers=2, log_file='/content/drive/MyDrive/work_baseline/logs/baseline', lr=1, max_grad_norm=0, mode='test', model_path='../models/', optim='adam', param_init=0, param_init_glorot=True, recall_eval=False, report_every=1, report_rouge=True, result_path='/content/drive/MyDrive/work_baseline/results/', rnn_size=512, save_ch

In [None]:
# Showing F1 Score and Accuracy in all an by qry
import numpy as np
from sklearn.metrics import f1_score, accuracy_score

def show_statistics_test(all_predited_list, all_tgt_list):
  by_qry_tgt_list = []
  by_qry_pred_list = []
  for i in range(17):
    by_qry_tgt = []
    for data in all_tgt_list:
      by_qry_tgt += data[i]
    by_qry_pred = []
    for data in all_predited_list:
      by_qry_pred += data[i]
    f1 = f1_score(by_qry_pred, by_qry_tgt)
    acc = accuracy_score(by_qry_pred, by_qry_tgt)
    print(f'Goal{i+1}, f1_score: {round(f1, 3)}, accuracy: {round(acc, 3)}')

    by_qry_tgt_list.append(by_qry_tgt)
    by_qry_pred_list.append(by_qry_pred)
  
  tgt_list_all = []
  pred_list_all = []
  for by_qry_tgt in by_qry_tgt_list:
    tgt_list_all += by_qry_tgt
  for by_qry_pred in by_qry_pred_list:
    pred_list_all += by_qry_pred
  
  f1_all = f1_score(pred_list_all, tgt_list_all)
  acc_all = accuracy_score(pred_list_all, tgt_list_all)
  print(f'Total, f1_score: {round(f1_all, 3)}, accuracy: {round(acc_all, 3)}')

In [None]:
import torch
from glob import glob
import numpy as np

def test_execute(step, align, thres_list):
  dataset = torch.load(glob(work_dir + '/dataset.pt')[0])
  optim_list_all = []
  pt_optim_dicts_all = torch.load(glob(work_dir + '/results/test.step' + str(step) + '.optim_dicts_all.pt')[0])
  for i in range(1, 18):
    optim_list_by_qry = []
    optim_dicts_all = gen_optim_list(pt_optim_dicts_all, str(i), align)
    for optim_dict in optim_dicts_all:
      optim_list_by_qry.append(torch.tensor([optim_dict[str(i)][j][0] for j in sorted(optim_dict[str(i)])]))
    optim_list_all.append(optim_list_by_qry)

  src_pred_list = []
  for i in range(17):
    src_pred = predict_goals(optim_list_all[i], threshold=thres_list[i])
    src_pred_list.append(src_pred)
  src_pred_list_t = np.array(src_pred_list).T.tolist()
  show_statistics_test(src_pred_list_t, dataset['tgt_list_test'])

In [None]:
# Showing Test Results
test_thres_list = [0.02, 0.07, 0.07, 0.1, 0.14, 0.1, 0.19, 0.12, 0.17, 0.15, 0.22, 0.21, 0.32, 0.15, 0.14, 0.35, 0.01]
test_execute(65000, 'center', test_thres_list)



Goal1, f1_score: 0.078, accuracy: 0.971
Goal2, f1_score: 0.289, accuracy: 0.987
Goal3, f1_score: 0.237, accuracy: 0.873
Goal4, f1_score: 0.286, accuracy: 0.969
Goal5, f1_score: 0.36, accuracy: 0.957
Goal6, f1_score: 0.375, accuracy: 0.962
Goal7, f1_score: 0.375, accuracy: 0.916
Goal8, f1_score: 0.329, accuracy: 0.879
Goal9, f1_score: 0.365, accuracy: 0.897
Goal10, f1_score: 0.256, accuracy: 0.954
Goal11, f1_score: 0.277, accuracy: 0.929
Goal12, f1_score: 0.36, accuracy: 0.885
Goal13, f1_score: 0.427, accuracy: 0.922
Goal14, f1_score: 0.338, accuracy: 0.963
Goal15, f1_score: 0.375, accuracy: 0.949
Goal16, f1_score: 0.048, accuracy: 0.978
Goal17, f1_score: 0.182, accuracy: 0.736
Total, f1_score: 0.302, accuracy: 0.925
