In [2]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root, get_data_dir

os.chdir(proj_root())

gpu_id = 1
modeldir = get_model_dir()

model_name="model-v4--TrainingConfig(func_only=True, drop_env_types=False, left_margin=1536, preamble_size=768, right_margin=2048)"
wrapper = ModelWrapper.from_pretrained(
    modeldir / model_name
)
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
print(f"Model loaded to {device}")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Model loaded to cuda:1


In [7]:
from spot.static_analysis import (
    PythonProject,
    PythonModule,
    UsageAnalysis,
    mask_types,
    remove_comments,
)
from spot.utils import cst, read_file, SpecialNames
from spot.function_dataset import data_project_from_dir
from spot.visualization import show_code_range

# ex_code = read_file("src/spot/function_decoding.py")
# ex_module = mask_types(cst.parse_module(ex_code))
# ex_project = PythonProject.from_modules([PythonModule.from_cst(ex_module, "spot.function_decoding")])

src_set = {
    "function_decoding.py",
    "data.py",
    "type_env.py",
    "function_dataset.py",
    "tokenized_src.py",
    "type_check.py",
    "model.py",
    "utils.py",
    "static_analysis.py",
}

ex_project = data_project_from_dir(proj_root(), file_filter=lambda f: f.name in {"model.py"})


In [10]:
for e in ex_project.all_elems():
    file, span = ex_project.get_elem_location(e.path)
    print(e)
    print(f"\t{file}: {show_code_range(span)}")

GlobalFunction(path=spot.model/dynamic_dataloader)
	src/spot/model.py: [231:1--255:6]
ClassAttribute(path=spot.model/DecodingArgs.ctx_args)
	src/spot/model.py: [26:5--26:22]
ClassAttribute(path=spot.model/DecodingArgs.sampling_max_tokens)
	src/spot/model.py: [27:5--27:29]
ClassAttribute(path=spot.model/DecodingArgs.max_workers)
	src/spot/model.py: [28:5--28:38]
ClassAttribute(path=spot.model/DecodingArgs.tokens_per_type)
	src/spot/model.py: [29:5--29:30]
ClassAttribute(path=spot.model/DecodingArgs.slack_tokens)
	src/spot/model.py: [30:5--30:27]
ClassAttribute(path=spot.model/DecodingArgs.do_sample)
	src/spot/model.py: [31:5--31:28]
ClassAttribute(path=spot.model/DecodingArgs.top_p)
	src/spot/model.py: [32:5--32:23]
ClassAttribute(path=spot.model/DecodingArgs.num_beams)
	src/spot/model.py: [33:5--33:36]
ClassAttribute(path=spot.model/DecodingArgs.num_beam_groups)
	src/spot/model.py: [34:5--34:42]
ClassAttribute(path=spot.model/DecodingArgs.length_penalty)
	src/spot/model.py: [35:5--35:3

In [27]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
)
from spot.model import DecodingArgs
from spot.visualization import pretty_print_dict

ctx_args = wrapper.args.ctx_args
wrapper.args = DecodingArgs(
    sampling_max_tokens=ctx_args.ctx_size,
    ctx_args=ctx_args,
    do_sample=False,
    num_beams=16,
    tokens_per_type=16,
    length_penalty=0.2,
)

rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

evalr = await rctx.evaluate_on_projects(
    [ex_project], pre_args, DecodingOrders.Callee2Caller(), common_type_names=set()
)

pretty_print_dict(evalr.accuracies)


evaluate_on_projects: 100%|██████████| 898/898 [09:27<00:00,  1.58it/s]

partial_acc: 69.46% (count=1.1k)
full_acc: 64.50% (count=1.1k)
full_acc_by_common:
   rare: 64.50% (count=1.1k)
full_acc_by_cat:
   FuncArg: 59.75% (count=641)
   FuncReturn: 71.95% (count=246)
   ClassAtribute: 70.14% (count=211)
   GlobalVar: 66.67% (count=12)
full_acc_by_pos:
   range(0, 1): 62.32% (count=568)
   range(1, 2): 62.26% (count=257)
   range(2, 4): 71.78% (count=202)
   range(4, 8): 72.60% (count=73)
   range(8, 16): 40.00% (count=10)
avg_label_size: 1.1739
avg_pred_size: 1.091





In [28]:
evalr = await rctx.evaluate_on_projects(
    [ex_project], pre_args, DecodingOrders.DoubleTraversal(), common_type_names=set()
)

pretty_print_dict(evalr.accuracies)


evaluate_on_projects: 1796it [18:50,  1.59it/s]                        


partial_acc: 70.18% (count=1.1k)
full_acc: 65.77% (count=1.1k)
full_acc_by_common:
   rare: 65.77% (count=1.1k)
full_acc_by_cat:
   FuncArg: 60.06% (count=641)
   FuncReturn: 72.76% (count=246)
   ClassAtribute: 74.88% (count=211)
   GlobalVar: 66.67% (count=12)
full_acc_by_pos:
   range(0, 1): 64.26% (count=568)
   range(1, 2): 63.04% (count=257)
   range(2, 4): 71.78% (count=202)
   range(4, 8): 71.23% (count=73)
   range(8, 16): 60.00% (count=10)
avg_label_size: 1.1739
avg_pred_size: 1.1


In [26]:
evalr = await rctx.evaluate_on_projects(
    [ex_project], pre_args, DecodingOrders.RandomOrder(), common_type_names=set()
)

pretty_print_dict(evalr.accuracies)


evaluate_on_projects: 100%|██████████| 898/898 [09:25<00:00,  1.59it/s]


partial_acc: 68.20% (count=1.1k)
full_acc: 62.97% (count=1.1k)
full_acc_by_common:
   rare: 62.97% (count=1.1k)
full_acc_by_cat:
   FuncArg: 57.41% (count=641)
   FuncReturn: 71.14% (count=246)
   ClassAtribute: 71.09% (count=211)
   GlobalVar: 50.00% (count=12)
full_acc_by_pos:
   range(0, 1): 60.56% (count=568)
   range(1, 2): 63.42% (count=257)
   range(2, 4): 69.31% (count=202)
   range(4, 8): 67.12% (count=73)
   range(8, 16): 30.00% (count=10)
avg_label_size: 1.1739
avg_pred_size: 1.1


In [29]:
# ablation: no incremental decoding
evalr = await rctx.evaluate_on_projects(
    [ex_project],
    pre_args,
    DecodingOrders.IndependentOrder(),
    common_type_names=set(),
)

pretty_print_dict(evalr.accuracies)


evaluate_on_projects: 100%|██████████| 898/898 [09:14<00:00,  1.62it/s]


partial_acc: 66.13% (count=1.1k)
full_acc: 60.45% (count=1.1k)
full_acc_by_common:
   rare: 60.45% (count=1.1k)
full_acc_by_cat:
   FuncArg: 57.10% (count=641)
   FuncReturn: 70.73% (count=246)
   ClassAtribute: 59.24% (count=211)
   GlobalVar: 50.00% (count=12)
full_acc_by_pos:
   range(0, 1): 58.10% (count=568)
   range(1, 2): 61.48% (count=257)
   range(2, 4): 64.36% (count=202)
   range(4, 8): 67.12% (count=73)
   range(8, 16): 40.00% (count=10)
avg_label_size: 1.1739
avg_pred_size: 1.082


In [None]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
)

rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

with ProcessPoolExecutor(20) as cpu_executor, ThreadPoolExecutor(1) as model_executor:
    result = await rctx.project_rollout(
        ex_project,
        pre_args,
        decode_order=DecodingOrders.Callee2Caller(),
        cpu_executor=cpu_executor,
        model_executor=model_executor,
        progress_cbk=lambda e, p: print(e.path, len(p)),
    )
    print()


spot.function_decoding/DecodingOrders.caller2callee 4
spot.function_decoding/DecodingOrders.random_order 2
spot.function_decoding/RolloutCtx.project_rollout 8
spot.function_decoding/RolloutCtx.model 1
spot.function_decoding/RolloutPrediction.elem2inputs 1
spot.function_decoding/RolloutPrediction.elem2preds 1
spot.function_decoding/RolloutPrediction.assignments 1
spot.function_decoding/construct_model_inputs 7
spot.function_decoding/is_mask_annot 2



In [None]:
from spot.utils import decode_tokens

for elem, preds in result.elem2preds.items():
    if "function_decoding" in elem.module:
        print("-----------------------------------------------------------")
        print(f"{elem.path}: {[str(p) for p in preds]}")
        print("model inputs:")
        print(decode_tokens(result.elem2inputs[elem]["input_ids"][0]))
