<a href="https://colab.research.google.com/github/shinichiromizuno/QueryMultiTopic/blob/master/Sentence_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Mount your Google Drive.
# Connect to GPU in Google Colab.

In [None]:
################ Preprocess ################

In [None]:
# Configure working directories.　We use the preprocessed dataset output of Multi-BERTSum by configuring it as the source directory.
work_dir = '/content/drive/MyDrive/work_Sentence_BERT'
source_dir = '/content/drive/MyDrive/work_Multi_BERTSum'

In [None]:
# Install required packages.
!pip install sentence-transformers

In [None]:
# Skipping Step1~4, as we use the preprocessed dataset output of Multi-BERTSum.

In [None]:
# Preparing Data for Validation & Test
import json
import random
from glob import glob
import torch

src_doc_train = []
tgt_list_train = []
src_doc_valid = []
tgt_list_valid = []
src_doc_test = []
tgt_list_test = []
for data in sorted(glob(source_dir + '/json_data_labeled/*')):
  filetype = data.split('/')[-1].split('.')[1]
  json_open = open(data)
  json_load = json.load(json_open)
  src_txt = json_load['src_txt']
  if filetype == 'train':
    src_doc_train.append(src_txt)
    tgt_list = []
    for i in range(1,18):
      tgt_list.append(json_load['tgt'+str(i)])
    tgt_list_train.append(tgt_list)
  if filetype == 'valid':
    src_doc_valid.append(src_txt)
    tgt_list = []
    for i in range(1,18):
      tgt_list.append(json_load['tgt'+str(i)])
    tgt_list_valid.append(tgt_list)
  if filetype == 'test':
    src_doc_test.append(src_txt)
    tgt_list = []
    for i in range(1,18):
      tgt_list.append(json_load['tgt'+str(i)])
    tgt_list_test.append(tgt_list)

# Incorporate SDGs goal text in query text
sdgsdir = '/content/drive/MyDrive/DatasetSDGs/0-SDGs/'

all_query_txt = []
for goal in sorted(glob(sdgsdir + '/*')):
  query_txt = []
  with open(goal, mode='r') as f:
    for line in f:
      query_txt.append(line.strip())
      break
  all_query_txt.append(' '.join(query_txt))

dataset = {'src_doc_train': src_doc_train, 'tgt_list_train': tgt_list_train, 'src_doc_valid': src_doc_valid, 'tgt_list_valid': tgt_list_valid, 'src_doc_test': src_doc_test, 'tgt_list_test': tgt_list_test, 'all_query_txt': all_query_txt}
torch.save(dataset, work_dir + '/dataset.pt')  

In [None]:
import torch
from glob import glob
from sentence_transformers import SentenceTransformer, InputExample, losses, util

dataset = torch.load(glob(work_dir + '/dataset.pt')[0])
train_examples = []
for src_doc, tgt_list in zip(dataset['src_doc_train'], dataset['tgt_list_train']):
  for i in range(17):
    qry = dataset['all_query_txt'][i]
    for src, tgt in zip(src_doc, tgt_list[i]):
      train_examples.append(InputExample(texts=[qry, src], label=float(tgt)))

In [None]:
# Check Statistics
len(train_examples)

2092377

In [None]:
################ Training ################

In [None]:
from torch.utils.data import DataLoader

import json
import torch
from transformers import BertTokenizer
import os
import shutil
from glob import glob

for i in range(7, 8):
  dir = work_dir + '/model_epoch' + str(i)
  if os.path.exists(dir):
    shutil.rmtree(dir)
  os.mkdir(dir)

  #Define the model. Either from scratch of by loading a pre-trained model
  if i == 1:
    model = SentenceTransformer('all-mpnet-base-v2')
  else:
    last_model = work_dir + '/model_epoch' + str((i-1))
    model = SentenceTransformer(last_model)

  #Define your train dataset, the dataloader and the train loss
  train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=30)
  train_loss = losses.CosineSimilarityLoss(model)

  #Tune the model
  model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100)

  #Save the model
  modelPath = dir
  model.save(modelPath)



Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/69746 [00:00<?, ?it/s]

In [None]:
################ Validation ################

In [None]:
def sentence_smilarity(src_doc_test, all_query_txt, model):
  cos_sim_list = []
  for src_doc in src_doc_test:
    emb1 = model.encode(all_query_txt)
    emb2 = model.encode(src_doc)
    cos_sim = util.cos_sim(emb1, emb2)
    cos_sim_list.append(cos_sim)
  return cos_sim_list

def predict_goals(cos_sim_list, threshold):
  src_pred_list = []
  for cos_sim in cos_sim_list:
    ones = torch.ones(cos_sim.shape).int()
    zeros = torch.zeros(cos_sim.shape).int()
    pred = torch.where(cos_sim > threshold, ones, zeros)
    src_pred_list.append(pred.tolist())
  return src_pred_list

In [None]:
# Showing F1 Score and Accuracy in all an by qry
import numpy as np
from sklearn.metrics import f1_score, accuracy_score

def show_statistics(all_predited_list, all_tgt_list, thres_list):
  best_all_predited_list = []
  for i in range(17):
    by_thres_predicted = []
    f1_list, acc_list = [], []
    for j, thres in enumerate(thres_list):
      by_qry_predicted = []
      by_qry_tgt = []
      for doc_pred, doc_tgt in zip(all_predited_list, all_tgt_list):
        by_qry_predicted += doc_pred[j][i]
        by_qry_tgt += doc_tgt[i]
      f1 = f1_score(by_qry_predicted, by_qry_tgt)
      acc = accuracy_score(by_qry_predicted, by_qry_tgt)
      f1_list.append(f1)
      acc_list.append(acc)
    bs = np.argmax(f1_list)
    print(f'Goal{i+1}, threshold: {thres_list[bs]}, f1_score: {round(f1_list[bs], 3)}, accuracy: {round(acc_list[bs], 3)}')
    best_all_predited_list.append([doc_pred[bs][i] for doc_pred in all_predited_list])
  
  best_all_predited_list_t = np.array(best_all_predited_list).T.tolist()
  single_predicted = []
  single_tgt = []
  for doc_pred, doc_tgt in zip(best_all_predited_list_t, all_tgt_list):
    for qry_pred, qry_tgt in zip(doc_pred, doc_tgt):
      single_predicted += qry_pred
      single_tgt += qry_tgt
  best_f1 = f1_score(single_predicted, single_tgt)
  best_acc = accuracy_score(single_predicted, single_tgt)
  print(f'Total Best f1_score:{round(best_f1, 3)}, accuracy:{round(best_acc, 3)}')

In [None]:
# Showing Optimal Hyperparameters
import numpy as np
import torch
from glob import glob
from sentence_transformers import SentenceTransformer, InputExample, losses, util

dataset = torch.load(glob(work_dir + '/dataset.pt')[0])

# Specify the number of iteration executed in the range.
for m in range(1, 7):
  print(m)
  src_pred_list = []
  cos_sim_list = []
  thres_range = range(0, 70, 1)
  for i in range(17):
    saved_model = SentenceTransformer(work_dir + '/model_epoch' + str(m))
    cos_sim_list = sentence_smilarity(dataset['src_doc_valid'], [dataset['all_query_txt'][i]], saved_model)
    src_pred_list_by_thres = []
    for j in thres_range:
      src_pred = predict_goals(cos_sim_list, threshold=(j/100))
      src_pred_list_by_thres.append(src_pred)
    src_pred_list.append(src_pred_list_by_thres)
  src_pred_list_t = np.squeeze(np.array(src_pred_list), 3).T.tolist()
  show_statistics(src_pred_list_t, dataset['tgt_list_valid'], [i/100 for i in thres_range])

1




Goal1, threshold: 0.07, f1_score: 0.182, accuracy: 0.973
Goal2, threshold: 0.21, f1_score: 0.303, accuracy: 0.989
Goal3, threshold: 0.13, f1_score: 0.321, accuracy: 0.886
Goal4, threshold: 0.21, f1_score: 0.302, accuracy: 0.965
Goal5, threshold: 0.1, f1_score: 0.346, accuracy: 0.935
Goal6, threshold: 0.27, f1_score: 0.302, accuracy: 0.987
Goal7, threshold: 0.24, f1_score: 0.43, accuracy: 0.944
Goal8, threshold: 0.11, f1_score: 0.407, accuracy: 0.894
Goal9, threshold: 0.13, f1_score: 0.3, accuracy: 0.893
Goal10, threshold: 0.25, f1_score: 0.251, accuracy: 0.966
Goal11, threshold: 0.19, f1_score: 0.322, accuracy: 0.944
Goal12, threshold: 0.17, f1_score: 0.413, accuracy: 0.918
Goal13, threshold: 0.17, f1_score: 0.414, accuracy: 0.93
Goal14, threshold: 0.1, f1_score: 0.313, accuracy: 0.956
Goal15, threshold: 0.11, f1_score: 0.352, accuracy: 0.946
Goal16, threshold: 0.12, f1_score: 0.24, accuracy: 0.951
Goal17, threshold: 0.1, f1_score: 0.227, accuracy: 0.894




Total Best f1_score:0.336, accuracy:0.939
2




Goal1, threshold: 0.15, f1_score: 0.177, accuracy: 0.983
Goal2, threshold: 0.15, f1_score: 0.29, accuracy: 0.988
Goal3, threshold: 0.15, f1_score: 0.322, accuracy: 0.901
Goal4, threshold: 0.24, f1_score: 0.303, accuracy: 0.97
Goal5, threshold: 0.16, f1_score: 0.337, accuracy: 0.948
Goal6, threshold: 0.39, f1_score: 0.295, accuracy: 0.989
Goal7, threshold: 0.16, f1_score: 0.429, accuracy: 0.935
Goal8, threshold: 0.13, f1_score: 0.4, accuracy: 0.906
Goal9, threshold: 0.16, f1_score: 0.304, accuracy: 0.911
Goal10, threshold: 0.23, f1_score: 0.265, accuracy: 0.967
Goal11, threshold: 0.22, f1_score: 0.324, accuracy: 0.95
Goal12, threshold: 0.17, f1_score: 0.41, accuracy: 0.921
Goal13, threshold: 0.2, f1_score: 0.405, accuracy: 0.934
Goal14, threshold: 0.12, f1_score: 0.28, accuracy: 0.963
Goal15, threshold: 0.14, f1_score: 0.379, accuracy: 0.958
Goal16, threshold: 0.18, f1_score: 0.247, accuracy: 0.968
Goal17, threshold: 0.09, f1_score: 0.233, accuracy: 0.882




Total Best f1_score:0.338, accuracy:0.946
3




Goal1, threshold: 0.16, f1_score: 0.178, accuracy: 0.983
Goal2, threshold: 0.15, f1_score: 0.256, accuracy: 0.986
Goal3, threshold: 0.16, f1_score: 0.323, accuracy: 0.907
Goal4, threshold: 0.26, f1_score: 0.341, accuracy: 0.971
Goal5, threshold: 0.13, f1_score: 0.33, accuracy: 0.943
Goal6, threshold: 0.33, f1_score: 0.334, accuracy: 0.988
Goal7, threshold: 0.13, f1_score: 0.418, accuracy: 0.933
Goal8, threshold: 0.1, f1_score: 0.398, accuracy: 0.897
Goal9, threshold: 0.1, f1_score: 0.306, accuracy: 0.889
Goal10, threshold: 0.16, f1_score: 0.28, accuracy: 0.962
Goal11, threshold: 0.32, f1_score: 0.32, accuracy: 0.959
Goal12, threshold: 0.19, f1_score: 0.414, accuracy: 0.925
Goal13, threshold: 0.16, f1_score: 0.398, accuracy: 0.929
Goal14, threshold: 0.21, f1_score: 0.294, accuracy: 0.969
Goal15, threshold: 0.19, f1_score: 0.376, accuracy: 0.96
Goal16, threshold: 0.15, f1_score: 0.264, accuracy: 0.963
Goal17, threshold: 0.1, f1_score: 0.232, accuracy: 0.894




Total Best f1_score:0.339, accuracy:0.945
4




Goal1, threshold: 0.1, f1_score: 0.169, accuracy: 0.977
Goal2, threshold: 0.16, f1_score: 0.27, accuracy: 0.987
Goal3, threshold: 0.27, f1_score: 0.34, accuracy: 0.927
Goal4, threshold: 0.22, f1_score: 0.343, accuracy: 0.971
Goal5, threshold: 0.17, f1_score: 0.334, accuracy: 0.95
Goal6, threshold: 0.35, f1_score: 0.315, accuracy: 0.988
Goal7, threshold: 0.14, f1_score: 0.414, accuracy: 0.933
Goal8, threshold: 0.1, f1_score: 0.393, accuracy: 0.899
Goal9, threshold: 0.16, f1_score: 0.31, accuracy: 0.912
Goal10, threshold: 0.28, f1_score: 0.285, accuracy: 0.971
Goal11, threshold: 0.27, f1_score: 0.329, accuracy: 0.955
Goal12, threshold: 0.24, f1_score: 0.419, accuracy: 0.933
Goal13, threshold: 0.23, f1_score: 0.393, accuracy: 0.935
Goal14, threshold: 0.14, f1_score: 0.282, accuracy: 0.967
Goal15, threshold: 0.17, f1_score: 0.403, accuracy: 0.963
Goal16, threshold: 0.26, f1_score: 0.248, accuracy: 0.972
Goal17, threshold: 0.08, f1_score: 0.233, accuracy: 0.892




Total Best f1_score:0.341, accuracy:0.949
5




Goal1, threshold: 0.28, f1_score: 0.158, accuracy: 0.987
Goal2, threshold: 0.18, f1_score: 0.306, accuracy: 0.989
Goal3, threshold: 0.23, f1_score: 0.335, accuracy: 0.923
Goal4, threshold: 0.26, f1_score: 0.345, accuracy: 0.972
Goal5, threshold: 0.14, f1_score: 0.337, accuracy: 0.946
Goal6, threshold: 0.34, f1_score: 0.315, accuracy: 0.987
Goal7, threshold: 0.16, f1_score: 0.42, accuracy: 0.937
Goal8, threshold: 0.12, f1_score: 0.398, accuracy: 0.907
Goal9, threshold: 0.11, f1_score: 0.316, accuracy: 0.9
Goal10, threshold: 0.2, f1_score: 0.289, accuracy: 0.967
Goal11, threshold: 0.34, f1_score: 0.343, accuracy: 0.961
Goal12, threshold: 0.18, f1_score: 0.421, accuracy: 0.925
Goal13, threshold: 0.24, f1_score: 0.407, accuracy: 0.938
Goal14, threshold: 0.19, f1_score: 0.276, accuracy: 0.968
Goal15, threshold: 0.27, f1_score: 0.386, accuracy: 0.965
Goal16, threshold: 0.14, f1_score: 0.265, accuracy: 0.96
Goal17, threshold: 0.14, f1_score: 0.241, accuracy: 0.915




Total Best f1_score:0.347, accuracy:0.95
6




Goal1, threshold: 0.06, f1_score: 0.16, accuracy: 0.972
Goal2, threshold: 0.2, f1_score: 0.299, accuracy: 0.989
Goal3, threshold: 0.14, f1_score: 0.333, accuracy: 0.913
Goal4, threshold: 0.26, f1_score: 0.347, accuracy: 0.973
Goal5, threshold: 0.21, f1_score: 0.343, accuracy: 0.953
Goal6, threshold: 0.35, f1_score: 0.312, accuracy: 0.988
Goal7, threshold: 0.12, f1_score: 0.421, accuracy: 0.936
Goal8, threshold: 0.13, f1_score: 0.392, accuracy: 0.91
Goal9, threshold: 0.1, f1_score: 0.316, accuracy: 0.906
Goal10, threshold: 0.27, f1_score: 0.296, accuracy: 0.972
Goal11, threshold: 0.22, f1_score: 0.331, accuracy: 0.952
Goal12, threshold: 0.16, f1_score: 0.423, accuracy: 0.922
Goal13, threshold: 0.18, f1_score: 0.399, accuracy: 0.935
Goal14, threshold: 0.24, f1_score: 0.301, accuracy: 0.973
Goal15, threshold: 0.16, f1_score: 0.392, accuracy: 0.961
Goal16, threshold: 0.2, f1_score: 0.289, accuracy: 0.967
Goal17, threshold: 0.11, f1_score: 0.245, accuracy: 0.915




Total Best f1_score:0.347, accuracy:0.949


In [None]:
################ Test ################

In [None]:
# Showing F1 Score and Accuracy in all an by qry
import numpy as np
from sklearn.metrics import f1_score, accuracy_score

def show_statistics_test(all_predited_list, all_tgt_list):
  by_qry_tgt_list = []
  by_qry_pred_list = []
  for i in range(17):
    by_qry_tgt = []
    for data in all_tgt_list:
      by_qry_tgt += data[i]
    by_qry_pred = []
    for data in all_predited_list:
      by_qry_pred += data[i]
    f1 = f1_score(by_qry_pred, by_qry_tgt)
    acc = accuracy_score(by_qry_pred, by_qry_tgt)
    print(f'Goal{i+1}, f1_score: {round(f1, 3)}, accuracy: {round(acc, 3)}')

    by_qry_tgt_list.append(by_qry_tgt)
    by_qry_pred_list.append(by_qry_pred)
  
  tgt_list_all = []
  pred_list_all = []
  for by_qry_tgt in by_qry_tgt_list:
    tgt_list_all += by_qry_tgt
  for by_qry_pred in by_qry_pred_list:
    pred_list_all += by_qry_pred
  
  f1_all = f1_score(pred_list_all, tgt_list_all)
  acc_all = accuracy_score(pred_list_all, tgt_list_all)
  print(f'Total, f1_score: {round(f1_all, 3)}, accuracy: {round(acc_all, 3)}')

In [None]:
import numpy as np
import torch
from glob import glob
from sentence_transformers import SentenceTransformer, InputExample, losses, util

def test_execute(m, thres_list):
  dataset = torch.load(glob(work_dir + '/dataset.pt')[0])
  saved_model = SentenceTransformer(work_dir + '/model_epoch' + str(m))
  src_pred_list = []
  cos_sim_list = []
  for i in range(17):
    cos_sim_list = sentence_smilarity(dataset['src_doc_test'], [dataset['all_query_txt'][i]], saved_model)
    src_pred = predict_goals(cos_sim_list, threshold=thres_list[i])
    src_pred_list.append(src_pred)
  src_pred_list_t = np.squeeze(src_pred_list).T.tolist()
  show_statistics_test(src_pred_list_t, dataset['tgt_list_test'])

In [None]:
# Showing Test Results
thres_list = [0.32, 0.17, 0.24, 0.25, 0.15, 0.36, 0.19, 0.11, 0.15, 0.28, 0.27, 0.14, 0.39, 0.23, 0.23, 0.32, 0.16]
test_execute(7, thres_list)

  result = getattr(asarray(obj), method)(*args, **kwds)


Goal1, f1_score: 0.042, accuracy: 0.989
Goal2, f1_score: 0.175, accuracy: 0.984
Goal3, f1_score: 0.279, accuracy: 0.926
Goal4, f1_score: 0.269, accuracy: 0.966
Goal5, f1_score: 0.309, accuracy: 0.939
Goal6, f1_score: 0.424, accuracy: 0.981
Goal7, f1_score: 0.328, accuracy: 0.932
Goal8, f1_score: 0.287, accuracy: 0.885
Goal9, f1_score: 0.317, accuracy: 0.916
Goal10, f1_score: 0.253, accuracy: 0.963
Goal11, f1_score: 0.253, accuracy: 0.945
Goal12, f1_score: 0.33, accuracy: 0.905
Goal13, f1_score: 0.402, accuracy: 0.933
Goal14, f1_score: 0.336, accuracy: 0.973
Goal15, f1_score: 0.349, accuracy: 0.965
Goal16, f1_score: 0.178, accuracy: 0.973
Goal17, f1_score: 0.173, accuracy: 0.925
Total, f1_score: 0.298, accuracy: 0.947
