In [1]:
import sys
sys.path.append('../src-py/')
from tqdm import tqdm
tqdm.pandas()
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [2]:
from flair.data import Corpus, Sentence
from flair.datasets import ColumnCorpus    
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

In [3]:
from project_debater_api import *

In [4]:
data_folder = '../../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/'
model_folder = '../../../data-ceph/arguana/arg-generation/claim-target-tagger/model'

### Train a Target tagger on IBM dataset:

In [6]:
columns = {0: 'text', 1: 'pos', 2: 'ct'}
# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file='train.tsv',
                              test_file='test.tsv')

2022-01-07 14:48:21,655 Reading data from ../../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds
2022-01-07 14:48:21,656 Train: ../../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/train.tsv
2022-01-07 14:48:21,657 Dev: None
2022-01-07 14:48:21,658 Test: ../../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/test.tsv


In [11]:
label_type = 'ct'

label_dict = corpus.make_label_dictionary(label_type=label_type)
print(label_dict)

# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ct',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

# 6. initialize trainer
trainer = ModelTrainer(tagger, corpus)

# 7. run fine-tuning
trainer.fine_tune(model_folder,
                  learning_rate=5.0e-6,
                  mini_batch_size=4,
                  #mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU
                  )

2022-01-07 15:01:45,191 Computing label dictionary. Progress:


100%|██████████| 877/877 [00:00<00:00, 19214.40it/s]

2022-01-07 15:01:45,240 Corpus contains the labels: pos (#11355), ct (#11355)
2022-01-07 15:01:45,240 Created (for label 'ct') Dictionary with 4 tags: <unk>, O, B-CT, I-CT
Dictionary with 4 tags: <unk>, O, B-CT, I-CT





2022-01-07 15:01:54,759 ----------------------------------------------------------------------------------------------------
2022-01-07 15:01:54,791 Model: "SequenceTagger(
  (embeddings): TransformerWordEmbeddings(
    (model): XLMRobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(250002, 1024, padding_idx=1)
        (position_embeddings): Embedding(514, 1024, padding_idx=1)
        (token_type_embeddings): Embedding(1, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0): RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_feature

{'test_score': 0.7769347496206374,
 'dev_score_history': [0.5333333333333333,
  0.7452830188679246,
  0.7902439024390244,
  0.8350515463917525,
  0.8775510204081632,
  0.8730964467005077,
  0.8820512820512821,
  0.8832487309644671,
  0.8787878787878788,
  0.8820512820512821],
 'train_loss_history': [1.0426235850715049,
  0.42617399000822315,
  0.29372392421475935,
  0.2260485619470602,
  0.20374770659286134,
  0.18255921371144995,
  0.14026931725029088,
  0.14063569918816818,
  0.11979448386688905,
  0.10413014359435109],
 'dev_loss_history': [tensor(0.3534, device='cuda:0'),
  tensor(0.1715, device='cuda:0'),
  tensor(0.1730, device='cuda:0'),
  tensor(0.2387, device='cuda:0'),
  tensor(0.2016, device='cuda:0'),
  tensor(0.2484, device='cuda:0'),
  tensor(0.2894, device='cuda:0'),
  tensor(0.2741, device='cuda:0'),
  tensor(0.2803, device='cuda:0'),
  tensor(0.2976, device='cuda:0')]}

### Extract targets from Reddit conclusions:

In [5]:
from flair.models import SequenceTagger
from flair.tokenization import SegtokSentenceSplitter

In [6]:
def extract_targets(claims):
    sentences = [Sentence(x) for x in claims]
    # predict tags for sentences
    model = SequenceTagger.load(model_folder+'/final-model.pt')
    model.predict(sentences)

    # iterate through sentences and print predicted labels
    targets = []
    for sentence in sentences:
        target_spans = sorted([(s.text, s.score) for s in sentence.get_spans('ct')], key=lambda x: -x[1])
        if len(target_spans) > 0:
            targets.append(target_spans[0][0])
        else:
            targets.append(sentence.to_original_text())
        
    return targets

In [7]:
def extract_targets_and_stances(df):
    unique_conclusions = df.title.unique().tolist()
    unique_conclusions_targets = extract_targets(unique_conclusions)
    unique_conclusions_stances = get_stances(unique_conclusions_targets, unique_conclusions)

    conc_to_targets = {x[0]: x[1] for x in zip(unique_conclusions, unique_conclusions_targets)}
    conc_to_stances = {x[0]: x[1] for x in zip(unique_conclusions, unique_conclusions_stances)}
    
    df['conclusion_targets'] = df.title.apply(lambda x: conc_to_targets[x])
    df['conclusion_stance']  = df.title.apply(lambda x: conc_to_stances[x])
    
    return df

In [50]:
#Extract conclusion target and stances for dev_sample
dev_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc.pkl')
dev_df = extract_targets_and_stances(dev_df)
dev_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc_with_targets.pkl')

In [20]:
#Extract conclusion target and stances for test_sample
test_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/test_concusion_comp_remove_75sem_perc_sample.pkl')
test_df = extract_targets_and_stances(test_df)
test_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/test_conclusion_comp_remove_75sem_perc_with_targets.pkl')

2022-02-17 16:40:46,107 loading file ../../../data-ceph/arguana/arg-generation/claim-target-tagger/model/final-model.pt


ProConClient: 100%|██████████| 2336/2336 [00:39<00:00, 74.02it/s]

----------------

In [24]:
#Extract conclusion target and stances for dev_sample_all
dev_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/valid_conclusion_all_sample.pkl')
dev_df = dev_df[dev_df.title.str.len() > 0]
dev_df = extract_targets_and_stances(dev_df)
dev_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/valid_conclusion_all_sample_with_targets.pkl')

2022-02-17 16:48:21,201 loading file ../../../data-ceph/arguana/arg-generation/claim-target-tagger/model/final-model.pt


ProConClient: 100%|██████████| 1497/1497 [00:32<00:00, 62.47it/s]

In [9]:
#Extract conclusion target and stances for test_sample_all
dev_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/test_concusion_all_sample.pkl')
dev_df = dev_df[dev_df.title.str.len() > 0]
dev_df = extract_targets_and_stances(dev_df)
dev_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/test_conclusion_all_sample_with_targets.pkl')

2022-03-08 13:01:03,790 loading file ../../../data-ceph/arguana/arg-generation/claim-target-tagger/model/final-model.pt


Downloading:   0%|          | 0.00/616 [00:00<?, ?B/s]

ProConClient: 100%|██████████| 1899/1899 [00:42<00:00, 59.82it/s]