In [18]:
import sys
sys.path.append('../src-py/')
from tqdm import tqdm
tqdm.pandas()
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [19]:
from flair.data import Corpus, Sentence
from flair.datasets import ColumnCorpus    
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

In [20]:
from project_debater_api import *

In [21]:
data_folder = '../../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/'
model_folder = '../../../data-ceph/arguana/arg-generation/claim-target-tagger/model'

### Train a Target tagger on IBM dataset:

In [5]:
columns = {0: 'text', 1: 'pos', 2: 'ct'}
# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file='train_ds.tsv',
                              test_file='test_ds.tsv')

2022-05-01 16:29:04,885 Reading data from ../../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds
2022-05-01 16:29:04,886 Train: ../../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/train_ds.tsv
2022-05-01 16:29:04,887 Dev: None
2022-05-01 16:29:04,888 Test: ../../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/test_ds.tsv


In [6]:
label_type = 'ct'

label_dict = corpus.make_label_dictionary(label_type=label_type)
print(label_dict)

# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ct',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

# 6. initialize trainer
trainer = ModelTrainer(tagger, corpus)

# 7. run fine-tuning
trainer.fine_tune(model_folder,
                  learning_rate=5.0e-6,
                  mini_batch_size=4,
                  #mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU
                  )

2022-05-01 16:29:07,968 Computing label dictionary. Progress:


1157it [00:00, 50027.93it/s]

2022-05-01 16:29:08,019 Dictionary created for label 'ct' with 2 values: CT (seen 1143 times)
Dictionary with 2 tags: <unk>, CT





2022-05-01 16:29:21,188 SequenceTagger predicts: Dictionary with 5 tags: O, S-CT, B-CT, E-CT, I-CT
2022-05-01 16:29:21,393 ----------------------------------------------------------------------------------------------------


  "There should be no best model saved at epoch 1 except there "


2022-05-01 16:29:21,400 Model: "SequenceTagger(
  (embeddings): TransformerWordEmbeddings(
    (model): XLMRobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(250002, 1024, padding_idx=1)
        (position_embeddings): Embedding(514, 1024, padding_idx=1)
        (token_type_embeddings): Embedding(1, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0): RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSe

100%|██████████| 33/33 [00:01<00:00, 17.50it/s]


2022-05-01 16:30:02,137 Evaluating as a multi-label problem: False
2022-05-01 16:30:02,149 DEV : loss 0.6870161890983582 - f1-score (micro avg)  0.0
2022-05-01 16:30:02,152 BAD EPOCHS (no improvement): 4
2022-05-01 16:30:02,153 ----------------------------------------------------------------------------------------------------
2022-05-01 16:30:05,930 epoch 2 - iter 29/290 - loss 0.80676257 - samples/sec: 30.74 - lr: 0.000005
2022-05-01 16:30:09,804 epoch 2 - iter 58/290 - loss 0.76814420 - samples/sec: 29.96 - lr: 0.000005
2022-05-01 16:30:13,724 epoch 2 - iter 87/290 - loss 0.74564540 - samples/sec: 29.61 - lr: 0.000005
2022-05-01 16:30:17,645 epoch 2 - iter 116/290 - loss 0.74784677 - samples/sec: 29.60 - lr: 0.000005
2022-05-01 16:30:21,523 epoch 2 - iter 145/290 - loss 0.73762055 - samples/sec: 29.93 - lr: 0.000005
2022-05-01 16:30:25,326 epoch 2 - iter 174/290 - loss 0.71779486 - samples/sec: 30.52 - lr: 0.000005
2022-05-01 16:30:29,162 epoch 2 - iter 203/290 - loss 0.69554402 - s

100%|██████████| 33/33 [00:01<00:00, 17.55it/s]

2022-05-01 16:30:43,047 Evaluating as a multi-label problem: False
2022-05-01 16:30:43,057 DEV : loss 0.49243220686912537 - f1-score (micro avg)  0.6783
2022-05-01 16:30:43,060 BAD EPOCHS (no improvement): 4
2022-05-01 16:30:43,061 ----------------------------------------------------------------------------------------------------





2022-05-01 16:30:46,908 epoch 3 - iter 29/290 - loss 0.37728256 - samples/sec: 30.20 - lr: 0.000004
2022-05-01 16:30:50,771 epoch 3 - iter 58/290 - loss 0.36272528 - samples/sec: 30.04 - lr: 0.000004
2022-05-01 16:30:54,606 epoch 3 - iter 87/290 - loss 0.39249552 - samples/sec: 30.28 - lr: 0.000004
2022-05-01 16:30:58,439 epoch 3 - iter 116/290 - loss 0.38959920 - samples/sec: 30.28 - lr: 0.000004
2022-05-01 16:31:02,235 epoch 3 - iter 145/290 - loss 0.39415008 - samples/sec: 30.58 - lr: 0.000004
2022-05-01 16:31:06,100 epoch 3 - iter 174/290 - loss 0.40498847 - samples/sec: 30.04 - lr: 0.000004
2022-05-01 16:31:09,948 epoch 3 - iter 203/290 - loss 0.39318432 - samples/sec: 30.16 - lr: 0.000004
2022-05-01 16:31:13,729 epoch 3 - iter 232/290 - loss 0.39040536 - samples/sec: 30.70 - lr: 0.000004
2022-05-01 16:31:17,583 epoch 3 - iter 261/290 - loss 0.38170420 - samples/sec: 30.12 - lr: 0.000004
2022-05-01 16:31:21,369 epoch 3 - iter 290/290 - loss 0.37915503 - samples/sec: 30.66 - lr: 0.

100%|██████████| 33/33 [00:01<00:00, 17.34it/s]

2022-05-01 16:31:23,306 Evaluating as a multi-label problem: False
2022-05-01 16:31:23,315 DEV : loss 0.3301384150981903 - f1-score (micro avg)  0.7088
2022-05-01 16:31:23,318 BAD EPOCHS (no improvement): 4
2022-05-01 16:31:23,319 ----------------------------------------------------------------------------------------------------





2022-05-01 16:31:27,755 epoch 4 - iter 29/290 - loss 0.31406910 - samples/sec: 26.17 - lr: 0.000004
2022-05-01 16:31:31,537 epoch 4 - iter 58/290 - loss 0.29301066 - samples/sec: 30.69 - lr: 0.000004
2022-05-01 16:31:35,260 epoch 4 - iter 87/290 - loss 0.29689222 - samples/sec: 31.18 - lr: 0.000004
2022-05-01 16:31:39,174 epoch 4 - iter 116/290 - loss 0.30550631 - samples/sec: 29.65 - lr: 0.000004
2022-05-01 16:31:43,026 epoch 4 - iter 145/290 - loss 0.28137027 - samples/sec: 30.14 - lr: 0.000004
2022-05-01 16:31:46,908 epoch 4 - iter 174/290 - loss 0.28374901 - samples/sec: 29.90 - lr: 0.000004
2022-05-01 16:31:50,712 epoch 4 - iter 203/290 - loss 0.29041878 - samples/sec: 30.51 - lr: 0.000003
2022-05-01 16:31:54,542 epoch 4 - iter 232/290 - loss 0.29081964 - samples/sec: 30.31 - lr: 0.000003
2022-05-01 16:31:58,300 epoch 4 - iter 261/290 - loss 0.28833174 - samples/sec: 30.88 - lr: 0.000003
2022-05-01 16:32:02,123 epoch 4 - iter 290/290 - loss 0.28093797 - samples/sec: 30.36 - lr: 0.

100%|██████████| 33/33 [00:01<00:00, 17.26it/s]

2022-05-01 16:32:04,090 Evaluating as a multi-label problem: False
2022-05-01 16:32:04,100 DEV : loss 0.4771294891834259 - f1-score (micro avg)  0.7312





2022-05-01 16:32:04,104 BAD EPOCHS (no improvement): 4
2022-05-01 16:32:04,104 ----------------------------------------------------------------------------------------------------
2022-05-01 16:32:08,036 epoch 5 - iter 29/290 - loss 0.28865009 - samples/sec: 29.53 - lr: 0.000003
2022-05-01 16:32:11,884 epoch 5 - iter 58/290 - loss 0.28721383 - samples/sec: 30.16 - lr: 0.000003
2022-05-01 16:32:15,726 epoch 5 - iter 87/290 - loss 0.25325380 - samples/sec: 30.21 - lr: 0.000003
2022-05-01 16:32:20,053 epoch 5 - iter 116/290 - loss 0.25687301 - samples/sec: 26.83 - lr: 0.000003
2022-05-01 16:32:23,814 epoch 5 - iter 145/290 - loss 0.25714270 - samples/sec: 30.86 - lr: 0.000003
2022-05-01 16:32:27,590 epoch 5 - iter 174/290 - loss 0.24267253 - samples/sec: 30.74 - lr: 0.000003
2022-05-01 16:32:31,389 epoch 5 - iter 203/290 - loss 0.23505762 - samples/sec: 30.56 - lr: 0.000003
2022-05-01 16:32:35,251 epoch 5 - iter 232/290 - loss 0.23301916 - samples/sec: 30.05 - lr: 0.000003
2022-05-01 16:3

100%|██████████| 33/33 [00:01<00:00, 17.44it/s]


2022-05-01 16:32:44,892 Evaluating as a multi-label problem: False
2022-05-01 16:32:44,901 DEV : loss 0.5390682816505432 - f1-score (micro avg)  0.7372
2022-05-01 16:32:44,906 BAD EPOCHS (no improvement): 4
2022-05-01 16:32:44,906 ----------------------------------------------------------------------------------------------------
2022-05-01 16:32:48,765 epoch 6 - iter 29/290 - loss 0.15000575 - samples/sec: 30.09 - lr: 0.000003
2022-05-01 16:32:52,654 epoch 6 - iter 58/290 - loss 0.24351116 - samples/sec: 29.84 - lr: 0.000003
2022-05-01 16:32:56,523 epoch 6 - iter 87/290 - loss 0.21480192 - samples/sec: 30.00 - lr: 0.000003
2022-05-01 16:33:00,365 epoch 6 - iter 116/290 - loss 0.21280186 - samples/sec: 30.21 - lr: 0.000003
2022-05-01 16:33:04,224 epoch 6 - iter 145/290 - loss 0.24328920 - samples/sec: 30.08 - lr: 0.000003
2022-05-01 16:33:07,990 epoch 6 - iter 174/290 - loss 0.23137018 - samples/sec: 30.82 - lr: 0.000002
2022-05-01 16:33:11,783 epoch 6 - iter 203/290 - loss 0.23270854 

100%|██████████| 33/33 [00:02<00:00, 13.60it/s]

2022-05-01 16:33:25,749 Evaluating as a multi-label problem: False
2022-05-01 16:33:25,758 DEV : loss 0.5501019358634949 - f1-score (micro avg)  0.7636
2022-05-01 16:33:25,762 BAD EPOCHS (no improvement): 4
2022-05-01 16:33:25,762 ----------------------------------------------------------------------------------------------------





2022-05-01 16:33:29,551 epoch 7 - iter 29/290 - loss 0.14697487 - samples/sec: 30.65 - lr: 0.000002
2022-05-01 16:33:33,482 epoch 7 - iter 58/290 - loss 0.15002215 - samples/sec: 29.53 - lr: 0.000002
2022-05-01 16:33:37,306 epoch 7 - iter 87/290 - loss 0.15058176 - samples/sec: 30.35 - lr: 0.000002
2022-05-01 16:33:41,134 epoch 7 - iter 116/290 - loss 0.15386545 - samples/sec: 30.33 - lr: 0.000002
2022-05-01 16:33:44,832 epoch 7 - iter 145/290 - loss 0.14683373 - samples/sec: 31.39 - lr: 0.000002
2022-05-01 16:33:48,652 epoch 7 - iter 174/290 - loss 0.16276921 - samples/sec: 30.39 - lr: 0.000002
2022-05-01 16:33:52,501 epoch 7 - iter 203/290 - loss 0.16921189 - samples/sec: 30.16 - lr: 0.000002
2022-05-01 16:33:56,399 epoch 7 - iter 232/290 - loss 0.17172012 - samples/sec: 29.78 - lr: 0.000002
2022-05-01 16:34:00,224 epoch 7 - iter 261/290 - loss 0.18184380 - samples/sec: 30.34 - lr: 0.000002
2022-05-01 16:34:04,110 epoch 7 - iter 290/290 - loss 0.17782460 - samples/sec: 29.87 - lr: 0.

100%|██████████| 33/33 [00:01<00:00, 17.51it/s]

2022-05-01 16:34:06,052 Evaluating as a multi-label problem: False





2022-05-01 16:34:06,062 DEV : loss 0.5791816711425781 - f1-score (micro avg)  0.7647
2022-05-01 16:34:06,067 BAD EPOCHS (no improvement): 4
2022-05-01 16:34:06,067 ----------------------------------------------------------------------------------------------------
2022-05-01 16:34:09,905 epoch 8 - iter 29/290 - loss 0.19255112 - samples/sec: 30.26 - lr: 0.000002
2022-05-01 16:34:13,693 epoch 8 - iter 58/290 - loss 0.19081356 - samples/sec: 30.64 - lr: 0.000002
2022-05-01 16:34:17,966 epoch 8 - iter 87/290 - loss 0.18303466 - samples/sec: 27.16 - lr: 0.000002
2022-05-01 16:34:21,739 epoch 8 - iter 116/290 - loss 0.19049393 - samples/sec: 30.76 - lr: 0.000001
2022-05-01 16:34:25,538 epoch 8 - iter 145/290 - loss 0.18064759 - samples/sec: 30.55 - lr: 0.000001
2022-05-01 16:34:29,328 epoch 8 - iter 174/290 - loss 0.18161132 - samples/sec: 30.63 - lr: 0.000001
2022-05-01 16:34:33,125 epoch 8 - iter 203/290 - loss 0.17904067 - samples/sec: 30.57 - lr: 0.000001
2022-05-01 16:34:36,858 epoch 8

100%|██████████| 33/33 [00:01<00:00, 17.47it/s]


2022-05-01 16:34:46,413 Evaluating as a multi-label problem: False
2022-05-01 16:34:46,422 DEV : loss 0.511544406414032 - f1-score (micro avg)  0.7794
2022-05-01 16:34:46,426 BAD EPOCHS (no improvement): 4
2022-05-01 16:34:46,427 ----------------------------------------------------------------------------------------------------
2022-05-01 16:34:50,208 epoch 9 - iter 29/290 - loss 0.28618755 - samples/sec: 30.71 - lr: 0.000001
2022-05-01 16:34:54,100 epoch 9 - iter 58/290 - loss 0.21005557 - samples/sec: 29.82 - lr: 0.000001
2022-05-01 16:34:57,932 epoch 9 - iter 87/290 - loss 0.18516733 - samples/sec: 30.29 - lr: 0.000001
2022-05-01 16:35:01,765 epoch 9 - iter 116/290 - loss 0.17993167 - samples/sec: 30.29 - lr: 0.000001
2022-05-01 16:35:05,513 epoch 9 - iter 145/290 - loss 0.17693395 - samples/sec: 30.97 - lr: 0.000001
2022-05-01 16:35:09,250 epoch 9 - iter 174/290 - loss 0.17303323 - samples/sec: 31.07 - lr: 0.000001
2022-05-01 16:35:13,004 epoch 9 - iter 203/290 - loss 0.16733077 -

100%|██████████| 33/33 [00:02<00:00, 13.49it/s]


2022-05-01 16:35:27,044 Evaluating as a multi-label problem: False
2022-05-01 16:35:27,054 DEV : loss 0.5512567162513733 - f1-score (micro avg)  0.7849
2022-05-01 16:35:27,058 BAD EPOCHS (no improvement): 4
2022-05-01 16:35:27,058 ----------------------------------------------------------------------------------------------------
2022-05-01 16:35:30,841 epoch 10 - iter 29/290 - loss 0.11398264 - samples/sec: 30.69 - lr: 0.000001
2022-05-01 16:35:34,646 epoch 10 - iter 58/290 - loss 0.10573766 - samples/sec: 30.51 - lr: 0.000000
2022-05-01 16:35:38,443 epoch 10 - iter 87/290 - loss 0.10254649 - samples/sec: 30.57 - lr: 0.000000
2022-05-01 16:35:42,185 epoch 10 - iter 116/290 - loss 0.12466263 - samples/sec: 31.04 - lr: 0.000000
2022-05-01 16:35:46,054 epoch 10 - iter 145/290 - loss 0.13699229 - samples/sec: 30.01 - lr: 0.000000
2022-05-01 16:35:49,856 epoch 10 - iter 174/290 - loss 0.13197114 - samples/sec: 30.54 - lr: 0.000000
2022-05-01 16:35:53,694 epoch 10 - iter 203/290 - loss 0.13

100%|██████████| 33/33 [00:01<00:00, 17.41it/s]

2022-05-01 16:36:07,213 Evaluating as a multi-label problem: False
2022-05-01 16:36:07,222 DEV : loss 0.5474958419799805 - f1-score (micro avg)  0.782
2022-05-01 16:36:07,226 BAD EPOCHS (no improvement): 4





2022-05-01 16:36:10,049 ----------------------------------------------------------------------------------------------------
2022-05-01 16:36:10,051 Testing using last state of model ...


100%|██████████| 244/244 [00:15<00:00, 15.57it/s]

2022-05-01 16:36:25,758 Evaluating as a multi-label problem: False
2022-05-01 16:36:25,769 0.7622	0.8079	0.7844	0.6452
2022-05-01 16:36:25,770 
Results:
- F-score (micro) 0.7844
- F-score (macro) 0.7844
- Accuracy 0.6452

By class:
              precision    recall  f1-score   support

          CT     0.7622    0.8079    0.7844       968

   micro avg     0.7622    0.8079    0.7844       968
   macro avg     0.7622    0.8079    0.7844       968
weighted avg     0.7622    0.8079    0.7844       968

2022-05-01 16:36:25,770 ----------------------------------------------------------------------------------------------------





{'test_score': 0.7843530591775327,
 'dev_score_history': [0.0,
  0.6783216783216782,
  0.7087719298245615,
  0.7311827956989249,
  0.7372262773722628,
  0.7636363636363634,
  0.7647058823529411,
  0.7794117647058822,
  0.7849056603773584,
  0.7819548872180451],
 'train_loss_history': [1.1407576502377446,
  0.6412852737494023,
  0.3791550337100714,
  0.2809379669566499,
  0.24066042675218985,
  0.22712329009040827,
  0.17782459732661007,
  0.17970192214440647,
  0.16248181861427086,
  0.1379217608272952],
 'dev_loss_history': [0.6870161890983582,
  0.49243220686912537,
  0.3301384150981903,
  0.4771294891834259,
  0.5390682816505432,
  0.5501019358634949,
  0.5791816711425781,
  0.511544406414032,
  0.5512567162513733,
  0.5474958419799805]}

### Extract targets from Reddit conclusions:

In [7]:
from flair.models import SequenceTagger
from flair.tokenization import SegtokSentenceSplitter

In [8]:
# predict tags for sentences
model = SequenceTagger.load(model_folder+'/final-model.pt')

2022-05-01 16:37:45,833 loading file ../../../data-ceph/arguana/arg-generation/claim-target-tagger/model/final-model.pt
2022-05-01 16:37:53,798 SequenceTagger predicts: Dictionary with 5 tags: O, S-CT, B-CT, E-CT, I-CT


In [9]:
def extract_targets(model, claims):
    sentences = [Sentence(x) for x in claims]
    model.predict(sentences)
    # iterate through sentences and print predicted labels
    targets = []
    for sentence in sentences:
        target_spans = sorted([(s.text, s.score) for s in sentence.get_spans('ct')], key=lambda x: -x[1])
        if len(target_spans) > 0:
            targets.append(target_spans[0][0])
        else:
            targets.append(sentence.to_original_text())
        
    return targets

In [10]:
def extract_targets_and_stances(df):
    unique_conclusions = df.title.unique().tolist()
    unique_conclusions_targets = extract_targets(model, unique_conclusions)
    unique_conclusions_stances = get_stances(unique_conclusions_targets, unique_conclusions)

    conc_to_targets = {x[0]: x[1] for x in zip(unique_conclusions, unique_conclusions_targets)}
    conc_to_stances = {x[0]: x[1] for x in zip(unique_conclusions, unique_conclusions_stances)}
    
    df['conclusion_targets'] = df.title.apply(lambda x: conc_to_targets[x])
    df['conclusion_stance']  = df.title.apply(lambda x: conc_to_stances[x])
    
    return df

----------------

In [23]:
#Extract conclusion target and stances for dev_sample_all
dev_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_valid_conclusion_all_preprocessed.pkl')

dev_df = dev_df[dev_df.title.str.len() > 0]
dev_df = extract_targets_and_stances(dev_df)
dev_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_valid_conclusion_all_preprocessed.pkl')

ProConClient: 100%|██████████| 1997/1997 [00:33<00:00, 60.30it/s]


In [22]:
#Extract conclusion target and stances for test_sample_all
test_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_test_conclusion_all_preprocessed.pkl')

test_df = test_df[test_df.title.str.len() > 0]
test_df = extract_targets_and_stances(test_df)
test_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_test_conclusion_all_preprocessed.pkl')

ProConClient: 100%|██████████| 2000/2000 [00:35<00:00, 56.97it/s]
