In [7]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root

os.chdir(proj_root())

# gpu_id = 1
# modeldir = get_model_dir()

model_name="model-v5--TrainingConfig(drop_env_types=False)"
# dataset_name = "ManyTypes4Py"
dataset_name = "SPOT-src"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
from spot.utils import *
from spot.function_dataset import data_project_from_dir

# load test projects
repos_dir = get_dataset_dir(dataset_name) / "repos" / "test"
test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]
test_projects = pmap(
    data_project_from_dir,
    test_repo_paths,
    desc="Loading test projects",
)
assert len(test_projects) > 0


Loading test projects: 100%|██████████| 1/1 [00:06<00:00,  6.32s/it]


In [9]:
from spot.function_decoding import DecodingOrders, EvalResult

results_dir = get_eval_dir(dataset_name, model_name)

decode_orders = {
    "no-neighbors": DecodingOrders.IndependentOrder(),
    "non-incr": DecodingOrders.IndependentOrder(),
    "random": DecodingOrders.RandomOrder(),
    "double-traversal": DecodingOrders.DoubleTraversal(),
    "callee2caller": DecodingOrders.Callee2Caller(),
    "caller2callee": DecodingOrders.Caller2Callee(),
}

evals = dict[str, EvalResult]()

for oname in tqdm(decode_orders, desc="Loading evaluation results"):
    evals[oname] = pickle_load(results_dir / f"{oname}-EvalResult.pkl")


Loading evaluation results: 100%|██████████| 6/6 [00:00<00:00,  8.11it/s]


In [10]:
from prettytable import PrettyTable
import prettytable as pt
from spot.type_env import AccuracyMetric

common_type_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)
metrics = AccuracyMetric.default_metrics(common_type_names)
results_table = PrettyTable()
results_table.field_names = ["order", *(m.name for m in metrics)]
results_table.align = "r"
results_table.set_style(pt.SINGLE_BORDER)
results_table.float_format = ".4"

for oname in decode_orders:
    accs = [evals[oname].error_analysis(None, metric).accuracies[metric.name].acc for metric in metrics]
    results_table.add_row([oname, *accs])

print(results_table)
write_file(results_dir / "comparison.txt", results_table.get_string())

┌──────────────────┬───────────┬────────┬──────────┐
│            order │ plain_acc │    acc │ base_acc │
├──────────────────┼───────────┼────────┼──────────┤
│     no-neighbors │    0.5046 │ 0.5117 │   0.6017 │
│         non-incr │    0.5905 │ 0.6092 │   0.7048 │
│           random │    0.5932 │ 0.6139 │   0.7151 │
│ double-traversal │    0.6060 │ 0.6326 │   0.7451 │
│    callee2caller │    0.6060 │ 0.6317 │   0.7488 │
│    caller2callee │    0.5850 │ 0.6111 │   0.7095 │
└──────────────────┴───────────┴────────┴──────────┘


In [None]:
import pandas as pd
from spot.type_env import AccuracyMetric


common_type_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)
metrics = AccuracyMetric.default_metrics(common_type_names)
strategy2acc = {
    s: evals[s].error_analysis(None, acc_metric).accuracies[acc_metric.name]
    for s in evals
}

In [23]:

projects = [p.root_dir.name for p in test_projects]
acc_metric = AccuracyMetric(common_type_names)

strategies_to_show = ["no-neighbors", "non-incr", "double-traversal"]
strategy2accs = {
    s: [
        evals[s].error_analysis(pname, acc_metric).accuracies[acc_metric.name]
        for pname in projects
    ]
    for s in strategies_to_show
}

n_annots = [
    sum(e.get_signature().n_annots() for e in p.all_elems()) for p in test_projects
]
# n_labels = [
#     sum(e.get_signature().n_annotated() for e in p.all_elems()) for p in test_projects
# ]
n_labels = [x.n_total for x in strategy2accs[strategies_to_show[0]]]
label_sizes = [
    evals[strategies_to_show[0]]
    .error_analysis(pname, acc_metric)
    .accuracies[f"{acc_metric.name}_label_size"]
    for pname in projects
]

label_rates = [a / b for a, b in zip(n_labels, n_annots)]

df = pd.DataFrame(
    {
        "project": projects,
        **{n: [x.acc for x in xs] for n, xs in strategy2accs.items()},
        "label_size": label_sizes,
        "label_rate": label_rates,
        "labels": n_labels,
    }
)
df_sorted = df.sort_values(by=["labels"], ascending=False)
df_sorted


Unnamed: 0,project,no-neighbors,non-incr,double-traversal,label_size,label_rate,labels
17,basilisp-lang__basilisp,0.588235,0.623398,0.619454,1.243181,0.526926,3043
29,nabla-c0d3__sslyze,0.712009,0.752391,0.763018,1.269926,0.568924,941
27,kornicameister__axion,0.49096,0.50904,0.454798,1.624478,0.630149,719
1,scalableminds__webknossos-connect,0.675719,0.750799,0.776358,1.209265,0.830239,626
30,rakitaj__daily-programmer,0.737037,0.742593,0.766667,1.527778,0.500928,540
42,seattleflu__id3c,0.738189,0.832677,0.852362,1.206693,0.442509,508
13,marcosschroh__dataclasses-avroschema,0.505695,0.503417,0.476082,1.473804,0.427875,439
25,nubark__instark,0.816537,0.857881,0.857881,1.180879,0.346774,387
48,everyclass__everyclass-server,0.723724,0.78979,0.747748,1.201201,0.3729,333
35,lucaswerkmeister__tool-quickcategories,0.758065,0.787097,0.764516,1.4,0.658174,310


In [24]:
df


Unnamed: 0,project,no-neighbors,non-incr,double-traversal,label_size,label_rate,labels
0,srittau__FakeSMTPd,0.836207,0.87931,0.896552,1.232759,0.348348,116
1,scalableminds__webknossos-connect,0.675719,0.750799,0.776358,1.209265,0.830239,626
2,flopp__unicode-explorer,0.820513,0.854701,0.923077,1.145299,0.576355,117
3,eirannejad__calcatime,0.46875,0.53125,0.5,1.9375,0.8,32
4,jelford__webwatcher,0.765957,0.765957,0.765957,1.319149,0.218605,47
5,road-master__video-archiver,0.853659,0.910569,0.918699,1.162602,0.416949,123
6,flopp__GpxTrackPoster,0.549815,0.656827,0.715867,1.247232,0.654589,271
7,boompig__book-classics,0.905882,0.941176,0.941176,1.270588,0.586207,85
8,dropbox__sqlalchemy-stubs,0.595238,0.666667,0.642857,1.261905,0.608696,42
9,reddit__baseplate.py-upgrader,0.754839,0.76129,0.812903,1.264516,0.430556,155


In [27]:
df_sorted.to_csv(proj_root() / "data" / "accs_by_project.csv")


In [37]:
ex_proj = test_projects[17]
print("Project location:", ex_proj.root_dir)

evalr = evals["double-traversal"]
pid = evalr.project_roots.index(ex_proj.root_dir)
ex_analysis = evalr.error_analysis(pid, acc_metric)
ex_errors = ex_analysis.errors[ex_proj.root_dir.name]
ex_rollout = evalr.predictions[pid]

print("Number of errors:", len(ex_errors))


Project location: /mnt/data0/jiayi/SPOT/datasets/ManyTypes4Py/repos/test/basilisp-lang__basilisp
Number of errors: 1158


In [38]:
# under predicted types sorted by frequency
under_pred_types = Counter()
over_predicted_types = Counter()
for pred in ex_errors:
    under_pred_types[acc_metric.process_type(pred.expected)] += 1
    over_predicted_types[acc_metric.process_type(pred.predicted)] += 1

print("Under-predicted types:")
display(under_pred_types.most_common(10))
print("Over-predicted types:")
display(over_predicted_types.most_common(10))


Under-predicted types:


[(ty'GeneratorContext', 81),
 (ty'runtime.Namespace', 63),
 (ty'sym.Symbol', 59),
 (ty'AnalyzerContext', 58),
 (ty'ISeq', 50),
 (ty'IPersistentVector[LispForm]', 46),
 (ty'Sequence[kw.Keyword]', 45),
 (ty'IPersistentMap', 34),
 (ty'str', 31),
 (ty'atom.Atom[NamespaceMap]', 17)]

Over-predicted types:


[(ty'vec.PersistentVector', 99),
 (ty'SpecialFormNode', 97),
 (ty'str', 70),
 (ty'ast.Context', 62),
 (ty'int', 52),
 (ty'ast.AST', 44),
 (ty'Node', 37),
 (ty'T', 30),
 (ty'sym.Symbol', 28),
 (ty'NamespaceMap', 18)]

In [36]:
from spot.static_analysis import ProjectPath

error_id = 0
pred = ex_errors[error_id]
print("----------------------")
print("Elem Path:", pred.path)
print("Type Index:", pred.index)
evalr.inspect_elem(pid, pred.path)
# evalr.inspect_elem(pid, ProjectPath.from_str("runner/Runner.run_command"))


----------------------
Elem Path: command/Command.apply
Type Index: 2
Expected:  MethodSig((wikitext: str, category_info: CategoryInfo) -> Tuple[str, List[Tuple[Action, bool]]])
Predicted: MethodSig((wikitext: str, category_info: CategoryInfo) -> Tuple[str, List[Action]])
Code:
import datetime
from page import Page
from siteinfo import CategoryInfo
from abc import ABC, abstractmethod
from dataclasses import dataclass
from action import Action
from typing import List, Optional, Tuple, Union
@dataclass(...)
class Command:
   ...
@dataclass(...)
class CommandRecord(ABC):
   ...
class CommandPlan(CommandRecord):
   ...
class CommandPending(CommandRecord):
   ...
class CommandFinish(CommandRecord):
   ...
class CommandSuccess(CommandFinish):
   ...
@dataclass(...)
class CommandEdit(CommandSuccess):
   ...
@dataclass(...)
class CommandNoop(CommandSuccess):
   ...
class CommandFailure(CommandFinish):
   ...
@dataclass(...)
class CommandPageMissing(CommandFailure):
   ...
@dataclass(...)
class