In [1]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root, get_gpu_id

os.chdir(proj_root())

gpu_id = get_gpu_id(1)
modeldir = get_model_dir()

# model_name="model-v4--TrainingConfig(func_only=True, drop_env_types=False, left_margin=1536, preamble_size=768, right_margin=2048)"
model_name="model-v6--TrainingConfig(drop_env_types=False)"
wrapper = ModelWrapper.from_pretrained(
    modeldir / model_name
)
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
print(f"Model loaded to {device}")


GPU_ID not set, using: 1
Model loaded to cuda:1


In [2]:
from spot.static_analysis import (
    PythonProject,
    PythonModule,
    UsageAnalysis,
    mask_types,
    remove_comments,
)
from spot.utils import *
from spot.function_dataset import data_project_from_dir
from spot.visualization import show_code_range


ex_project = data_project_from_dir(Path("/mnt/data0/jiayi/SPOT/datasets/ManyTypes4Py/repos/test/nabla-c0d3__sslyze"))


In [3]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
    AccuracyMetric,
)
from spot.model import DecodingArgs
from spot.visualization import pretty_print_dict

ctx_args = wrapper.args.ctx_args
wrapper.args = DecodingArgs(
    sampling_max_tokens=ctx_args.ctx_size,
    ctx_args=ctx_args,
    do_sample=False,
    num_beams=10,
    tokens_per_type=16,
    # length_penalty=0.2,
)


rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

evalr = await rctx.evaluate_on_projects(
    [ex_project], pre_args, DecodingOrders.DoubleTraversal(),
)

metric = AccuracyMetric(wrapper.common_type_names)
pretty_print_dict(evalr.error_analysis(None, metric).accuracies)


evaluate_on_projects: 100%|██████████| 2359/2359 [24:52<00:00,  1.58it/s]

acc: 75.82% (count=947)
acc_by_common:
   rare: 67.86% (count=560)
   common: 87.34% (count=387)
acc_by_cat:
   FuncArg: 80.17% (count=343)
   FuncReturn: 73.60% (count=250)
   ClassAtribute: 73.35% (count=349)
   GlobalVar: 60.00% (count=5)
acc_label_size: 1.2693
acc_pred_size: 1.2555
acc_ignored_labels: 62
n_skipped_types: 0
n_missing_types: 0





In [5]:
metric = AccuracyMetric(wrapper.common_type_names)
pretty_print_dict(evalr.error_analysis(None, metric).accuracies)

acc: 76.98% (count=947)
acc_by_common:
   rare: 70.36% (count=560)
   common: 86.56% (count=387)
acc_by_cat:
   FuncArg: 81.34% (count=343)
   FuncReturn: 75.60% (count=250)
   ClassAtribute: 73.93% (count=349)
   GlobalVar: 60.00% (count=5)
acc_label_size: 1.2693
acc_pred_size: 1.2777
acc_ignored_labels: 62
n_skipped_types: 0
n_missing_types: 0


In [4]:
pid = evalr.project_roots.index(ex_project.root_dir)
ex_analysis = evalr.error_analysis(pid, metric)
ex_errors = ex_analysis.errors[ex_project.root_dir.name]
ex_rollout = evalr.predictions[pid]

print("Number of errors:", len(ex_errors))

Number of errors: 229


In [6]:
from spot.static_analysis import ProjectPath

error_id = 3
pred = ex_errors[error_id]
print("----------------------")
print("Elem Path:", pred.path)
print("Type Index:", pred.index)
evalr.inspect_elem(pid, pred.path)


----------------------
Elem Path: tests.scanner_tests.test_mass_scanner/TestMassScannerProducerThread.PluginImplThatTriggersConnectivityError._scan_job_work_function
Type Index: 1
Expected:  MethodSig((arg1: str) -> str)
Predicted: MethodSig((arg1: str) -> None)
Code:
import threading
from unittest import mock
from sslyze.plugins.scan_commands import ScanCommandsRepository
from sslyze.scanner._mass_scanner import MassScannerProducerThread, NoMoreServerScanRequestsSentinel
from sslyze.server_connectivity import ServerConnectivityInfo
from tests.factories import ServerScanRequestFactory, ServerTlsProbingResultFactory
from typing import Optional, List
from queue import Queue
from pathlib import Path
from sslyze import (
    ScanCommand,
    ServerConnectivityStatusEnum,
    ServerScanStatusEnum,
    ScanCommandAttemptStatusEnum,
    ScanCommandsExtraArguments,
    CertificateInfoExtraArgument,
    ScanCommandErrorReasonEnum,
)
from sslyze.errors import TlsHandshakeTimedOut
from sslyze.plu

: 