In [None]:
import sys
sys.path.append("../")
import utils

import metrics_wrapper

In [None]:
METRIC = "comet20"
# METRIC = "chrf"
PATTERNS_REF = ["R1", "R2", "R4", "R3"]
metric_score = {
    pattern_ref:utils.load_metric_scores(
            f"../../computed/metric_scores_{METRIC}.json",
            pattern_ref=pattern_ref,
            aggregate="average",
            refids_path="../../computed/metric_scores_none.json"
        )
    for pattern_ref in PATTERNS_REF
}

In [None]:

import copy
import random
random.seed(0)

def select_subset(S, B, quality_lambda, temp=1):
    assert type(S) is set

    S = copy.deepcopy(S)

    def cost(R):
        return 1*len(R[1])+1.5*len(R[2])+2*len(R[3])+2.5*len(R[4])

    def utility(i):
        return {
            1: 1/1,
            2: 2/1,
            3: 3/2,
            4: 2/3,
        }[i]

    R = {
        1: S,
        2: set(),
        3: set(),
        4: set(),
    }

    out = copy.deepcopy(R)
    while cost(R) < B:
        out = copy.deepcopy(R)
        operation = random.choices(
            population=["REPLACE", "ADD"],
            weights=[quality_lambda, 1-quality_lambda],
            k=1
        )[0]
        candidates_new = [
            (x, i, utility(i))
            for i in R.keys()
            # what could be added to R[i]?
            for x in S-R[i]
        ]
        candidates_old = [
            (x, i, utility(i))
            for i in R.keys()
            # what could be removed from R[i]?
            for x in R[i]
        ]

        if not candidates_new or not candidates_old:
            break
        
        patience = 0
        if operation == "REPLACE":
            candidate_old = random.choices(
                population=candidates_old,
                weights=[x[2]**(-1/temp) for x in candidates_old],
                k=1,
            )[0]

            # filter candidates to have a higher utility and be the same segment but from a different vendor
            candidates_new = [
                x for x in candidates_new
                if x[2] >= candidate_old[2] and x[0] == candidate_old[0] and x[1] != candidate_old[1]
            ]

            if not candidates_new:
                patience += 1
                if patience >= 10:
                    break
                continue
            else:
                patience = 0

            candidate_new = random.choices(
                population=candidates_new,
                weights=[x[2]**(1/temp) for x in candidates_new],
                k=1,
            )[0]
            
            # commit transaction
            R[candidate_new[1]].add(candidate_new[0])
            R[candidate_old[1]].remove(candidate_old[0])

        elif operation == "ADD":
            candidate_new = random.choices(
                population=candidates_new,
                weights=[x[2]**(1/temp) for x in candidates_new],
                k=1,
            )[0]

            # commit transaction
            R[candidate_new[1]].add(candidate_new[0])
        else:
            raise Exception("Unknown operation")

    return out

In [None]:
import numpy as np
import collections
import tqdm

data_wmt = utils.load_wmt(annotation_path="../../data/annotations.json", wmt_path="../../data/data_tmp/")

S = set(x["src"] for x in data_wmt)

def evaluate_subset(R):
    # maps src -> reference vendor
    R_map = collections.defaultdict(list)
    for i, values in R.items():
        for x in values:
            R_map[x].append(i)
    R_map = {
        x: list(values)
        for x, values in R_map.items()
    }

    data_wmt_local = [
        {
            "src": x["src"],
            "tgt": x["tgt"],
            "system": x["system"],
            "human": x["score"],
            "metric": np.average([
                metric_score[f"R{i}"][(x["src"], x["tgt"])]
                for i in R_map[x["src"]]
            ])
        }
        for x in data_wmt
    ]
    corr, _ = utils.compute_segment_tau(data_wmt_local)
    return corr

Bs = [160, 160*2, 160*3, 160*4, 160*5, 160*6]
QUALITY_LAMBDAs = np.linspace(0.001, 0.999, 10)
CORRs = {}
for quality_lambda in tqdm.tqdm(QUALITY_LAMBDAs):
    for B in Bs:
        corrs = []
        for _ in range(10):
            R = select_subset(
                S=S,
                B=B,
                quality_lambda=quality_lambda,
                temp=1
            )
            corr = evaluate_subset(R)
            corrs.append(corr)
        CORRs[(B, quality_lambda)] = np.average(corrs)

In [None]:
import matplotlib.pyplot as plt

img = np.zeros((len(Bs), len(QUALITY_LAMBDAs)))

for b_i, b_v in enumerate(Bs):
    for l_i, l_v in enumerate(QUALITY_LAMBDAs):
        img[b_i, l_i] = CORRs[(b_v, l_v)]


new_img = []
for row in img:
    row_min = min(row)-0.01
    row_max = max(row)+0.01
    # normalize
    # new_img.append((row - row_min)/(row_max-row_min))
    new_img.append(row)

img = np.array(new_img).T

plt.figure(figsize=(1.9, 2))
for col_i, col in enumerate(img.T):
    if col_i == 0:
        continue
    row_i = np.argmax(col)
    plt.scatter(
        x=[col_i], y=[row_i],
        marker="*", s=50,
        color="black"
    )


plt.imshow(
    img,
    cmap="gray",
    aspect="auto",
    vmin=0.07 if METRIC == "chrf" else 0.165,
    vmax=0.15 if METRIC == "chrf" else 0.192,
)
a = plt.colorbar(
    aspect=10, pad=0.03,
    ticks=
        [0.08, 0.11, 0.14] if METRIC == "chrf" else
        None
)
a.outline.set_linewidth(0)



ax = plt.gca()
plt.title(metrics_wrapper.METRIC_NAMES[METRIC], fontsize=12)
plt.xlabel("Budget", labelpad=-5)
plt.xticks(
    list(range(len(Bs)))[::5],
    [f"${b//160}|S|$" for b in Bs][::5]
)
plt.yticks(
    [1, len(QUALITY_LAMBDAs)//2-0.5, len(QUALITY_LAMBDAs)-1-1],
    ["$\\rightarrow$Quantity", "$\\lambda$", "Quality$\\leftarrow$"],
    rotation=90,
    va="center",
    fontsize=9,
)

ax.tick_params(axis='y', which='major', pad=-1)
ax.tick_params(left=False, bottom=True)
ax.spines[['left', 'right', 'top', 'bottom']].set_visible(False)
plt.tight_layout(pad=0)

plt.savefig(f"../../computed/budget_allocation_{METRIC}.pdf")