In [1]:
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict

import spacy
import pandas as pd
import prettytable
import re
import string

In [2]:
train_df = pd.read_csv('./Data/train.csv')
test_df = pd.read_csv('./Data/test.csv')

In [3]:
train_df['first_party'] = train_df['first_party'].str.lower()
train_df['second_party'] = train_df['second_party'].str.lower()
train_df['facts'] = train_df['facts'].str.lower()

test_df['first_party'] = test_df['first_party'].str.lower()
test_df['second_party'] = test_df['second_party'].str.lower()
test_df['facts'] = test_df['facts'].str.lower()

idx = 0
fact = train_df.loc[idx, 'facts'].lower()

nlp = spacy.load('en_core_web_sm')
doc = nlp(fact)
print(doc.text)

on june 27, 1962, phil st. amant, a candidate for public office, made a television speech in baton rouge, louisiana.  during this speech, st. amant accused his political opponent of being a communist and of being involved in criminal activities with the head of the local teamsters union.  finally, st. amant implicated herman thompson, an east baton rouge deputy sheriff, in a scheme to move money between the teamsters union and st. amant’s political opponent. 
thompson successfully sued st. amant for defamation.  louisiana’s first circuit court of appeals reversed, holding that thompson did not show st. amant acted with “malice.”  thompson then appealed to the supreme court of louisiana.  that court held that, although public figures forfeit some of their first amendment protection from defamation, st. amant accused thompson of a crime with utter disregard of whether the remarks were true.  finally, that court held that the first amendment protects uninhibited, robust debate, rather tha

In [4]:
def get_name_re(name, fact_token: spacy.tokens.doc.Doc, first=True):
    name = re.sub(rf'[ .,{string.punctuation}]+', r' ', name.lower())
    name_list = [n for n in name.split() if len(n) >= 1]

    for n in name_list:
        changed_name = re.findall(rf"{n} ?\([a-z]+\)", fact_token.text)
        # print(f"changed name: {changed_name}")
        if changed_name:
            name_list.extend([re.sub(rf'({n}|[ {string.punctuation}])', '', cn) for cn in changed_name])

    # fact_subj = ' '.join([token.text for token in fact_token if 'NN' in token.tag_])
    # print(f"name list: {name_list}")
    # print(f"fact subj: {fact_subj}")
    # res = []
    # for name in name_list:
    #     res.append((name, len(re.findall(name, fact_subj))))
    # print(res)
    abbrev = "firstparty" if first else "secondparty"
    fact_subj = []
    for token in fact_token:
        if 'NN' in token.tag_ and token.text in name_list:
            fact_subj.append(abbrev)
        else:
            fact_subj.append(token.text)
    fact_subj = ' '.join(fact_subj)
    # fact_subj = re.sub(rf"({'|'.join(name_list)})", abbrev, fact)
    fact_subj = re.sub(rf"({abbrev} ?)+", f'{abbrev} ', fact_subj)
    return fact_subj

In [5]:
import re

def replace_name(first_party, second_party, fact_token: spacy.tokens.doc.Doc):
    first_party_name = re.sub(rf'[ .,{string.punctuation}]+', r' ', first_party.lower())
    second_party_name = re.sub(rf'[ .,{string.punctuation}]+', r' ', second_party.lower())
    fp_name_list = [n for n in first_party_name.split() if len(n) >= 1]
    sp_name_list = [n for n in second_party_name.split() if len(n) >= 1]

    fp_name_list_added = fp_name_list.copy()
    sp_name_list_added = sp_name_list.copy()

    for n in fp_name_list:
        changed_name = re.findall(rf"{n} ?\([a-z]+\)", fact_token.text)
        if changed_name:
            fp_name_list_added.extend([re.sub(rf'({n}|[ {string.punctuation}])', '', cn) for cn in changed_name])
    
    for n in sp_name_list:
        changed_name = re.findall(rf"{n} ?\([a-z]+\)", fact_token.text)
        if changed_name:
            sp_name_list_added.extend([re.sub(rf'({n}|[ {string.punctuation}])', '', cn) for cn in changed_name])

    # for name_list in [fp_name_list, sp_name_list]:
    #     for n in name_list:
    #         changed_name = re.findall(rf"{n} ?\([a-z]+\)", fact_token.text)
    #         if changed_name:
    #             name_list.extend([re.sub(rf'({n}|[ {string.punctuation}])', '', cn) for cn in changed_name])
    
    # print(f"fp name list: {fp_name_list}")
    # print(f"sp name list: {sp_name_list}")

    fact_subj = []
    for token in fact_token:
        if 'NN' in token.tag_:
            if token.text in fp_name_list_added:
                fact_subj.append('firstparty')
            elif token.text in sp_name_list_added:
                fact_subj.append('secondparty')
            else:
                fact_subj.append(token.text)
        else:
            fact_subj.append(token.text)
    
    fact_subj = ' '.join(fact_subj)
    fact_subj = re.sub(rf"(firstparty ?)+", f'firstparty ', fact_subj)
    fact_subj = re.sub(rf"(secondparty ?)+", f'secondparty ', fact_subj)
    return fact_subj

In [6]:
from pandarallel import pandarallel

pandarallel.initialize(progress_bar=True, nb_workers=2)

nlp = spacy.load('en_core_web_sm')

# train_df['facts_token'] = train_df['facts'].parallel_apply(nlp)
train_df['new_facts'] = train_df.parallel_apply(lambda x: replace_name(x['first_party'], x['second_party'], nlp(x['facts'])), axis=1)
test_df['new_facts'] = test_df.parallel_apply(lambda x: replace_name(x['first_party'], x['second_party'], nlp(x['facts'])), axis=1)

INFO: Pandarallel will run on 2 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=1239), Label(value='0 / 1239'))), …

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=620), Label(value='0 / 620'))), HB…

In [7]:
column_rename = {'first_party': 'fp',
                 'second_party': 'sp',
                 'first_party_winner': 'label'}

train_df.rename(columns=column_rename, inplace=True)
test_df.rename(columns=column_rename, inplace=True)

In [8]:
train_df

Unnamed: 0,ID,fp,sp,facts,label,new_facts
0,TRAIN_0000,phil a. st. amant,herman a. thompson,"on june 27, 1962, phil st. amant, a candidate ...",1,"on june 27 , 1962 , firstparty . firstparty , ..."
1,TRAIN_0001,stephen duncan,lawrence owens,ramon nelson was riding his bike when he suffe...,0,ramon nelson was riding his bike when he suffe...
2,TRAIN_0002,billy joe magwood,"tony patterson, warden, et al.",an alabama state court convicted billy joe mag...,1,an alabama state court convicted billy firstpa...
3,TRAIN_0003,linkletter,walker,victor linkletter was convicted in state court...,0,victor firstparty was convicted in state court...
4,TRAIN_0004,william earl fikes,alabama,"on april 24, 1953 in selma, alabama, an intrud...",1,"on april 24 , 1953 in selma , secondparty , an..."
...,...,...,...,...,...,...
2473,TRAIN_2473,"hollyfrontier cheyenne refining, llc, et al.","renewable fuels association, et al.",congress amended the clean air act through the...,1,congress amended the clean air act through the...
2474,TRAIN_2474,"grupo mexicano de desarrollo, s. a.","alliance bond fund, inc.","alliance bond fund, inc., an investment fund, ...",1,"secondparty , secondparty . , an investment se..."
2475,TRAIN_2475,peguero,united states,"in 1992, the district court sentenced manuel d...",0,"in 1992 , the district court sentenced manuel ..."
2476,TRAIN_2476,immigration and naturalization service,st. cyr,"on march 8, 1996, enrico st. cyr, a lawful per...",0,"on march 8 , 1996 , enrico secondparty . secon..."


In [9]:
train_df['label'].value_counts()

label
1    1649
0     829
Name: count, dtype: int64

In [10]:
aug_df = pd.DataFrame({'ID': train_df['ID'],
                       'fp': train_df['sp'],
                       'sp': train_df['fp'],
                       'facts': train_df['facts'],
                       'new_facts': train_df['new_facts'],
                       'label': 1-train_df['label']})

aug_df

Unnamed: 0,ID,fp,sp,facts,new_facts,label
0,TRAIN_0000,herman a. thompson,phil a. st. amant,"on june 27, 1962, phil st. amant, a candidate ...","on june 27 , 1962 , firstparty . firstparty , ...",0
1,TRAIN_0001,lawrence owens,stephen duncan,ramon nelson was riding his bike when he suffe...,ramon nelson was riding his bike when he suffe...,1
2,TRAIN_0002,"tony patterson, warden, et al.",billy joe magwood,an alabama state court convicted billy joe mag...,an alabama state court convicted billy firstpa...,0
3,TRAIN_0003,walker,linkletter,victor linkletter was convicted in state court...,victor firstparty was convicted in state court...,1
4,TRAIN_0004,alabama,william earl fikes,"on april 24, 1953 in selma, alabama, an intrud...","on april 24 , 1953 in selma , secondparty , an...",0
...,...,...,...,...,...,...
2473,TRAIN_2473,"renewable fuels association, et al.","hollyfrontier cheyenne refining, llc, et al.",congress amended the clean air act through the...,congress amended the clean air act through the...,0
2474,TRAIN_2474,"alliance bond fund, inc.","grupo mexicano de desarrollo, s. a.","alliance bond fund, inc., an investment fund, ...","secondparty , secondparty . , an investment se...",0
2475,TRAIN_2475,united states,peguero,"in 1992, the district court sentenced manuel d...","in 1992 , the district court sentenced manuel ...",1
2476,TRAIN_2476,st. cyr,immigration and naturalization service,"on march 8, 1996, enrico st. cyr, a lawful per...","on march 8 , 1996 , enrico secondparty . secon...",1


In [11]:
train_df = pd.concat([train_df, aug_df], ignore_index=True)
train_df

Unnamed: 0,ID,fp,sp,facts,label,new_facts
0,TRAIN_0000,phil a. st. amant,herman a. thompson,"on june 27, 1962, phil st. amant, a candidate ...",1,"on june 27 , 1962 , firstparty . firstparty , ..."
1,TRAIN_0001,stephen duncan,lawrence owens,ramon nelson was riding his bike when he suffe...,0,ramon nelson was riding his bike when he suffe...
2,TRAIN_0002,billy joe magwood,"tony patterson, warden, et al.",an alabama state court convicted billy joe mag...,1,an alabama state court convicted billy firstpa...
3,TRAIN_0003,linkletter,walker,victor linkletter was convicted in state court...,0,victor firstparty was convicted in state court...
4,TRAIN_0004,william earl fikes,alabama,"on april 24, 1953 in selma, alabama, an intrud...",1,"on april 24 , 1953 in selma , secondparty , an..."
...,...,...,...,...,...,...
4951,TRAIN_2473,"renewable fuels association, et al.","hollyfrontier cheyenne refining, llc, et al.",congress amended the clean air act through the...,0,congress amended the clean air act through the...
4952,TRAIN_2474,"alliance bond fund, inc.","grupo mexicano de desarrollo, s. a.","alliance bond fund, inc., an investment fund, ...",0,"secondparty , secondparty . , an investment se..."
4953,TRAIN_2475,united states,peguero,"in 1992, the district court sentenced manuel d...",1,"in 1992 , the district court sentenced manuel ..."
4954,TRAIN_2476,st. cyr,immigration and naturalization service,"on march 8, 1996, enrico st. cyr, a lawful per...",1,"on march 8 , 1996 , enrico secondparty . secon..."


In [12]:
train_df['label'].value_counts()

label
1    2478
0    2478
Name: count, dtype: int64

In [13]:
train_data = Dataset.from_pandas(train_df)
test_data = Dataset.from_pandas(test_df)

In [14]:
train_data.features

{'ID': Value(dtype='string', id=None),
 'fp': Value(dtype='string', id=None),
 'sp': Value(dtype='string', id=None),
 'facts': Value(dtype='string', id=None),
 'label': Value(dtype='int64', id=None),
 'new_facts': Value(dtype='string', id=None)}

In [15]:
import datasets

label = datasets.ClassLabel(num_classes=2, names=[0, 1])
train_data.features['label'] = label
train_data = train_data.class_encode_column("label")

Stringifying the column:   0%|          | 0/4956 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/4956 [00:00<?, ? examples/s]

In [16]:
train_data

Dataset({
    features: ['ID', 'fp', 'sp', 'facts', 'label', 'new_facts'],
    num_rows: 4956
})

In [17]:
train_data['label'][:10]

[1, 0, 1, 0, 1, 1, 1, 1, 1, 1]

In [18]:
train_data = train_data.train_test_split(test_size=0.2, shuffle=True, seed=42, stratify_by_column='label')
dataset = DatasetDict({'train': train_data['train'], 'validation': train_data['test'], 'test': test_data})
dataset

DatasetDict({
    train: Dataset({
        features: ['ID', 'fp', 'sp', 'facts', 'label', 'new_facts'],
        num_rows: 3964
    })
    validation: Dataset({
        features: ['ID', 'fp', 'sp', 'facts', 'label', 'new_facts'],
        num_rows: 992
    })
    test: Dataset({
        features: ['ID', 'fp', 'sp', 'facts', 'new_facts'],
        num_rows: 1240
    })
})

In [19]:
import numpy as np

np.mean(train_data['train']['label']), np.mean(train_data['test']['label'])

(0.5, 0.5)

In [20]:
pretrained_model = "distilbert-base-uncased"
# pretrained_model = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)

In [21]:
def preprocess(examples):
    examples['facts_with_party'] = [f'First: {first}.\n Second: {second}.\nFacts: {new_fact}' 
                                    for first, second, new_fact in zip(examples['fp'], examples['sp'], examples['new_facts'])]
    return tokenizer(examples['facts_with_party'], padding='max_length', truncation=True, max_length=512)

dataset = dataset.map(preprocess, batched=True)

Map:   0%|          | 0/3964 [00:00<?, ? examples/s]

Map:   0%|          | 0/992 [00:00<?, ? examples/s]

Map:   0%|          | 0/1240 [00:00<?, ? examples/s]

In [22]:
import evaluate
import numpy as np

accuracy = evaluate.load('accuracy')

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.argmax(preds, axis=1)
    return accuracy.compute(predictions=preds, references=labels)

In [23]:
from transformers import DistilBertModel

model = AutoModelForSequenceClassification.from_pretrained(pretrained_model, num_labels=2)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias', 'classifier.we

In [24]:
training_args = TrainingArguments(
    output_dir='./results',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=20,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    evaluation_strategy='epoch',
    save_strategy='epoch',
)

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
)

In [None]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.696901,0.5
2,No log,0.693649,0.5
3,No log,0.693897,0.5
4,No log,0.693715,0.497984
5,0.695000,0.693918,0.5
6,0.695000,0.694052,0.5
7,0.695000,0.694866,0.5
8,0.695000,0.696856,0.486895
9,0.693200,0.699583,0.478831
10,0.693200,0.706079,0.462702




TrainOutput(global_step=2480, training_loss=0.646895500921434, metrics={'train_runtime': 757.9879, 'train_samples_per_second': 104.593, 'train_steps_per_second': 3.272, 'total_flos': 1.050201536544768e+16, 'train_loss': 0.646895500921434, 'epoch': 20.0})

In [28]:
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase
import torch

class CustomDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
    def __init__(
        self, 
        tokenizer: PreTrainedTokenizerBase, 
        mlm_probability=0.15, 
        pad_to_multiple_of=None,
        special_tokens_to_mask=["firstparty", "secondparty", "court"]
    ):
        super().__init__(tokenizer=tokenizer, mlm=mlm_probability, pad_to_multiple_of=pad_to_multiple_of)
        self.special_tokens_to_mask = [self.tokenizer.encode(st, add_special_tokens=False)[0] for st in special_tokens_to_mask]

    def mask_tokens(self, inputs: torch.Tensor):
        labels = inputs.clone()
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        
        for special_token in self.special_tokens_to_mask:
            special_token_mask = labels.eq(special_token)
            probability_matrix.masked_fill_(special_token_mask, 1.0)
        
        masked_indices = torch.bernoulli(probability_matrix).bool()

        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        return inputs, labels


In [29]:
data_collator = CustomDataCollatorForLanguageModeling(tokenizer)

sentence = train_data['train']['new_facts'][0]

inputs = tokenizer.encode(sentence, return_tensors='pt')
masked_inputs, labels = data_collator.mask_tokens(inputs)
masked_sentence = tokenizer.decode(masked_inputs[0])
print(sentence)
print(masked_sentence)

secondparty observed oil producer - operated stations receiving favorable rates from producers and refiners . in response , secondparty passed a statute prohibiting oil producers or refiners from operating gasoline stations within the state and requiring producers and refiners extend temporary price cuts to the stations they supplied . firstparty challenged the statute in anne arundel county circuit court , which ruled the statute invalid . the secondparty court of appeals reversed the ruling . 

[MASK] [MASK]party observed oil [MASK] - operated stations receiving [MASK] rates [MASK] producers and refiner [MASK]. in response, [MASK]party passed a statute prohibiting oil producers or refiners from operating [MASK] stations within the state and requiring producers and refiners extend temporary price cuts [MASK] the stations they supplied [MASK] [MASK]party challenged the statute [MASK] anne arundel county circuit [MASK], which ruled the statute [MASK]. the [MASK]par [MASK] cinema of appe

In [None]:
import re

def mask_text(text, words_to_mask):
    MASK_TOKEN = '[MASK]'
    for word in words_to_mask:
        text = re.sub(word, MASK_TOKEN, text, flags=re.IGNORECASE)
    return text

