In [2]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root, get_data_dir

os.chdir(proj_root())

gpu_id = 1
datadir = get_data_dir()
modeldir = get_model_dir()

model_name="model-v4--TrainingConfig(func_only=True, drop_env_types=False, left_margin=1536, preamble_size=768, right_margin=2048)"
wrapper = ModelWrapper.from_pretrained(
    modeldir / f"checkpoints/lit-saved/{model_name}"
)
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
print(f"Model loaded to {device}")


Model loaded to cuda:1


In [16]:
from spot.static_analysis import (
    PythonProject,
    PythonModule,
    UsageAnalysis,
    mask_types,
    remove_comments,
)
from spot.utils import cst, read_file, SpecialNames

# ex_code = read_file("src/spot/function_decoding.py")
# ex_module = mask_types(cst.parse_module(ex_code))
# ex_project = PythonProject.from_modules([PythonModule.from_cst(ex_module, "spot.function_decoding")])

src_set = {
    "function_decoding.py",
    "data.py",
    "type_env.py",
    "function_dataset.py",
    "tokenized_src.py",
    "type_check.py",
    "model.py",
    "utils.py",
    "static_analysis.py",
}

def handle_msak(m: cst.Module):
    new_code = remove_comments(m).code.replace("SPOT_TYPE_MASK", "Not_A_Mask")
    return cst.parse_module(new_code)
    

ex_project = PythonProject.from_root(
    proj_root(),
    src_transform=lambda m: handle_msak(m),
    # file_filter=lambda f: "dagger" not in f.name and "critic" not in f.name,
    # file_filter=lambda f: f.name in src_set,
)


In [26]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
)
from spot.model import DecodingArgs
from spot.visualization import pretty_print_dict

ctx_args = wrapper.args.ctx_args
wrapper.args = DecodingArgs(
    sampling_max_tokens=ctx_args.ctx_size,
    ctx_args=ctx_args,
    do_sample=False,
    num_beams=16,
    tokens_per_type=16,
    length_penalty=0.2,
)

rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

evalr = await rctx.evaluate_on_projects(
    [ex_project], pre_args, DecodingOrders.random_order, common_type_names=set()
)

pretty_print_dict(evalr.accuracies)

evaluate_on_projects: 100%|██████████| 898/898 [09:25<00:00,  1.59it/s]


partial_acc: 68.20% (count=1.1k)
full_acc: 62.97% (count=1.1k)
full_acc_by_common:
   rare: 62.97% (count=1.1k)
full_acc_by_cat:
   FuncArg: 57.41% (count=641)
   FuncReturn: 71.14% (count=246)
   ClassAtribute: 71.09% (count=211)
   GlobalVar: 50.00% (count=12)
full_acc_by_pos:
   range(0, 1): 60.56% (count=568)
   range(1, 2): 63.42% (count=257)
   range(2, 4): 69.31% (count=202)
   range(4, 8): 67.12% (count=73)
   range(8, 16): 30.00% (count=10)
avg_label_size: 1.1739
avg_pred_size: 1.1


In [None]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
)

rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

with ProcessPoolExecutor(20) as cpu_executor, ThreadPoolExecutor(1) as model_executor:
    result = await rctx.project_rollout(
        ex_project,
        pre_args,
        decode_order=DecodingOrders.caller2callee,
        cpu_executor=cpu_executor,
        model_executor=model_executor,
        progress_cbk=lambda e, p: print(e.path, len(p)),
    )
    print()


spot.function_decoding/DecodingOrders.caller2callee 4
spot.function_decoding/DecodingOrders.random_order 2
spot.function_decoding/RolloutCtx.project_rollout 8
spot.function_decoding/RolloutCtx.model 1
spot.function_decoding/RolloutPrediction.elem2inputs 1
spot.function_decoding/RolloutPrediction.elem2preds 1
spot.function_decoding/RolloutPrediction.assignments 1
spot.function_decoding/construct_model_inputs 7
spot.function_decoding/is_mask_annot 2



In [None]:
from spot.utils import decode_tokens

for elem, preds in result.elem2preds.items():
    if "function_decoding" in elem.module:
        print("-----------------------------------------------------------")
        print(f"{elem.path}: {[str(p) for p in preds]}")
        print("model inputs:")
        print(decode_tokens(result.elem2inputs[elem]["input_ids"][0]))
