### Visualization of *Capital*, *Language*, *N Neighbor* with invariant subject
This script seems to visualize the distances between capital, language, and northern neighbor prompt predictions (object-retrieval). At the end there is code to calculate the Jacobian for each f(s,c), which isn't used yet.

In [1]:
import sys
sys.path.append("..")

In [2]:
!nvidia-smi

zsh:1: command not found: nvidia-smi


In [None]:
from src import models

device = "cuda:5"
mt = models.load_model("gptj", device=device)

In [None]:
def comma_sep_lines_to_pairs(string):
    return [
        [x.strip() for x in line.split(",")]
        for line in string.split("\n")
    ]

CAPITOLS = comma_sep_lines_to_pairs(
    """\
United States, Washington D.C.
Canada, Ottawa
Mexico, Mexico City
Brazil, Brasília
Argentina, Buenos Aires
Chile, Santiago
Peru, Lima
Colombia, Bogotá
Venezuela, Caracas
Spain, Madrid
France, Paris
Germany, Berlin
Italy, Rome
Russia, Moscow
China, Beijing
Japan, Tokyo
South Korea, Seoul
India, New Delhi
Pakistan, Islamabad
Nigeria, Abuja
Egypt, Cairo
Saudi Arabia, Riyadh
Turkey, Ankara
Australia, Canberra""")

LANGUAGES = comma_sep_lines_to_pairs("""\
United States, English
Canada, English and French
Mexico, Spanish
Brazil, Portuguese
Argentina, Spanish
Chile, Spanish
Peru, Spanish
Colombia, Spanish
Venezuela, Spanish
Spain, Spanish
France, French
Germany, German
Italy, Italian
Russia, Russian
China, Mandarin Chinese
Japan, Japanese
South Korea, Korean
India, Hindi
Pakistan, Urdu
Nigeria, English
Egypt, Arabic
Saudi Arabia, Arabic
Turkey, Turkish
Australia, English""")

BORDER_NORTH = comma_sep_lines_to_pairs("""\
United States, Canada
Mexico, United States
Brazil, Venezuela
Argentina, Bolivia
Chile, Peru
Peru, Ecuador
Colombia, Venezuela
Venezuela, Colombia
Spain, France
France, Germany
Germany, Denmark
Italy, Switzerland
Russia, Kazakhstan
China, Russia
South Korea, North Korea
India, China
Pakistan, Afghanistan
South Africa, Namibia
Egypt, Libya
Saudi Arabia, Iraq
Turkey, Bulgaria""")

BORDER_SOUTH = comma_sep_lines_to_pairs("""\
United States, Mexico
Canada, United States
Mexico, Guatemala
Brazil, Bolivia
Argentina, Chile
Chile, Argentina
Peru, Chile
Colombia, Ecuador
Venezuela, Brazil
France, Spain
Germany, Switzerland
Russia, Georgia
Nigeria, Cameroon
South Africa, Lesotho
Egypt, Sudan
Saudi Arabia, Yemen
Turkey, Syria""")

BORDER_NORTH

In [None]:
def line_sep_prompts(string):
    return [line.strip().replace("[", "{").replace("]", "}") for line in string.split("\n")]

LANGUAGE_PROMPTS = line_sep_prompts("""\
[country] is a country where the language of [language] is spoken.
The people of [country] communicate in [language].
[country] is home to speakers of [language].
The people in [country] converse using the language of [language].
The inhabitants of [country] use [language] to communicate.
In [country], the language primarily spoken is [language].""")

CAPITOL_PROMPTS = line_sep_prompts("""\
The capital city of [country] is [city].
[country] is home to the capital city of [city].
The political capital of [country] is [city].
The seat of government for [country] is [city].
The government of [country] is centered in [city].""")

NORTH_PROMPTS = line_sep_prompts("""\
The northern frontier of [country] meets that of [other].
[country] lies to the north of [other].
The northerly boundary of [country] is shared with [other].
[country]'s northern flank abuts [other].
[country]'s northernmost point touches [other].
The northernmost part of [country] adjoins [other].
To the north, [country] is contiguous with [other].
The northern edge of [country] meets [other].
[country]'s northern line of demarcation is with [other].
The northern boundary of [country] is contiguous with [other].""")

SOUTH_PROMPTS = line_sep_prompts("""\
The southern frontier of [country] meets that of [country].
[country]'s southern border abuts [country].
[country] lies to the south of [country].
[country]'s southern flank meets [country].
The southernmost point of [country] borders [country].
The southern edge of [country] meets [country].
The southern line of demarcation of [country] is shared with [country].
[country]'s southern boundary is contiguous with [country].""")

NORTH_PROMPTS

### Estimated Relation Operator
For estimation of capital from country, language from country, and northern neighbor from country. 

In [None]:
from src import estimate

from tqdm.auto import tqdm

ops_capitols = {}
ops_languages = {}
ops_north = {}

for ops, prompts, samples in (
    (ops_capitols, CAPITOL_PROMPTS, CAPITOLS),
    (ops_languages, LANGUAGE_PROMPTS, LANGUAGES),
    (ops_north, NORTH_PROMPTS, BORDER_NORTH),
):
    for prompt in prompts:
        for subject, _ in tqdm(samples, desc=prompt):
            prompt = (
                prompt
                    .split("{city}")[0]
                    .split("{language}")[0]
                    .split("{other}")[0]
                    .replace("{country}", "{}")
                    .rstrip(". ")
            )
            operator = estimate.relation_operator_from_sample(
                mt.model,
                mt.tokenizer,
                subject,
                prompt,
                device=device,
            )
            ops[prompt, subject] = operator

In [None]:
from collections import defaultdict

CATEGORIES = {
    "language": ops_languages,
    "capitol": ops_capitols,
    "north": ops_north,
}

dists = defaultdict(lambda: defaultdict(list))
for c1, ops1 in tqdm(CATEGORIES.items()):
    for c2, ops2 in CATEGORIES.items():
        for (p1, s1), (o1, m1) in ops1.items():
            for (p2, s2), (o2, m2) in ops2.items():
                #if prompt is different and subject is the same, get distances
                if p1 == p2:
                    continue
                if s1 != s2:
                    continue
                dist = o1.weight.sub(o2.weight).norm().item()
#                 dists[c1][c2].append(dist)
                dists[c1][c2].append((p1, p2, s1, dist))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

data = np.array([
    [np.mean([x[-1] for x in dists[c1][c2]]) for c2 in CATEGORIES]
    for c1 in CATEGORIES
])

sns.heatmap(
    data=data,
    xticklabels=list(CATEGORIES),
    yticklabels=list(CATEGORIES),
    vmin=0,
    vmax=data.max(),
    annot=True,
    fmt=".2f"
)

Above shows average distances, but let's show classification accuracy instead.

In [None]:
dists = defaultdict(lambda: defaultdict(list))
for c1, ops1 in tqdm(CATEGORIES.items()):
    for c2, ops2 in CATEGORIES.items():
        for (p1, s1), (o1, m1) in ops1.items():
            for (p2, s2), (o2, m2) in ops2.items():
                if p1 == p2:
                    continue
                dist = o1.weight.sub(o2.weight).pow(2).sum().item()
                dists[c1][p1].append((c2, p2, s2, dist))

In [None]:
scores = defaultdict(lambda: defaultdict(int))
for c1 in dists:
    for p1, ds in dists[c1].items():
        ordered = sorted(ds, key=lambda x: x[-1])
        best = ordered[0][0]
        scores[c1][best] += 1

accuracies = {
    c1: {
        c2: count / sum(counts.values())
        for c2, count in counts.items()
    }
    for c1, counts in scores.items()
}

accuracies

In [None]:
data = np.array([
    [accuracies.get(c1, {}).get(c2, 0) for c2 in CATEGORIES]
    for c1 in CATEGORIES
])

plt.title("Classification Accuracy")
sns.heatmap(
    data=data,
    xticklabels=list(CATEGORIES),
    yticklabels=list(CATEGORIES),
    vmin=0.0,
    vmax=1.0,
    fmt=".2f",
    annot=True,
)

What if we condition on J's accuracy?

In [None]:
PAIRS_BY_CATEGORY = {
    "language": LANGUAGES,
    "capitol": CAPITOLS,
    "north": BORDER_NORTH,
}

def compute_accuracy(category, prompt, subject):
    operator = CATEGORIES[category][prompt, subject]
    dataset = [x for x in PAIRS_BY_CATEGORY[category] if x[0] != subject]
    n_correct = 0
    for s, t in dataset:
        os = [o[0].lower().strip() for o in operator(s, device=device)]
        n_correct += any(t.lower().strip().startswith(o) for o in os)
    return n_correct / len(dataset)

