In [None]:
import pandas as pd
import json
import pickle
import matplotlib.pyplot as plt
import torch
import random
import torch.nn.functional as F

In [None]:
with open("expression_db.json") as f:
    db = json.load(f)

reverse_database = {(value[0], value[1]) : key for key, value in db.items()}

In [None]:
all_premises = []
premise_labels = []
premise_idx = []
# take most frequent libraries for analysis
for i, (key, value) in enumerate(db.items()):
    if value[0] == "pred_set" or value[0] == "relation" or value[0] == "list":# or value[0] == "rich_list" or value[0] == "prim_rec": #or value[0] == "pred_set" or value[0] == "set":
        all_premises.append(key)
        premise_labels.append(value[0])
        premise_idx.append(i)

In [None]:

with open('vanilla_embs.pk', 'rb') as f:
    original_embeddings = pickle.load(f)

with open('gnn_embs.pk', 'rb') as f:
    gnn_embeddings = pickle.load(f)


In [None]:
from sklearn.manifold import TSNE
import plotly.express as px
def plot_reps(reps, labels, n_components, graph_name):
    tsne = TSNE(n_components)
    tsne_result = tsne.fit_transform(reps.detach().cpu().numpy())

    color_dict = {}

    i = 0

    labels_ = list(set(labels))
    for label in labels_:
        color_dict[label] = i
        i += 1

    color_dict["POI1"] = i + 1
    color_dict["POI2"] = i + 1
    color_dict["POI3"] = i + 1
    color_dict["POI4"] = i + 1

    if n_components == 3:

        df_list = []
        for i, premise in enumerate(all_premises):
            # set custom markers for points of interest
            if i == 1255 or i == 1256 or i == 1290 or i == 1291:
                labels[i] = "POI1"
            if i == 403 or i == 418:
                labels[i] = "POI2"
            if i == 497 or i == 387:
                labels[i] = "POI3"
            if i == 1097 or i == 810 or i == 591:
                labels[i] = "POI4"
            df_list.append({'Premise': db[premise][-1], 'X': tsne_result[i,0], "Y": tsne_result[i,1], "Z": tsne_result[i,2], "Thm": premise_labels[i], "Idx": i})

        df = pd.DataFrame.from_records(df_list)


        fig = px.scatter_3d(df, x = 'X', y = 'Y', z = 'Z', color = 'Thm', hover_name = "Premise", hover_data = ["Thm", "Idx"])
        fig.update_traces(marker_size = 3)

        fig.write_html(graph_name+".html", auto_open=True)
    elif n_components == 2:
        ax = plt.figure(figsize=(16,10)).gca()
        scatter = ax.scatter(x = tsne_result[:,0], y = tsne_result[:,1], c = [color_dict[l] for l in labels])#, label = labels_th)

        legend1 = ax.legend(*(scatter.legend_elements()[0], labels_),
                        loc="lower left", title="Classes", fontsize=16)
        ax.add_artist(legend1)
        plt.savefig(graph_name, format='pdf')

        plt.show()
    else:
        raise NotImplementedError

    return tsne_result, premise_labels



In [None]:
vanilla_tsne, v_labs = plot_reps(original_embeddings, premise_labels, 2, "original_premise_embeddings.pdf")

In [None]:
gnn_tsne, g_labs = plot_reps(gnn_embeddings, premise_labels, 2, "gnn_premise_embeddings.pdf")

In [None]:
vanilla_tsne, v_labs = plot_reps(original_embeddings, premise_labels, 3, "original_premise_embeddings")

In [None]:
gnn_tsne, g_labs = plot_reps(gnn_embeddings, premise_labels, 3, "gnn_premise_embeddings")

In [None]:
cos = lambda m: F.normalize(m) @ F.normalize(m).t()

In [None]:
original_cosine_closest = {}
gnn_cosine_closest = {}
def get_closest(ind):
    inds = torch.topk(cos(original_embeddings)[ind], 5)[1][1:]
    original_cosine_closest[db[all_premises[ind]][-1]] = [db[all_premises[x]][-1] for x in inds]
    inds = torch.topk(cos(gnn_embeddings)[ind], 5)[1][1:]
    gnn_cosine_closest[db[all_premises[ind]][-1]] = [db[all_premises[x]][-1] for x in inds]


In [None]:
for i in range(len(all_premises)):
    get_closest(i)
    

In [None]:
gnn_cosine_closest

In [None]:
original_cosine_closest

In [None]:

examples = []
inds = [i for i,x in enumerate(all_premises)]
random.shuffle(inds)

num_premises = 20

# only print expressions under 100 characters for readability
short_only = True

j = 0
for i in inds:
    if j == num_premises:
        break
    prem = db[all_premises[i]][-1]

    #
    if short_only:
        if len(gnn_cosine_closest[prem][0]) < 100 and len(original_cosine_closest[prem][0]) < 100 and len(prem) < 100:
            examples.append((prem, gnn_cosine_closest[prem][0], original_cosine_closest[prem][0]))
            print(f"Expression: {prem}\n GNN: {gnn_cosine_closest[prem][0]}\n Original: {original_cosine_closest[prem][0]}\n")
            j += 1
    else:
        examples.append((prem, gnn_cosine_closest[prem][0], original_cosine_closest[prem][0]))
        print(f"Expression: {prem}\n GNN: {gnn_cosine_closest[prem][0]}\n Original: {original_cosine_closest[prem][0]}\n")
        j += 1



In [None]:
## Convert unicode expressions to latex
# from pylatexenc.latexencode import UnicodeToLatexEncoder
#
# u = UnicodeToLatexEncoder(unknown_char_policy='replace')
# print(u.unicode_to_latex('(R1 :α -> β -> bool) ∩ᵣ (R2 :α -> β -> bool) = R2 ∩ᵣ R1')))
#
# latex_examples = [u.unicode_to_latex(r) for r in examples]