In [1]:
%load_ext autoreload
%autoreload 2

import asyncio
import os
from typing import *

import torch
import wandb
from spot.data import get_tk_dataset_name
from spot.function_dataset import data_project_from_dir
from spot.model import ModelWrapper
from spot.train import TrainingConfig, PreprocessArgs
from spot.type_env import AccuracyMetric
from spot.utils import (
    PickleCache,
    assert_eq,
    get_dataroot,
    get_dataset_dir,
    get_eval_dir,
    get_gpu_id,
    get_model_dir,
    pickle_dump,
    pmap,
    pretty_print_dict,
    pretty_show_dict,
    proj_root,
    run_long_task,
    write_file,
)
from spot.visualization import string_to_html
from termcolor import colored

os.chdir(proj_root())


def wandb_string(s: str):
    return wandb.Html(string_to_html(s))

  warn(f"Failed to load image Python extension: {e}")


In [6]:
# experiment configurations
quicktest = False

gpu_id = get_gpu_id(2)
# model_name = "model-v6--TrainingConfig(func_only=False, left_margin=2048, preamble_size=800, right_margin=1536)"
model_name = "model-v6--TrainingConfig(func_only=False, imports_in_preamble=False, stub_in_preamble=False, left_margin=2048, right_margin=1536)"
pre_args = PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)
dataset_name = "ManyTypes4Py"
# dataset_name = "InferTypes4Py"
# dataset_name = "SPOT-src"
experiment_name = dataset_name + ": " + model_name

print(colored(f"Use GPU: {gpu_id}", "green"))


GPU_ID not set, using: 2
[32mUse GPU: 2[0m


In [7]:
# load test data
from spot.data import load_tokenized_srcsets, create_tokenized_srcsets

sdata_name = get_tk_dataset_name(dataset_name, pre_args, func_only=False)
sdata_path = get_dataroot() / "TokenizedSrcSets" / sdata_name
recreate=False
if recreate or not sdata_path.exists():
    create_tokenized_srcsets(
        dataset_name,
        sdata_path,
        func_only=False,
        pre_args=pre_args,
    )
tk_dataset = load_tokenized_srcsets(
    sdata_path,
    quicktest=quicktest,
    sets_to_load=["test"],
)


Loading TokenizedSrcSets:  /mnt/nas/jiayi/SPOT/TokenizedSrcSets/ManyTypes4Py-v5-PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)
254M	/mnt/nas/jiayi/SPOT/TokenizedSrcSets/ManyTypes4Py-v5-PreprocessArgs(imports_in_preamble=False, stub_in_preamble=False)


In [8]:
# model evaluation

from spot.function_decoding import (
    DecodingOrders,
    EvalResult,
    PreprocessArgs,
    RolloutCtx,
)
from spot.function_dataset import sigmap_from_file_predictions
from spot.static_analysis import SignatureErrorAnalysis

# load model
model = ModelWrapper.from_pretrained(get_model_dir() / model_name)
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
model.to(device)

ctx_args = model.args.ctx_args
model.args.sampling_max_tokens = ctx_args.ctx_size
model.args.do_sample = False
model.args.num_beams = 10
model.args.tokens_per_type = 16

eval_cache = PickleCache(get_eval_dir(dataset_name, model_name) / f"{pre_args}")
# eval_cache.clear()
pre_r = eval_cache.cached(
    "DatasetPredResult.pkl",
    lambda: model.eval_on_dataset(tk_dataset["test"]),
)


In [9]:
repos_dir = get_dataset_dir(dataset_name) / "repos" / "test"
test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]
test_projects = pmap(
    data_project_from_dir,
    test_repo_paths,
    desc="Loading test projects",
)
assert len(test_projects) > 0

common_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)
pred_map, label_map = sigmap_from_file_predictions(pre_r, test_projects, repos_dir)
accs = {
    m.name: SignatureErrorAnalysis(pred_map, label_map, m).accuracies
    for m in AccuracyMetric.default_metrics(common_names)
}

from spot.experiments.typet5 import accs_as_table_row
accs_as_table_row(accs)
pretty_print_dict(accs)

Loading test projects: 100%|██████████| 50/50 [00:26<00:00,  1.86it/s]


Accuracies on all types:
header:  ['full.all', 'calibrated.simple', 'calibrated.complex', 'calibrated.all', 'base.all']
67.07 & 72.12 & 44.05 & 67.47 & 73.44
Accuracies on common types:
header:  ['full.all', 'calibrated.simple', 'calibrated.complex', 'calibrated.all', 'base.all']
76.74 & 82.43 & 53.03 & 78.04 & 82.44


In [10]:
from spot.utils import decode_tokens, Path
from spot.visualization import export_preds_on_code

export_to = Path(f"caches/model_predictions/eval_file_model/{dataset_name}")
export_preds_on_code(pre_r.chunks, pre_r.predictions, export_to, AccuracyMetric(common_names))

Exporting: 100%|██████████| 1851/1851 [00:19<00:00, 96.51it/s] 
Computing accuracies: 100%|██████████| 1851/1851 [00:00<00:00, 11808.94it/s]
