In [1]:
%load_ext autoreload
%autoreload 2

import asyncio
import os
from typing import *

import torch
import wandb
from spot.data import get_tk_dataset_name, load_tokenized_srcsets
from spot.function_dataset import data_project_from_dir
from spot.model import ModelWrapper
from spot.train import TrainingConfig, PreprocessArgs
from spot.type_env import AccuracyMetric
from spot.utils import (
    PickleCache,
    assert_eq,
    get_dataroot,
    get_dataset_dir,
    get_eval_dir,
    get_gpu_id,
    get_model_dir,
    pickle_dump,
    pmap,
    pretty_print_dict,
    pretty_show_dict,
    proj_root,
    run_long_task,
    write_file,
)
from spot.visualization import string_to_html
from termcolor import colored

os.chdir(proj_root())


def wandb_string(s: str):
    return wandb.Html(string_to_html(s))

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# experiment configurations
quicktest = False

gpu_id = get_gpu_id(1)
model_name = "model-v6--TrainingConfig(func_only=False, left_margin=2048, preamble_size=800, right_margin=1536)"
pre_args = PreprocessArgs()
dataset_name = "ManyTypes4Py"
# dataset_name = "SPOT-src"
experiment_name = dataset_name + ": " + model_name

print(colored(f"Use GPU: {gpu_id}", "green"))


GPU_ID not set, using: 1
[32mUse GPU: 1[0m


In [3]:
# load model
# model = ModelWrapper.from_pretrained(get_model_dir() / model_name)
# device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
# model.to(device)
# print(f"Model loaded to {device}")

# load test projects
repos_dir = get_dataset_dir(dataset_name) / "repos" / "test"
test_repo_paths = [f for f in repos_dir.iterdir() if f.is_dir()]
test_projects = pmap(
    data_project_from_dir,
    test_repo_paths,
    desc="Loading test projects",
)
assert len(test_projects) > 0

sdata_name = get_tk_dataset_name(dataset_name, pre_args, func_only=False)
sdata_path = get_dataroot() / "TokenizedSrcSets" / sdata_name
tk_dataset = load_tokenized_srcsets(
    sdata_path,
    quicktest=quicktest,
    sets_to_load=["test"],
)


Loading test projects: 100%|██████████| 50/50 [00:20<00:00,  2.48it/s]


Loading TokenizedSrcSets:  /mnt/nas/jiayi/SPOT/TokenizedSrcSets/ManyTypes4Py-v5-PreprocessArgs()
258M	/mnt/nas/jiayi/SPOT/TokenizedSrcSets/ManyTypes4Py-v5-PreprocessArgs()


In [4]:
# model evaluation

from spot.function_decoding import (
    DecodingOrders,
    EvalResult,
    PreprocessArgs,
    RolloutCtx,
    sigmap_from_file_predictions,
)
from spot.static_analysis import SignatureErrorAnalysis

# ctx_args = model.args.ctx_args
# model.args.sampling_max_tokens = ctx_args.ctx_size
# model.args.do_sample = False
# model.args.num_beams = 10
# model.args.tokens_per_type = 16

eval_cache = PickleCache(get_eval_dir(dataset_name, model_name) / "eval_cache")
# eval_cache.clear()
pre_r = eval_cache.cached(
    "dataset_pred.pkl",
    lambda: model.eval_on_dataset(tk_dataset["test"]),
)


In [5]:
common_names = ModelWrapper.load_common_type_names(get_model_dir() / model_name)
pred_map, label_map = sigmap_from_file_predictions(pre_r, test_projects, repos_dir)
accs = {
    m.name: SignatureErrorAnalysis(pred_map, label_map, m).accuracies
    for m in AccuracyMetric.default_metrics(common_names)
}

pretty_print_dict(accs)

plain_acc:
   plain_acc: 67.46% (count=15.7k)
   plain_acc_by_common:
      rare: 52.61% (count=5.6k)
      common: 75.62% (count=10.1k)
   plain_acc_by_cat:
      FuncArg: 65.21% (count=8.0k)
      FuncReturn: 75.88% (count=5.8k)
      ClassAtribute: 50.87% (count=1.8k)
      GlobalVar: 66.36% (count=107)
   plain_acc_label_size: 1.4194
   plain_acc_pred_size: 1.3838
   plain_acc_ignored_labels: 0
   n_skipped_types: 0
   n_missing_types: 53
acc:
   acc: 67.47% (count=13.2k)
   acc_by_common:
      rare: 55.59% (count=5.6k)
      common: 76.11% (count=7.6k)
   acc_by_cat:
      FuncArg: 66.93% (count=6.7k)
      FuncReturn: 68.01% (count=4.9k)
      ClassAtribute: 67.73% (count=1.5k)
      GlobalVar: 72.73% (count=99)
   acc_label_size: 1.3155
   acc_pred_size: 1.2906
   acc_ignored_labels: 2521
   n_skipped_types: 0
   n_missing_types: 53
base_acc:
   base_acc: 74.71% (count=13.2k)
   base_acc_by_common:
      rare: 62.27% (count=4.8k)
      common: 81.79% (count=8.4k)
   base_acc_by