In [1]:
import pickle
import torch

In [2]:
def readgold(path):
    pairs = open(path, 'r').readlines()
    pairs = [[int(p) for p in ps.split('\t')] for ps in pairs]
    return torch.LongTensor(pairs)

In [3]:
def maxfp_heuristic(scores, preds, gold):
    tp_first_score = scores[gold[:,0],0][preds[gold[:,0],0] == gold[:,1]]
    tp_second_score = scores[gold[:,0],1][preds[gold[:,0],0] == gold[:,1]]
    fp_second_score = scores[gold[:,0],1][preds[gold[:,0],0] != gold[:,1]]
    fp_first_score = scores[gold[:,0],0][preds[gold[:,0],0] != gold[:,1]]
    tp_diff = tp_first_score - tp_second_score
    fp_diff = fp_first_score - fp_second_score
    thresh = max(fp_diff)
    return thresh

In [4]:
def f1score(scores, preds, gold, thresh):
    diff = scores[:,0] - scores[:,1]
    predictions = -torch.ones(len(preds)).long()
    predictions[diff > thresh] = preds[diff > thresh,0]
    gt = -torch.ones(len(preds)).long()
    gt[gold[:,0]] = gold[:,1]
    precision = (predictions == gt).float().mean()
    tp = (predictions[gold[:,0]] != -1).sum()
    fp = (predictions[gold[:,0]] == -1).sum()
    recall = tp.float()/(tp+fp)
    f1 = 2*precision*recall/(precision+recall)
    return precision, recall, f1

def print_performance(scores, preds, gold, usethresh):
    if usethresh:
        prec, rec, f1 = f1score(scores, preds, gold, usethresh)
        print("thresh: %.2f, precision: %.2f, recall: %.2f, f1: %.2f" %(usethresh, prec, rec, f1))
        return
    for thr in torch.arange(0,0.2,0.01):
        prec, rec, f1 = f1score(scores, preds, gold, thr.item())
        print("thresh: %.2f, precision: %.2f, recall: %.2f, f1: %.2f" %(thr.item(), prec, rec, f1))


In [5]:
goldfile = 'bucc2018.ru-en.gold'
predfile = 'ru-en.training.scores.csls'
gold = readgold(goldfile)
scores, preds = pickle.load(open(predfile, 'rb'))
scores, preds = scores.squeeze(), preds.squeeze()

In [6]:
print_performance(scores, preds, gold, None)

thresh: 0.00, precision: 0.03, recall: 1.00, f1: 0.06
thresh: 0.01, precision: 0.38, recall: 0.91, f1: 0.53
thresh: 0.02, precision: 0.60, recall: 0.85, f1: 0.70
thresh: 0.03, precision: 0.74, recall: 0.80, f1: 0.77
thresh: 0.04, precision: 0.83, recall: 0.75, f1: 0.79
thresh: 0.05, precision: 0.89, recall: 0.72, f1: 0.80
thresh: 0.06, precision: 0.93, recall: 0.68, f1: 0.79
thresh: 0.07, precision: 0.95, recall: 0.65, f1: 0.77
thresh: 0.08, precision: 0.96, recall: 0.61, f1: 0.75
thresh: 0.09, precision: 0.97, recall: 0.58, f1: 0.72
thresh: 0.10, precision: 0.97, recall: 0.54, f1: 0.69
thresh: 0.11, precision: 0.98, recall: 0.50, f1: 0.66
thresh: 0.12, precision: 0.98, recall: 0.46, f1: 0.62
thresh: 0.13, precision: 0.98, recall: 0.42, f1: 0.59
thresh: 0.14, precision: 0.98, recall: 0.38, f1: 0.55
thresh: 0.15, precision: 0.98, recall: 0.34, f1: 0.51
thresh: 0.16, precision: 0.98, recall: 0.31, f1: 0.47
thresh: 0.17, precision: 0.98, recall: 0.27, f1: 0.43
thresh: 0.18, precision: 0.9

thresh = .06

In [264]:
print_performance(scores, preds, gold, 0.06)

thresh: 0.06, precision: 0.97, recall: 0.73, f1: 0.83
