Postprocessing of the predictions for task 2: technique classification.

In [1]:
import pandas as pd
import numpy as np

In [1]:
TRAIN_FILE = "tc-dev-complex-repetitions.tsv"
ORIGINAL_LABELS = "tc-dev-base-labels.txt"
ALTERNATIVE_LABELS = "tc-dev-alternative-labels.txt"
FILE_TO_SAVE = "tc-predictions.txt"

- TRAIN_FILE - is a file which contains training instances with repetition features
- ORIGINAL_LABELS - a submission file with labels which were generated by a base model
- ALTERNATIVE_LABELS - a submission file with labels which were generated by an alternative model that doesn't predict "repetition" at all

### Reading the data

The data should contain two features which describe Repetitions:
- number of repetitions `n_of_repetitions`
- check that it's not the first occurence of span in the article `not_first_occurrence`

In [3]:
dev_df = pd.read_csv(TRAIN_FILE, sep="\t")

In [4]:
dev_df.head()

Unnamed: 0,document_id,label,span_start,span_end,text,n_of_repetitions,not_first_occurrence,n_of_lemmatized_repetitions
0,730093263,?,123,128,white,1,0,1
1,730093263,?,352,357,black,2,0,2
2,730093263,?,1370,1393,“true American heroes.”,0,0,0
3,730093263,?,2434,2439,black,2,1,2
4,730093263,?,2699,2807,"If these two men had survived, and Quentin Lam...",0,0,0


In [5]:
n_of_repetitions = dev_df["n_of_repetitions"].values
not_first_occurrence = dev_df["not_first_occurrence"].values

### Making predictions

In [6]:
y_pred = []
for i in range(len(n_of_repetitions)):
    if n_of_repetitions[i] > 0 and not_first_occurrence[i] == 1:
        y_pred.append(1)
    else:
        y_pred.append(0)

Just a small check to make sure that the model generates a reasonable amount of the "Repetition" predictions

In [7]:
y_pred.count(1)

150

### Editing original labels with "Repetition" predictions + handling duplicates

In [9]:
with open(ORIGINAL_LABELS, "r") as f:
    original_lines = f.readlines()
    original_lines = [line.strip() for line in original_lines]
    
with open(ALTERNATIVE_LABELS, "r") as f:
    alternative_lines = f.readlines()
    alternative_lines = [line.strip().split("\t") for line in alternative_lines]
    
with open(FILE_TO_SAVE, "w") as f:
    postprocessed_entries = []
    duplicate_counter = 0
    for idx in range(len(original_lines)):
        columns = original_lines[idx].split("\t")
        
        entry_key = (columns[0], columns[2], columns[3])
        if (entry_key in postprocessed_entries):
            print("There is a duplicate!")
            duplicate_counter += 1
            print("Old entry:", columns)
        
        # If Repetition-postprocessor predicts a repetition and it's not a duplicate,
        # we assign it to the final output
        if (y_pred[idx] == 1) and (entry_key not in postprocessed_entries):
            columns[1] = "Repetition"
            
        # If originally it used to be a repetition but Repetition-postprocessor doesn't 
        # predict it as a repetition and it is not a duplicate,
        # we replace the Repetition tag in the final output with an alternative label
        if (columns[1] == "Repetition") and (y_pred[idx] == 0) and (entry_key not in postprocessed_entries):
            columns[1] = alternative_lines[idx][1]
        # If originally it used to be a repetition but Repetition-postprocessor doesn't 
        # predict it as a repetition and but it is a duplicate,
        # we don't replace the Repetition in the final output
        elif (columns[1] == "Repetition") and (y_pred[idx] == 0) and (entry_key in postprocessed_entries):
            columns[1] = "Repetition"
        # If it's not connected to the repetition at all but it's a duplicate
        # we replace the original label with an alternative one
        elif (entry_key in postprocessed_entries) and (y_pred[idx] == 0):
            columns[1] = alternative_lines[idx][1]
        
        if (entry_key in postprocessed_entries):
            print("New entry:", columns)
            
        postprocessed_entries.append((columns[0], columns[2], columns[3]))
            
        f.write("\t".join(columns)+"\n")
        
print("There are {} duplicates".format(duplicate_counter))