In [2]:
%load_ext autoreload
%autoreload 2

# first, load the trained model
import os
import torch
from typing import *

from spot.model import ModelWrapper
from spot.utils import get_model_dir, proj_root

os.chdir(proj_root())

gpu_id = 1
modeldir = get_model_dir()

# model_name="model-v4--TrainingConfig(func_only=True, drop_env_types=False, left_margin=1536, preamble_size=768, right_margin=2048)"
model_name="model-v5--TrainingConfig(drop_env_types=False)"
wrapper = ModelWrapper.from_pretrained(
    modeldir / model_name
)
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
print(f"Model loaded to {device}")


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Model loaded to cuda:1


In [2]:
from spot.static_analysis import (
    PythonProject,
    PythonModule,
    UsageAnalysis,
    mask_types,
    remove_comments,
)
from spot.utils import cst, read_file, SpecialNames
from spot.function_dataset import data_project_from_dir
from spot.visualization import show_code_range

# ex_code = read_file("src/spot/function_decoding.py")
# ex_module = mask_types(cst.parse_module(ex_code))
# ex_project = PythonProject.from_modules([PythonModule.from_cst(ex_module, "spot.function_decoding")])

src_set = {
    "function_decoding.py",
    "data.py",
    "type_env.py",
    "function_dataset.py",
    "tokenized_src.py",
    "type_check.py",
    "model.py",
    "utils.py",
    "static_analysis.py",
}

ex_project = data_project_from_dir(proj_root(), file_filter=lambda f: f.name in {"model.py", "utils.py"})


In [3]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
)
from spot.model import DecodingArgs
from spot.visualization import pretty_print_dict

ctx_args = wrapper.args.ctx_args
wrapper.args = DecodingArgs(
    sampling_max_tokens=ctx_args.ctx_size,
    ctx_args=ctx_args,
    do_sample=False,
    num_beams=16,
    tokens_per_type=16,
    length_penalty=0.2,
)

rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

evalr = await rctx.evaluate_on_projects(
    [ex_project], pre_args, DecodingOrders.RandomOrder(),
)

pretty_print_dict(evalr.accuracies(None, wrapper.common_type_names))


evaluate_on_projects: 100%|██████████| 134/134 [00:44<00:00,  3.03it/s]

partial_acc: 70.27% (count=185)
full_acc: 68.11% (count=185)
full_acc_by_common:
   rare: 68.11% (count=185)
full_acc_by_cat:
   FuncArg: 63.37% (count=101)
   FuncReturn: 65.45% (count=55)
   ClassAtribute: 89.29% (count=28)
   GlobalVar: 100.00% (count=1)
full_acc_by_pos:
   range(0, 1): 69.79% (count=96)
   range(1, 2): 61.54% (count=52)
   range(2, 4): 75.76% (count=33)
   range(4, 8): 50.00% (count=4)
avg_label_size: 1.1459
avg_pred_size: 1.1135





In [4]:
from spot.utils import decode_tokens
from spot.static_analysis import ProjectPath

rollout = evalr.predictions[0]
elems_to_show = [ProjectPath("spot.model", "DecodingArgs.max_workers")]

for path in elems_to_show:
    print(path)
    print("prediction:", rollout.elem2preds[path])
    print("input:", decode_tokens(rollout.elem2inputs[path]["input_ids"]))


spot.model/DecodingArgs.max_workers
prediction: [ty'int']
input: import random
import numpy as np
from mypy_extensions import mypyc_attr
from.type_env import PythonType
from.utils import *
from copy import copy, deepcopy
from collections import Counter
from typing import NamedTuple, overload
from datasets.arrow_dataset import Dataset
from torch import Tensor
from torch.utils.data import DataLoader, RandomSampler
from transformers.data.data_collator import DataCollatorForSeq2Seq
from.data import (
    ChunkedDataset,
    CtxArgs,
    TokenizedSrcSet,
    output_ids_as_types,
    preds_to_accuracies,
)
@dataclass
class DecodingArgs:
   ...
@dataclass
class DatasetPredResult(Generic[T1]):
   ...
@dataclass
class ModelWrapper:
   ...

# Used:
# spot.utils
DefaultWorkers:... = multiprocessing.cpu_count() // 2

# Target:
# spot.model
@dataclass
class DecodingArgs:
    max_workers: <extra_id_0> = DefaultWorkers

# Users:




In [26]:
from spot.utils import DefaultTokenizer

DefaultTokenizer.encode("\n# Target:\n", add_special_tokens=False)

[203, 7, 5916, 30, 203]

In [28]:
evalr = await rctx.evaluate_on_projects(
    [ex_project], pre_args, DecodingOrders.DoubleTraversal()
)

pretty_print_dict(evalr.accuracies(None, wrapper.common_type_names))


evaluate_on_projects: 1796it [18:50,  1.59it/s]                        


partial_acc: 70.18% (count=1.1k)
full_acc: 65.77% (count=1.1k)
full_acc_by_common:
   rare: 65.77% (count=1.1k)
full_acc_by_cat:
   FuncArg: 60.06% (count=641)
   FuncReturn: 72.76% (count=246)
   ClassAtribute: 74.88% (count=211)
   GlobalVar: 66.67% (count=12)
full_acc_by_pos:
   range(0, 1): 64.26% (count=568)
   range(1, 2): 63.04% (count=257)
   range(2, 4): 71.78% (count=202)
   range(4, 8): 71.23% (count=73)
   range(8, 16): 60.00% (count=10)
avg_label_size: 1.1739
avg_pred_size: 1.1


In [None]:
from spot.function_decoding import (
    RolloutCtx,
    PreprocessArgs,
    DecodingOrders,
    ProcessPoolExecutor,
    ThreadPoolExecutor,
)

rctx = RolloutCtx(
    model=wrapper,
)

pre_args = PreprocessArgs()

with ProcessPoolExecutor(20) as cpu_executor, ThreadPoolExecutor(1) as model_executor:
    result = await rctx.project_rollout(
        ex_project,
        pre_args,
        decode_order=DecodingOrders.Callee2Caller(),
        cpu_executor=cpu_executor,
        model_executor=model_executor,
        progress_cbk=lambda e, p: print(e.path, len(p)),
    )
    print()


spot.function_decoding/DecodingOrders.caller2callee 4
spot.function_decoding/DecodingOrders.random_order 2
spot.function_decoding/RolloutCtx.project_rollout 8
spot.function_decoding/RolloutCtx.model 1
spot.function_decoding/RolloutPrediction.elem2inputs 1
spot.function_decoding/RolloutPrediction.elem2preds 1
spot.function_decoding/RolloutPrediction.assignments 1
spot.function_decoding/construct_model_inputs 7
spot.function_decoding/is_mask_annot 2



In [None]:
from spot.utils import decode_tokens

for elem, preds in result.elem2preds.items():
    if "function_decoding" in elem.module:
        print("-----------------------------------------------------------")
        print(f"{elem.path}: {[str(p) for p in preds]}")
        print("model inputs:")
        print(decode_tokens(result.elem2inputs[elem]["input_ids"][0]))
