In [1]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root

os.chdir(proj_root())

# gpu_id = 1
# modeldir = get_model_dir()

model_name="model-v6--TrainingConfig(drop_env_types=False)"
dataset_name = "ManyTypes4Py"
# dataset_name = "SPOT-src"

In [2]:
from spot.utils import *
from spot.function_dataset import data_project_from_dir

# load test projects
repos_dir = get_dataset_dir(dataset_name) / "repos" / "test"
test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]
test_projects = pmap(
    data_project_from_dir,
    test_repo_paths,
    desc="Loading test projects",
)
assert len(test_projects) > 0


Loading test projects: 100%|██████████| 50/50 [00:19<00:00,  2.53it/s]


In [4]:
from spot.function_decoding import DecodingOrders, EvalResult
import re

results_dir = get_eval_dir(dataset_name, model_name)

decode_orders = [m.group(1) for f in results_dir.iterdir() if (m:=re.match(r"(.+)-EvalResult\.pkl", f.name)) is not None]
decode_orders.sort()

evals = dict[str, EvalResult]()

for oname in tqdm(decode_orders, desc="Loading evaluation results"):
    evals[oname] = pickle_load(results_dir / f"{oname}-EvalResult.pkl")


Loading evaluation results: 100%|██████████| 3/3 [00:08<00:00,  2.70s/it]


In [5]:
from prettytable import PrettyTable
import prettytable as pt
from spot.type_env import AccuracyMetric

common_type_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)
metrics = AccuracyMetric.default_metrics(common_type_names)
results_table = PrettyTable()
results_table.field_names = ["order", *(m.name for m in metrics)]
results_table.align = "r"
results_table.set_style(pt.SINGLE_BORDER)
results_table.float_format = ".4"

for oname in decode_orders:
    accs = [evals[oname].error_analysis(None, metric).accuracies[metric.name].acc for metric in metrics]
    results_table.add_row([oname, *accs])

print(results_table)
# write_file(results_dir / "comparison.txt", results_table.get_string())

┌──────────────────┬───────────┬────────┬──────────┐
│            order │ plain_acc │    acc │ base_acc │
├──────────────────┼───────────┼────────┼──────────┤
│ double-traversal │    0.7158 │ 0.7092 │   0.7717 │
│         non-incr │    0.7022 │ 0.6951 │   0.7583 │
│     random-twice │    0.6849 │ 0.6756 │   0.7569 │
└──────────────────┴───────────┴────────┴──────────┘


In [7]:

projects = [p.root_dir.name for p in test_projects]
acc_metric = AccuracyMetric(common_type_names)

strategies_to_show = ["non-incr", "random-twice", "double-traversal"]
strategy2accs = {
    s: [
        evals[s].error_analysis(pname, acc_metric).accuracies[acc_metric.name]
        for pname in projects
    ]
    for s in strategies_to_show
}

n_annots = [
    sum(e.get_signature().n_annots() for e in p.all_elems()) for p in test_projects
]
# n_labels = [
#     sum(e.get_signature().n_annotated() for e in p.all_elems()) for p in test_projects
# ]
n_labels = [x.n_total for x in strategy2accs[strategies_to_show[0]]]
label_sizes = [
    evals[strategies_to_show[0]]
    .error_analysis(pname, acc_metric)
    .accuracies[f"{acc_metric.name}_label_size"]
    for pname in projects
]

label_rates = [a / b for a, b in zip(n_labels, n_annots)]

df = pd.DataFrame(
    {
        "project": projects,
        **{n: [x.acc for x in xs] for n, xs in strategy2accs.items()},
        "label_size": label_sizes,
        "label_rate": label_rates,
        "labels": n_labels,
    }
)
df_sorted = df.sort_values(by=["labels"], ascending=False)
df_sorted


Unnamed: 0,project,non-incr,random-twice,double-traversal,label_size,label_rate,labels
17,basilisp-lang__basilisp,0.60871,0.492796,0.617551,1.242633,0.525826,3054
29,nabla-c0d3__sslyze,0.777191,0.761352,0.769799,1.269271,0.566049,947
27,kornicameister__axion,0.488473,0.465418,0.481268,1.644092,0.625789,694
1,scalableminds__webknossos-connect,0.723147,0.736762,0.760968,1.242057,0.830402,661
35,lucaswerkmeister__tool-quickcategories,0.750455,0.772313,0.790528,1.510018,0.605292,549
42,seattleflu__id3c,0.709324,0.698355,0.716636,1.20841,0.449466,547
30,rakitaj__daily-programmer,0.738889,0.75,0.766667,1.527778,0.500928,540
25,nubark__instark,0.887728,0.890339,0.877285,1.172324,0.364762,383
16,brettkromkamp__topic-db,0.72619,0.693452,0.723214,1.145833,0.59893,336
48,everyclass__everyclass-server,0.777778,0.81982,0.783784,1.201201,0.3729,333


In [8]:
df


Unnamed: 0,project,non-incr,random-twice,double-traversal,label_size,label_rate,labels
0,srittau__FakeSMTPd,0.87069,0.922414,0.931034,1.232759,0.348348,116
1,scalableminds__webknossos-connect,0.723147,0.736762,0.760968,1.242057,0.830402,661
2,flopp__unicode-explorer,0.777778,0.863248,0.863248,1.145299,0.576355,117
3,eirannejad__calcatime,0.90625,0.90625,0.90625,1.9375,0.8,32
4,jelford__webwatcher,0.765957,0.744681,0.744681,1.319149,0.218605,47
5,road-master__video-archiver,0.861789,0.861789,0.869919,1.162602,0.416949,123
6,flopp__GpxTrackPoster,0.686347,0.701107,0.682657,1.247232,0.654589,271
7,boompig__book-classics,0.929412,0.929412,0.929412,1.270588,0.586207,85
8,dropbox__sqlalchemy-stubs,0.690476,0.690476,0.690476,1.261905,0.608696,42
9,reddit__baseplate.py-upgrader,0.829268,0.829268,0.847561,1.25,0.438503,164


In [9]:
df_sorted.to_csv(proj_root() / "data" / "accs_by_project.csv")


In [10]:
ex_proj = test_projects[29]
print("Project location:", ex_proj.root_dir)

evalr = evals["double-traversal"]
pid = evalr.project_roots.index(ex_proj.root_dir)
ex_analysis = evalr.error_analysis(pid, acc_metric)
ex_errors = ex_analysis.errors[ex_proj.root_dir.name]
ex_rollout = evalr.predictions[pid]

print("Number of errors:", len(ex_errors))


Project location: /mnt/data0/jiayi/SPOT/datasets/ManyTypes4Py/repos/test/nabla-c0d3__sslyze
Number of errors: 218


In [32]:
from spot.utils import *

DefaultTokenizer.encode("<mask><mask>1", add_special_tokens=False)

[4, 4, 21]

In [11]:
# under predicted types sorted by frequency
under_pred_types = Counter()
over_predicted_types = Counter()
for pred in ex_errors:
    under_pred_types[acc_metric.process_type(pred.expected)] += 1
    over_predicted_types[acc_metric.process_type(pred.predicted)] += 1

print("Under-predicted types:")
display(under_pred_types.most_common(10))
print("Over-predicted types:")
display(over_predicted_types.most_common(10))


Under-predicted types:


[(ty'ScanCommandExtraArgument', 15),
 (ty'str', 13),
 (ty'int', 12),
 (ty'_MozillaTlsConfigurationAsJson', 6),
 (ty'CipherSuitesScanAttemptAsJson', 6),
 (ty'CipherSuitesScanAttempt', 6),
 (ty'_Base64EncodedBytes', 5),
 (ty'AllScanCommandsAttempts', 4),
 (ty'ScanCommandResult', 4),
 (ty'ClassVar[str]', 4)]

Over-predicted types:


[(ty'str', 31),
 (ty'Any', 21),
 (ty'ScanCommandsExtraArguments', 19),
 (ty'...', 6),
 (ty'ServerScanResult', 5),
 (ty'ScanCommandResult', 5),
 (ty'float', 4),
 (ty'Dict[str, str]', 4),
 (ty'int', 4),
 (ty'bool', 4)]

In [29]:
from spot.static_analysis import ProjectPath

error_id = 100
pred = ex_errors[error_id]
print("----------------------")
print("Elem Path:", pred.path)
print("Type Index:", pred.index)
evalr.inspect_elem(pid, pred.path)
# evalr.inspect_elem(pid, ProjectPath.from_str("tests.factories/AllScanCommandsAttemptsFactory.create"))


----------------------
Elem Path: sslyze.plugins.openssl_cipher_suites.json_output/CipherSuitesScanResultAsJson.from_orm
Type Index: 2
Expected:  MethodSig((cls, scan_result: CipherSuitesScanResult) -> "CipherSuitesScanResultAsJson")
Predicted: MethodSig((cls: Any, scan_result: CipherSuitesScanResult) -> Any)
Code:
import pydantic
from base64 import b64encode
from typing import List, Optional
from nassl.ephemeral_key_info import EphemeralKeyInfo, EcDhEphemeralKeyInfo, NistEcDhKeyExchangeInfo, DhEphemeralKeyInfo
from sslyze.json.scan_attempt_json import ScanCommandAttemptAsJson
from sslyze.plugins.openssl_cipher_suites.implementation import (
    CipherSuitesScanResult,
    CipherSuiteAcceptedByServer,
)
class _BaseModelWithOrmMode(pydantic.BaseModel):
   ...
class _CipherSuiteAsJson(_BaseModelWithOrmMode):
   ...
_Base64EncodedBytes = str
class _EphemeralKeyInfoAsJson(_BaseModelWithOrmMode):
   ...
class _CipherSuiteAcceptedByServerAsJson(_BaseModelWithOrmMode):
   ...
class _CipherSui