In [1]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root

os.chdir(proj_root())

gpu_id = 1
modeldir = get_model_dir()

# model_name="model-v4--TrainingConfig(func_only=True, drop_env_types=False, left_margin=1536, preamble_size=768, right_margin=2048)"
model_name="model-v5--TrainingConfig(drop_env_types=False)"
wrapper = ModelWrapper.from_pretrained(
    modeldir / model_name
)
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
print(f"Model loaded to {device}")


Model loaded to cuda:1


In [2]:
from spot.static_analysis import (
    PythonProject,
    PythonModule,
    UsageAnalysis,
    mask_types,
    remove_comments,
)
from spot.utils import *
from spot.function_dataset import data_project_from_dir
from spot.visualization import show_code_range

# ex_code = read_file("src/spot/function_decoding.py")
# ex_module = mask_types(cst.parse_module(ex_code))
# ex_project = PythonProject.from_modules([PythonModule.from_cst(ex_module, "spot.function_decoding")])

src_set = {
    "function_decoding.py",
    "data.py",
    "type_env.py",
    "function_dataset.py",
    "tokenized_src.py",
    "type_check.py",
    "model.py",
    "utils.py",
    "static_analysis.py",
}

# ex_project = data_project_from_dir(proj_root(), file_filter=lambda f: f.name in {"model.py", "utils.py"})

ex_project = data_project_from_dir(Path("/mnt/data0/jiayi/SPOT/datasets/ManyTypes4Py/repos/test/lucaswerkmeister__tool-quickcategories"))


In [3]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
    AccuracyMetric,
)
from spot.model import DecodingArgs
from spot.visualization import pretty_print_dict

ctx_args = wrapper.args.ctx_args
wrapper.args = DecodingArgs(
    sampling_max_tokens=ctx_args.ctx_size,
    ctx_args=ctx_args,
    do_sample=False,
    num_beams=16,
    tokens_per_type=16,
    # length_penalty=0.2,
)

rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

evalr = await rctx.evaluate_on_projects(
    [ex_project], pre_args, DecodingOrders.DoubleTraversal(),
)

metric = AccuracyMetric(wrapper.common_type_names)
pretty_print_dict(evalr.error_analysis(None, metric).accuracies)


evaluate_on_projects: 100%|██████████| 1201/1201 [13:03<00:00,  1.53it/s]

acc: 81.24% (count=549)
acc_by_common:
   rare: 79.29% (count=169)
   common: 82.11% (count=380)
acc_by_cat:
   FuncArg: 80.90% (count=199)
   FuncReturn: 81.72% (count=290)
   ClassAtribute: 78.95% (count=57)
   GlobalVar: 100.00% (count=3)
acc_label_size: 1.51
acc_pred_size: 1.2368
acc_ignored_labels: 241
n_skipped_types: 0
n_missing_types: 0





In [4]:
pid = evalr.project_roots.index(ex_project.root_dir)
ex_analysis = evalr.error_analysis(pid, metric)
ex_errors = ex_analysis.errors[ex_project.root_dir.name]
ex_rollout = evalr.predictions[pid]

print("Number of errors:", len(ex_errors))

Number of errors: 115


In [5]:
from spot.static_analysis import ProjectPath

error_id = 0
pred = ex_errors[error_id]
print("----------------------")
print("Elem Path:", pred.path)
print("Type Index:", pred.index)
evalr.inspect_elem(pid, pred.path)


----------------------
Elem Path: command/Command.apply
Type Index: 2
Expected:  MethodSig((wikitext: str, category_info: CategoryInfo) -> Tuple[str, List[Tuple[Action, bool]]])
Predicted: MethodSig((wikitext: str, category_info: CategoryInfo) -> Tuple[str, List[Action]])
Code:
from dataclasses import dataclass
import datetime
from action import Action
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
from page import Page
from siteinfo import CategoryInfo
@dataclass(...)
class Command:
   ...
@dataclass(...)
class CommandRecord(ABC):
   ...
class CommandPlan(CommandRecord):
   ...
class CommandPending(CommandRecord):
   ...
class CommandFinish(CommandRecord):
   ...
class CommandSuccess(CommandFinish):
   ...
@dataclass(...)
class CommandEdit(CommandSuccess):
   ...
@dataclass(...)
class CommandNoop(CommandSuccess):
   ...
class CommandFailure(CommandFinish):
   ...
@dataclass(...)
class CommandPageMissing(CommandFailure):
   ...
@dataclass(...)
class