In [2]:
%load_ext autoreload
%autoreload 2

import os
import asyncio
from typing import *

from spot.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()

In [3]:
# experiment configurations

from spot.data import (
    SrcDataset,
    get_dataset_name,
    load_src_datasets,
    TypeCheckSettings,
)
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper
from spot.train import TrainingConfig, TypeCheckArgs

config = TrainingConfig(
    quicktest=False,
    all_labels=True,
    ctx_size=2048,
    left_margin=1024,
    right_margin=512,
)
gpu_id = 0
TypeCheckSettings.temp_path = f"DAgger-{gpu_id}"

print(f"quicktest={config.quicktest}")

project_name = "test-SPOT" if config.quicktest else "SPOT"
train_ctx_args = config.train_ctx_args()
tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)

dec_args = DecodingArgs(
    sampling_max_tokens=8 * config.ctx_size,
    ctx_args=config.dec_ctx_args(),
)

datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

model_name = "DAgger-model--" + config.as_name()

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    quicktest=config.quicktest,
)

  warn(f"Failed to load image Python extension: {e}")


quicktest=False
Loading datasets:  src_datasets-all_labels-drop_comments


In [4]:
# load the model
from spot.model import load_model_spot, DefaultTokenizer
from spot.model import ModelWrapper
from spot.dagger import DAggerModel
import torch

wrapper = ModelWrapper.from_pretrained(datadir / f"checkpoints/saved/{model_name}")
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
dmodel = DAggerModel(wrapper)

In [5]:
# evaluate
from spot.utils import pretty_print_dict, pretty_show_dict
from spot.visualization import visualize_preds_on_code

eval_r = await dmodel.eval_on_data(src_datasets["test"][0:5])
pretty_print_dict(eval_r.accuracies)

compute_preexisting_fdbks: 100%|██████████| 1/1 [00:01<00:00,  1.30s/it]
eval_on_data: 100%|██████████| 157/157 [01:01<00:00,  2.55it/s]

partial_acc (ImNone): 79.62% (count=157)
full_acc (ImNone): 76.43% (count=157)
partial_acc: 76.43% (count=157)
ast_acc: 62.32% (count=207)
full_acc: 71.34% (count=157)
partial_acc_by_cat:
   FuncArg: 73.17% (count=82)
   FuncReturn: 80.00% (count=60)
   ClassAtribute: 88.89% (count=9)
   LocalVar: 66.67% (count=6)
partial_acc_by_pos:
   range(0, 1): 100.00% (count=5)
   range(1, 2): 100.00% (count=5)
   range(2, 4): 75.00% (count=8)
   range(4, 8): 100.00% (count=16)
   range(8, 16): 83.33% (count=30)
   range(16, 32): 72.09% (count=43)
   range(32, 64): 64.00% (count=50)
avg_label_size: 1.3185
avg_pred_size: 1.1656





In [6]:
from spot.data import SrcDataset
from spot.visualization import export_preds_on_code

viz_ds = SrcDataset(src_datasets["test"].repos_root, eval_r.final_srcs)
viz_preds = eval_r.final_preds

export_preds_on_code(viz_ds, viz_preds, proj_root() / "caches/DAgger-preds-on-code")

Exporting: 100%|██████████| 5/5 [00:00<00:00, 97.71it/s]
Computing accuracies: 100%|██████████| 5/5 [00:00<00:00, 3873.57it/s]
