In [1]:
import sys
sys.path.append('../src-py/')
from tqdm import tqdm
tqdm.pandas()
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [2]:
from flair.data import Corpus, Sentence
from flair.datasets import ColumnCorpus    
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

In [3]:
from project_debater_api import *

In [6]:
data_folder = '../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/'
model_folder = '../../data-ceph/arguana/arg-generation/claim-target-tagger/model'

### Train a Target tagger on IBM dataset:

In [7]:
columns = {0: 'text', 1: 'pos', 2: 'ct'}
# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file='train_ds.tsv',
                              test_file='test_ds.tsv')

2022-06-20 12:53:07,539 Reading data from ../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds
2022-06-20 12:53:07,540 Train: ../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/train_ds.tsv
2022-06-20 12:53:07,540 Dev: None
2022-06-20 12:53:07,541 Test: ../../data-ceph/arguana/arg-generation/claim-target-tagger/data/ibm_ds/test_ds.tsv


In [8]:
label_type = 'ct'

label_dict = corpus.make_label_dictionary(label_type=label_type)
print(label_dict)

# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='xlm-roberta-large',
                                       layers="-1",
                                       subtoken_pooling="first",
                                       fine_tune=True,
                                       use_context=True,
                                       )

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='ct',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

# 6. initialize trainer
trainer = ModelTrainer(tagger, corpus)

# 7. run fine-tuning
trainer.fine_tune(model_folder,
                  learning_rate=5.0e-6,
                  mini_batch_size=4,
                  #mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU
                  )

2022-06-20 12:53:14,830 Computing label dictionary. Progress:


100%|██████████| 1157/1157 [00:00<00:00, 30142.05it/s]

2022-06-20 12:53:14,872 Corpus contains the labels: pos (#14298), ct (#14298)
2022-06-20 12:53:14,873 Created (for label 'ct') Dictionary with 4 tags: <unk>, B-CT, I-CT, O
Dictionary with 4 tags: <unk>, B-CT, I-CT, O





2022-06-20 12:53:52,636 ----------------------------------------------------------------------------------------------------
2022-06-20 12:53:52,705 Model: "SequenceTagger(
  (embeddings): TransformerWordEmbeddings(
    (model): XLMRobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(250002, 1024, padding_idx=1)
        (position_embeddings): Embedding(514, 1024, padding_idx=1)
        (token_type_embeddings): Embedding(1, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0): RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_feature

{'test_score': 0.7972097658196312,
 'dev_score_history': [0.5808383233532933,
  0.7311827956989247,
  0.7753623188405797,
  0.8072727272727273,
  0.8145454545454545,
  0.8475836431226765,
  0.8389513108614233,
  0.8432835820895523,
  0.8507462686567164,
  0.8314606741573034],
 'train_loss_history': [0.949919008878515,
  0.40578732975968484,
  0.28271249064705106,
  0.2324002127852648,
  0.2194415353511049,
  0.18234073505706924,
  0.15147190432366534,
  0.14496121027377656,
  0.11475107674469999,
  0.11619692441975707],
 'dev_loss_history': [tensor(0.3845, device='cuda:0'),
  tensor(0.2300, device='cuda:0'),
  tensor(0.2998, device='cuda:0'),
  tensor(0.3240, device='cuda:0'),
  tensor(0.2827, device='cuda:0'),
  tensor(0.3348, device='cuda:0'),
  tensor(0.3911, device='cuda:0'),
  tensor(0.3459, device='cuda:0'),
  tensor(0.3379, device='cuda:0'),
  tensor(0.3535, device='cuda:0')]}

### Extract targets from Reddit conclusions:

In [5]:
from ca_utils import *

2022-05-19 21:30:29,281 loading file ../../../data-ceph/arguana/arg-generation/claim-target-tagger/model/final-model.pt
2022-05-19 21:30:59,358 SequenceTagger predicts: Dictionary with 5 tags: O, S-CT, B-CT, E-CT, I-CT


In [6]:
def extract_targets_and_stances(df):
    unique_conclusions = df.title.unique().tolist()
    unique_conclusions_targets = extract_targets(unique_conclusions)
    unique_conclusions_stances = get_stances(unique_conclusions_targets, unique_conclusions)

    conc_to_targets = {x[0]: x[1] for x in zip(unique_conclusions, unique_conclusions_targets)}
    conc_to_stances = {x[0]: x[1] for x in zip(unique_conclusions, unique_conclusions_stances)}
    
    df['conclusion_targets'] = df.title.apply(lambda x: conc_to_targets[x])
    df['conclusion_stance']  = df.title.apply(lambda x: conc_to_stances[x])
    
    return df

----------------

In [23]:
#Extract conclusion target and stances for dev_sample_all
dev_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_valid_conclusion_all_preprocessed.pkl')

dev_df = dev_df[dev_df.title.str.len() > 0]
dev_df = extract_targets_and_stances(dev_df)
dev_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_valid_conclusion_all_preprocessed.pkl')

ProConClient: 100%|██████████| 1997/1997 [00:33<00:00, 60.30it/s]


In [22]:
#Extract conclusion target and stances for test_sample_all
test_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_test_conclusion_all_preprocessed.pkl')

test_df = test_df[test_df.title.str.len() > 0]
test_df = extract_targets_and_stances(test_df)
test_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_test_conclusion_all_preprocessed.pkl')

ProConClient: 100%|██████████| 2000/2000 [00:35<00:00, 56.97it/s]


In [7]:
#Extract conclusion target and stances for test_all
test_df = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/test_conclusion_all_preprocessed.pkl')

test_df = test_df[test_df.title.str.len() > 0]
test_df = extract_targets_and_stances(test_df)
test_df.to_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/test_conclusion_all_preprocessed.pkl')

ProConClient: 100%|██████████| 8519/8519 [02:21<00:00, 60.32it/s]


In [8]:
len(test_df)

8533

In [9]:
test_df[['title', 'bart_conclusion', 'conclusion_targets', 'conclusion_stance']].head()

Unnamed: 0,title,bart_conclusion,conclusion_targets,conclusion_stance
410850,people should come with instructions,i think people should be required by law to use a cheat sheet if they meet someone they,people should come with instructions,0.997129
410858,People should not be heavily criticized for things they put on social media in the distant past,i think the internet should stop being as harsh on people for things they put on social,distant past,-0.952858
410902,We shouldn't focus on slowing climate change,joint statement:: there are other environmental issues that are a greater problem for,focus on slowing climate change,-0.997431
410910,The Australian PM was right to tell students to stop activism around global warming,I believe that activism is a terrible way to combat climate change,stop activism around global warming,0.999497
410916,Feeding cats or dogs a diet with meat is indefensible.,if a cat or dog eats her life then it's a animal killer and they should be,Feeding cats or dogs a diet with meat,-0.984038
