In [None]:
# Lint as: python3
"""Example tagging for Toxic Spans based on Spacy.
Requires:
  pip install spacy sklearn
Install models:
  python -m spacy download en_core_web_sm
"""

import ast
import csv
import random
import statistics
import sys
import itertools
import string
import sklearn
import spacy

sys.path.append('../evaluation')

In [None]:
accu=0

In [None]:
def f1(predictions, gold):
    """
    F1 (a.k.a. DICE) operating on two lists of offsets (e.g., character).
    >>> assert f1([0, 1, 4, 5], [0, 1, 6]) == 0.5714285714285714
    :param predictions: a list of predicted offsets
    :param gold: a list of offsets serving as the ground truth
    :return: a score between 0 and 1
    """
    if len(gold) == 0:
        return 1. if len(predictions) == 0 else 0.
    if len(predictions) == 0:
        return 0.
    predictions_set = set(predictions)
    gold_set = set(gold)
    nom = 2 * len(predictions_set.intersection(gold_set))
    denom = len(predictions_set) + len(gold_set)
    return float(nom)/float(denom)


def evaluate(pred, gold):
    """
    Based on https://github.com/felipebravom/EmoInt/blob/master/codalab/scoring_program/evaluation.py
    :param pred: file with predictions
    :param gold: file with ground truth
    :return:
    """
    # read the predictions
    pred_lines = pred.readlines()
    # read the ground truth
    gold_lines = gold.readlines()

    # only when the same number of lines exists
    if (len(pred_lines) == len(gold_lines)):
        data_dic = {}
        for n, line in enumerate(gold_lines):
            parts = line.split('\t')
            if len(parts) == 2:
                data_dic[int(parts[0])] = [literal_eval(parts[1])]
            else:
                raise ValueError('Format problem for gold line %d.', n)

        for n, line in enumerate(pred_lines):
            parts = line.split('\t')
            if len(parts) == 2:
                if int(parts[0]) in data_dic:
                    try:
                        data_dic[int(parts[0])].append(literal_eval(parts[1]))
                    except ValueError:
                        # Invalid predictions are replaced by a default value
                        data_dic[int(parts[0])].append([])
                else:
                    raise ValueError('Invalid text id for pred line %d.', n)
            else:
                raise ValueError('Format problem for pred line %d.', n)

        # lists storing gold and prediction scores
        scores = []
        for id in data_dic:
            if len(data_dic[id]) == 2:
                gold_spans = data_dic[id][0]
                pred_spans = data_dic[id][1]
                scores.append(f1(pred_spans, gold_spans))
            else:
                sys.exit('Repeated id in test data.')

        return (np.mean(scores), sem(scores))

    else:
        sys.exit('Predictions and gold data have different number of lines.')



In [None]:
def _contiguous_ranges(span_list):
    """Extracts continguous runs [1, 2, 3, 5, 6, 7] -> [(1,3), (5,7)]."""
    output = []
    for _, span in itertools.groupby(
        enumerate(span_list), lambda p: p[1] - p[0]):
        span = list(span)
        output.append((span[0][1], span[-1][1]))
    return output

SPECIAL_CHARACTERS = string.whitespace
def fix_spans(spans, text, special_characters=SPECIAL_CHARACTERS):
    """Applies minor edits to trim spans and remove singletons."""
    cleaned = []
    for begin, end in _contiguous_ranges(spans):
        if end>=len(text):
            # if begin>=len(text):
            # else :
            continue
        while text[begin] in special_characters and begin < end:
            begin += 1
        while text[end] in special_characters and begin < end:
            end -= 1
        if end - begin > 1:
            cleaned.extend(range(begin, end + 1))
    return cleaned

In [None]:
def spans_to_ents(doc, spans, label):
  """Converts span indicies into spacy entity labels."""
  started = False
  left, right, ents = 0, 0, []
  for x in doc:
    if x.pos_ == 'SPACE':
      continue
    if spans.intersection(set(range(x.idx, x.idx + len(x.text)))):
      if not started:
        left, started = x.idx, True
      right = x.idx + len(x.text)
    elif started:
      ents.append((left, right, label))
      started = False
  if started:
    ents.append((left, right, label))
  return ents


def read_datafile(filename):
  """Reads csv file with python span list and text."""
  data = []
  with open(filename) as csvfile:
    reader = csv.DictReader(csvfile)
    count = 0
    for row in reader:
      fixed = fix_spans(
          ast.literal_eval(row['spans']), row['text'])
      data.append((fixed, row['text']))
  return data

In [None]:
# Read training data
print('loading training data')
train = read_datafile('train.csv')

# Read trial data for test.
print('loading test data')
test = read_datafile('validation.csv')


In [None]:
# def main():
"""Train and eval a spacy named entity tagger for toxic spans."""

# Convert training data to Spacy Entities
nlp = spacy.load("en_core_web_sm")

print('preparing training data')
training_data = []
for n, (spans, text) in enumerate(train):
  doc = nlp(text)
  ents = spans_to_ents(doc, set(spans), 'TOXIC')
  training_data.append((doc.text, {'entities': ents}))

toxic_tagging = spacy.blank('en')
toxic_tagging.vocab.strings.add('TOXIC')
ner = nlp.create_pipe("ner")
toxic_tagging.add_pipe(ner, last=True)
ner.add_label('TOXIC')
pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
unaffected_pipes = [
    pipe for pipe in toxic_tagging.pipe_names
    if pipe not in pipe_exceptions]

print('training')
with toxic_tagging.disable_pipes(*unaffected_pipes):
  toxic_tagging.begin_training()
  for iteration in range(30):
    random.shuffle(training_data)
    losses = {}
    batches = spacy.util.minibatch(
        training_data, size=spacy.util.compounding(
            4.0, 32.0, 1.001))
    for batch in batches:
      texts, annotations = zip(*batch)
      toxic_tagging.update(texts, annotations, drop=0.5, losses=losses)
    print("Losses", losses)

# Score on validation data.
    print('evaluation')
    scores = []
    for spans, text in test:
        pred_spans = []
        doc = toxic_tagging(text)
        for ent in doc.ents:
            pred_spans.extend(range(ent.start_char, ent.start_char + len(ent.text)))
        score = f1(pred_spans, spans)
        scores.append(score)
    print('avg F1 %g' % statistics.mean(scores))
    if(accu<statistics.mean(scores)):
        accu=statistics.mean(scores)
        toxic_tagging.to_disk("./drive/My Drive/best")

# if __name__ == '__main__':
#   main()

loading training data
loading test data
preparing training data
training
Losses {'ner': 25542.330320830177}
evaluation
avg F1 0.561472
Losses {'ner': 21688.396416402713}
evaluation
avg F1 0.487952
Losses {'ner': 21494.040469275318}
evaluation
avg F1 0.540415
Losses {'ner': 20818.088185393455}
evaluation
avg F1 0.595986
Losses {'ner': 19128.068120488657}
evaluation
avg F1 0.544117
Losses {'ner': 19222.784765028082}
evaluation
avg F1 0.557138
Losses {'ner': 18748.401597556338}
evaluation
avg F1 0.587888
Losses {'ner': 17869.124502382554}
evaluation
avg F1 0.586877
Losses {'ner': 17297.707755795396}
evaluation
avg F1 0.550164
Losses {'ner': 17737.55657736829}
evaluation
avg F1 0.601662
Losses {'ner': 16987.667335124912}
evaluation
avg F1 0.562182
Losses {'ner': 16495.107530350273}
evaluation
avg F1 0.570561
Losses {'ner': 15976.735554215864}
evaluation
avg F1 0.599318
Losses {'ner': 16003.423004134667}
evaluation
avg F1 0.565677
Losses {'ner': 15314.599330999636}
evaluation
avg F1 0.55773

In [None]:
# toxic_tagging.to_disk("besti")
test = read_datafile('trial.csv')

In [None]:
from spacy.lang.en import Language
nlpi = spacy.load("./drive/My Drive/best")

In [None]:
print('evaluation')
scores = []
for spans, text in test:
    pred_spans = []
    doc = nlpi(text)
    for ent in doc.ents:
        pred_spans.extend(range(ent.start_char, ent.start_char + len(ent.text)))
    score = f1(pred_spans, spans)
    scores.append(score)
print('avg F1 %g' % statistics.mean(scores))

evaluation
avg F1 0.663341
