In [1]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root, get_data_dir

os.chdir(proj_root())

gpu_id = 1
datadir = get_data_dir()
modeldir = get_model_dir()

model_name="model-v4--TrainingConfig(func_only=True, drop_env_types=False, left_margin=1536, preamble_size=768, right_margin=2048)"
wrapper = ModelWrapper.from_pretrained(
    modeldir / f"checkpoints/lit-saved/{model_name}"
)
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
print(f"Model loaded to {device}")


Model loaded to cuda:1


In [6]:
from spot.static_analysis import (
    PythonProject,
    PythonModule,
    UsageAnalysis,
    mask_types,
    remove_comments,
)
from spot.utils import cst, read_file, SpecialNames

# ex_code = read_file("src/spot/function_decoding.py")
# ex_module = mask_types(cst.parse_module(ex_code))
# ex_project = PythonProject.from_modules([PythonModule.from_cst(ex_module, "spot.function_decoding")])

ex_project = PythonProject.from_root(
    proj_root(),
    src_filter=lambda code: SpecialNames.TypeMask not in code,
    src_transform=lambda m: mask_types(remove_comments(m)),
    # file_filter=lambda f: "dagger" not in f.name and "critic" not in f.name,
    file_filter=lambda f: f.name == "function_decoding.py",
)


In [7]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
)

rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

with ProcessPoolExecutor(20) as cpu_executor, ThreadPoolExecutor(1) as model_executor:
    result = await rctx.project_rollout(
        ex_project,
        pre_args,
        decode_order=DecodingOrders.caller2callee,
        cpu_executor=cpu_executor,
        model_executor=model_executor,
        progress_cbk=lambda e, p: print(e.path, len(p)),
    )
    print()


spot.function_decoding/DecodingOrders.caller2callee 4
spot.function_decoding/DecodingOrders.random_order 2
spot.function_decoding/RolloutCtx.project_rollout 8
spot.function_decoding/RolloutCtx.model 1
spot.function_decoding/RolloutPrediction.elem2inputs 1
spot.function_decoding/RolloutPrediction.elem2preds 1
spot.function_decoding/RolloutPrediction.assignments 1
spot.function_decoding/construct_model_inputs 7
spot.function_decoding/is_mask_annot 2



In [10]:
from spot.utils import decode_tokens

for elem, preds in result.elem2preds.items():
    if "function_decoding" in elem.module:
        print("-----------------------------------------------------------")
        print(f"{elem.path}: {[str(p) for p in preds]}")
        print("model inputs:")
        print(decode_tokens(result.elem2inputs[elem]["input_ids"][0]))


DecodingOrders.caller2callee: ['UsageAnalysis', 'list[ProjectPath]', 'ProjectPath', 'None']
model inputs:
import asyncio
import copy
import random
import torch
from.type_env import AnnotInfo, collect_user_annotations
from.model import ModelWrapper
from.utils import *
from.static_analysis import (
    ModuleName,
    ProjectPath,
    PythonElem,
    PythonFunction,
    PythonProject,
    PythonVariable,
    UsageAnalysis,
    VariableSingature,
)
from.data import CtxArgs, SrcChunkInfo, src_to_chunks
from.function_dataset import (
    ElemSignature,
    ctx_modules_for_elem,
    mk_preamble,
    reformat_elems,
)
from.tokenized_src import (
    PreprocessArgs,
    TokenSeq,
    TokenizedSrc,
    tokenized_src_from_segs,
)
from.type_check import PythonType
@dataclass
class RolloutPrediction:
   ...
@dataclass
class RolloutCtx:
   ...
class DecodingOrders:
   ...

# BEGIN
# spot.function_decoding
class DecodingOrders:
    @staticmethod
    def caller2callee(analysis: <extra_id_0>) -> <extr