In [1]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root

os.chdir(proj_root())

# gpu_id = 1
# modeldir = get_model_dir()

# model_name="model-v6--TrainingConfig(drop_env_types=False)"
model_name="model-v7--TrainingConfig(drop_env_types=False, add_implicit_rel_imports=True)"
# dataset_name = "ManyTypes4Py"
dataset_name = "InferTypes4Py"
# dataset_name = "SPOT-src"

In [2]:
from spot.utils import *
from spot.function_dataset import data_project_from_dir

# load test projects
repos_dir = get_dataset_dir(dataset_name) / "repos" / "test"
test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]
test_projects = pmap(
    data_project_from_dir,
    test_repo_paths,
    desc="Loading test projects",
)
assert len(test_projects) > 0


Loading test projects: 100%|██████████| 3/3 [00:11<00:00,  3.86s/it]


In [4]:
from spot.function_decoding import DecodingOrders, EvalResult
import re

results_dir = get_eval_dir(dataset_name, "(implicit_imports) " + model_name)

decode_orders = [m.group(1) for f in results_dir.iterdir() if (m:=re.match(r"(.+)-EvalResult\.pkl", f.name)) is not None]
decode_orders.sort()

evals = dict[str, EvalResult]()

for oname in tqdm(decode_orders, desc="Loading evaluation results"):
    evals[oname] = pickle_load(results_dir / f"{oname}-EvalResult.pkl")


Loading evaluation results: 100%|██████████| 6/6 [00:01<00:00,  4.00it/s]


In [5]:
from prettytable import PrettyTable
import prettytable as pt
from spot.type_env import AccuracyMetric

common_type_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)
metrics = AccuracyMetric.default_metrics(common_type_names)
results_table = PrettyTable()
results_table.field_names = ["order", *(m.name for m in metrics)]
results_table.align = "r"
results_table.set_style(pt.SINGLE_BORDER)
results_table.float_format = ".4"

for oname in decode_orders:
    accs = [evals[oname].error_analysis(None, metric).accuracies[metric.name].acc for metric in metrics]
    results_table.add_row([oname, *accs])

print(results_table)
# write_file(results_dir / "comparison.txt", results_table.get_string())

┌──────────────────┬───────────┬────────┬──────────┐
│            order │ plain_acc │    acc │ base_acc │
├──────────────────┼───────────┼────────┼──────────┤
│    callee2caller │    0.7118 │ 0.7164 │   0.7891 │
│    caller2callee │    0.6754 │ 0.6870 │   0.7630 │
│ double-traversal │    0.7122 │ 0.7168 │   0.7869 │
│     no-neighbors │    0.6529 │ 0.6576 │   0.7274 │
│         non-incr │    0.6838 │ 0.6907 │   0.7682 │
│           random │    0.6896 │ 0.6980 │   0.7663 │
└──────────────────┴───────────┴────────┴──────────┘


In [7]:

projects = [p.root_dir.name for p in test_projects]
acc_metric = AccuracyMetric(common_type_names)

strategies_to_show = ["non-incr", "random", "double-traversal"]
strategy2accs = {
    s: [
        evals[s].error_analysis(pname, acc_metric).accuracies[acc_metric.name]
        for pname in projects
    ]
    for s in strategies_to_show
}

n_annots = [
    sum(e.get_signature().n_annots() for e in p.all_elems()) for p in test_projects
]
# n_labels = [
#     sum(e.get_signature().n_annotated() for e in p.all_elems()) for p in test_projects
# ]
n_labels = [x.n_total for x in strategy2accs[strategies_to_show[0]]]
label_sizes = [
    evals[strategies_to_show[0]]
    .error_analysis(pname, acc_metric)
    .accuracies[f"{acc_metric.name}_label_size"]
    for pname in projects
]

label_rates = [a / b for a, b in zip(n_labels, n_annots)]

df = pd.DataFrame(
    {
        "project": projects,
        **{n: [x.acc for x in xs] for n, xs in strategy2accs.items()},
        "label_size": label_sizes,
        "label_rate": label_rates,
        "labels": n_labels,
    }
)
df_sorted = df.sort_values(by=["labels"], ascending=False)
display(df_sorted)


Unnamed: 0,project,non-incr,random,double-traversal,label_size,label_rate,labels
0,SPOT,0.609185,0.622184,0.636915,1.540728,0.682436,1154
1,typilus,0.757455,0.760437,0.817097,1.883698,0.47882,1006
2,type4py,0.738434,0.741993,0.701068,1.224199,0.313441,562


In [9]:
df_sorted.to_csv(proj_root() / "data" / "accs_by_project.csv")


In [8]:
ex_proj = test_projects[2]
print("Project location:", ex_proj.root_dir)

evalr = evals["double-traversal"]
pid = evalr.project_roots.index(ex_proj.root_dir)
ex_analysis = evalr.error_analysis(pid, acc_metric)
ex_errors = ex_analysis.errors[ex_proj.root_dir.name]
ex_rollout = evalr.predictions[pid]

print("Number of errors:", len(ex_errors))


Project location: /mnt/data0/jiayi/SPOT/datasets/InferTypes4Py/repos/test/type4py
Number of errors: 168


In [9]:
# under predicted types sorted by frequency
under_pred_types = Counter()
over_predicted_types = Counter()
for pred in ex_errors:
    under_pred_types[acc_metric.process_type(pred.expected)] += 1
    over_predicted_types[acc_metric.process_type(pred.predicted)] += 1

print("Under-predicted types:")
display(under_pred_types.most_common(10))
print("Over-predicted types:")
display(over_predicted_types.most_common(10))


Under-predicted types:


[(ty'pd.DataFrame', 18),
 (ty'np.array', 14),
 (ty'Dict', 10),
 (ty'List', 7),
 (ty'Set', 7),
 (ty'cst.Name', 6),
 (ty'str', 6),
 (ty'cst.ClassDef', 5),
 (ty'cst.FunctionDef', 5),
 (ty'List[List]', 4)]

Over-predicted types:


[(ty'Dict', 20),
 (ty'str', 15),
 (ty'List[str]', 15),
 (ty'cst.CSTNode', 14),
 (ty'Union[cst.Module, str]', 11),
 (ty'Path', 8),
 (ty'np.ndarray', 7),
 (ty'Dict[str, str]', 6),
 (ty'List', 6),
 (ty'Union[cst.Expr, str]', 3)]

In [13]:
from spot.static_analysis import ProjectPath

error_id = 35
pred = ex_errors[error_id]
print("----------------------")
print("Elem Path:", pred.path)
print("Type Index:", pred.index)
evalr.inspect_elem(pid, pred.path)
# evalr.inspect_elem(pid, ProjectPath.from_str("tests.factories/AllScanCommandsAttemptsFactory.create"))


----------------------
Elem Path: libsa4py.cst_transformers/TypeApplier.__get_cls_vars
Type Index: 1
Expected:  MethodSig((var_name: str) -> dict)
Predicted: MethodSig((var_name: str) -> str)
Code:
from collections import Counter
from libsa4py import PY_TYPING_MOD, PY_COLLECTION_MOD
from itertools import chain
from libsa4py.nl_preprocessing import NLPreprocessor
import libcst as cst
import libcst.matchers as match
import re
import regex
from typing import Union, Dict, Tuple, List, Optional
class CommentAndDocStringRemover(cst.CSTTransformer):
   ...
class StringRemover(cst.CSTTransformer):
   ...
class NumberRemover(cst.CSTTransformer):
   ...
class TypeAnnotationRemover(cst.CSTTransformer):
   ...
class TypeAdder(cst.CSTTransformer):
   ...
class SpaceAdder(cst.CSTTransformer):
   ...
class TypeQualifierResolver(cst.CSTTransformer):
   ...
class ParametricTypeDepthReducer(cst.CSTTransformer):
   ...
class TypeApplier(cst.CSTTransformer):
   ...
# libsa4py.cst_transformers
class TypeAp