In [1]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root

os.chdir(proj_root())

# gpu_id = 1
# modeldir = get_model_dir()

model_name="model-v5--TrainingConfig(drop_env_types=False)"
dataset_name = "ManyTypes4Py"

In [2]:
from spot.utils import *
from spot.function_dataset import data_project_from_dir

# load test projects
repos_dir = get_dataset_dir(dataset_name) / "repos" / "test"
test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]
test_projects = pmap(
    data_project_from_dir,
    test_repo_paths,
    desc="Loading test projects",
)
assert len(test_projects) > 0


Loading test projects: 100%|██████████| 50/50 [00:20<00:00,  2.43it/s]


In [3]:
from spot.function_decoding import DecodingOrders, EvalResult

results_dir = get_eval_dir(dataset_name, model_name)

decode_orders = {
    "no-neighbors": DecodingOrders.IndependentOrder(),
    "non-incr": DecodingOrders.IndependentOrder(),
    "random": DecodingOrders.RandomOrder(),
    "double-traversal": DecodingOrders.DoubleTraversal(),
    "callee2caller": DecodingOrders.Callee2Caller(),
    "caller2callee": DecodingOrders.Caller2Callee(),
}

evals = dict[str, EvalResult]()

for oname in tqdm(decode_orders, desc="Loading evaluation results"):
    evals[oname] = pickle_load(results_dir / f"{oname}-EvalResult.pkl")


Loading evaluation results: 100%|██████████| 6/6 [00:13<00:00,  2.24s/it]


In [12]:
import pandas as pd
from spot.type_env import AccuracyMetric

projects = [p.root_dir.name for p in test_projects]
common_type_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)

strategies_to_show = ["no-neighbors", "non-incr", "double-traversal"]
acc_metric = AccuracyMetric(common_type_names)
strategy2accs = {
    s: [
        evals[s].error_analysis(pname, acc_metric).accuracies[acc_metric.name]
        for pname in projects
    ]
    for s in strategies_to_show
}

n_annots = [
    sum(e.get_signature().n_annots() for e in p.all_elems()) for p in test_projects
]
# n_labels = [
#     sum(e.get_signature().n_annotated() for e in p.all_elems()) for p in test_projects
# ]
n_labels = [x.n_total for x in strategy2accs[strategies_to_show[0]]]

label_rates = [a / b for a, b in zip(n_labels, n_annots)]

df = pd.DataFrame(
    {
        "project": projects,
        **{n: [x.acc for x in xs] for n, xs in strategy2accs.items()},
        "label_rate": label_rates,
        "labels": n_labels,
    }
)
df.sort_values(by=["labels"], ascending=False)


Unnamed: 0,project,no-neighbors,non-incr,double-traversal,label_rate,labels
17,basilisp-lang__basilisp,0.614854,0.648702,0.646073,0.526926,3043
29,nabla-c0d3__sslyze,0.752391,0.788523,0.797024,0.568924,941
27,kornicameister__axion,0.517385,0.536857,0.495132,0.630149,719
1,scalableminds__webknossos-connect,0.694888,0.769968,0.79393,0.830239,626
30,rakitaj__daily-programmer,0.838889,0.866667,0.883333,0.500928,540
42,seattleflu__id3c,0.757874,0.860236,0.879921,0.442509,508
13,marcosschroh__dataclasses-avroschema,0.544419,0.528474,0.510251,0.427875,439
25,nubark__instark,0.829457,0.868217,0.873385,0.346774,387
48,everyclass__everyclass-server,0.735736,0.804805,0.762763,0.3729,333
35,lucaswerkmeister__tool-quickcategories,0.793548,0.806452,0.796774,0.658174,310


In [13]:
df.to_csv(proj_root() / "data" / "accs_by_project.csv")


In [35]:
ex_proj = test_projects[27]
print("Project location:", ex_proj.root_dir)

pid = evals["double-traversal"].project_roots.index(ex_proj.root_dir)
ex_analysis = evals["double-traversal"].error_analysis(pid, acc_metric)
ex_errors = ex_analysis.errors[ex_proj.root_dir.name]
ex_rollout = evals["double-traversal"].predictions[pid]

print("Number of errors:", len(ex_errors))


Project location: /mnt/data0/jiayi/SPOT/datasets/ManyTypes4Py/repos/test/kornicameister__axion
Number of errors: 363


In [38]:
from spot.type_env import PythonType, normalize_type, remove_top_optional

remove_top_optional(normalize_type(PythonType.from_str("t.Optional[t.Dict[str, t.Any]]")))

ty't.Dict[t.Any]'

In [36]:
# under predicted types sorted by frequency
under_pred_types = Counter()
over_predicted_types = Counter()
for pred in ex_errors:
    under_pred_types[pred.expected] += 1
    over_predicted_types[pred.predicted] += 1

print("Under-predicted types:")
display(under_pred_types.most_common(10))
print("Over-predicted types:")
display(over_predicted_types.most_common(10))

Under-predicted types:


[(ty't.Dict[str, t.Any]', 45),
 (ty'str', 30),
 (ty'logging.LogCaptureFixture', 21),
 (ty't.Type[t.Any]', 18),
 (ty't.Optional[t.Dict[str, t.Any]]', 14),
 (ty't.Dict[str, t.Dict[str, t.Any]]', 11),
 (ty't.Tuple[t.Set[exceptions.Error], model.ParamMapping]', 11),
 (ty'oas.OASOperation', 10),
 (ty'ptm.MockFixture', 10),
 (ty'pipeline.Response', 8)]

Over-predicted types:


[(ty't.Any', 66),
 (ty't.Mapping[str, t.Any]', 34),
 (ty'logging.Logger', 20),
 (ty't.Mapping[str, t.Mapping[str, t.Any]]', 17),
 (ty'int', 15),
 (ty't.Tuple[t.Set[str], t.Dict[str, t.Any]]', 10),
 (ty't.Dict[str, t.Any]', 9),
 (ty'ptm.Mock', 9),
 (ty'model.Parameter', 8),
 (ty't.Mapping[str, str]', 7)]

In [21]:
from spot.static_analysis import ProjectPath

error_id = 100
pred = ex_errors[error_id]
print("----------------------")
print("Elem Path:", pred.path)
print("Type Index:", pred.index)
evals["double-traversal"].inspect_elem(pid, pred.path)


----------------------
Elem Path: tests.schemas.conftest/UnionSchema.lake_trip
Type Index: 0
Expected:  AttrSig(cst'typing.Union[Bus, Car]')
Predicted: AttrSig(cst'...')
Code:
import datetime
import json
import os
import enum
import typing
import uuid
import pytest
from pydantic import Field
from dataclasses_avroschema.avrodantic import AvroBaseModel


# Used:

# Target:
# tests.schemas.conftest
class UnionSchema(AvroBaseModel):
    lake_trip: <extra_id_0> = Field(default_factory=lambda: Bus(engine_name="honda"))

# Users:
# tests.schemas.test_pydantic
def test_pydantic_record_schema_with_unions_type(union_type_schema: typing.Mapping[str, UnionSchema]) -> None:
    class Bus(AvroBaseModel):
        engine_name:...

        class Meta:
            namespace = "types.bus_type"

    class Car(AvroBaseModel):
        engine_name:...

        class Meta:
            namespace = "types.car_type"

    class TripDistance(enum.Enum):
        CLOSE = "Close"
        FAR = "Far"

        class Me