In [1]:
%load_ext autoreload
%autoreload 2

import os
from typing import *

from spot.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()

In [2]:
# experiment configurations

from spot.data import (
    SrcDataset,
    get_dataset_name,
    load_src_datasets,
    TypeCheckSettings,
)
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper
from spot.train import TrainingConfig, TypeCheckArgs

config = TrainingConfig(
    quicktest=False,
    all_labels=True,
    ctx_size=2048,
    left_margin=1024,
    right_margin=512,
)
gpu_id = 1
TypeCheckSettings.temp_path = f"DAgger-{gpu_id}"

print(f"quicktest={config.quicktest}")

project_name = "test-SPOT" if config.quicktest else "SPOT"
train_ctx_args = config.train_ctx_args()
tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)

datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

model_name = "DAgger-model--" + config.as_name()

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    quicktest=config.quicktest,
)

  warn(f"Failed to load image Python extension: {e}")


quicktest=False
Loading datasets:  src_datasets-all_labels-drop_comments


In [5]:
# load the model
from spot.model import load_model_spot, DefaultTokenizer
from spot.model import ModelWrapper
from spot.dagger import DAggerModel
import torch

dec_args = DecodingArgs(
    sampling_max_tokens=8 * config.ctx_size,
    ctx_args=config.dec_ctx_args(),
    do_sample=False,
    num_beams=8,
)

wrapper = ModelWrapper.from_pretrained(datadir / f"checkpoints/saved/{model_name}")
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
dmodel = DAggerModel(wrapper)

In [12]:
# evaluate
from spot.utils import pretty_print_dict, pretty_show_dict
from spot.visualization import visualize_preds_on_code

eval_r = await dmodel.eval_on_data(src_datasets["train"][1:105:10])
pretty_print_dict(eval_r.accuracies)

compute_preexisting_fdbks: 100%|██████████| 6/6 [00:02<00:00,  2.15it/s]
eval_on_data: 100%|██████████| 180/180 [03:38<00:00,  1.21s/it]


partial_acc (ImNone): 64.44% (count=180)
full_acc (ImNone): 61.11% (count=180)
partial_acc: 63.33% (count=180)
ast_acc: 47.13% (count=244)
full_acc: 60.00% (count=180)
partial_acc_by_cat:
   FuncArg: 58.42% (count=101)
   FuncReturn: 67.12% (count=73)
   GlobalVar: 100.00% (count=3)
   LocalVar: 100.00% (count=3)
partial_acc_by_pos:
   range(0, 1): 90.91% (count=11)
   range(1, 2): 87.50% (count=8)
   range(2, 4): 100.00% (count=12)
   range(4, 8): 76.92% (count=13)
   range(8, 16): 45.45% (count=11)
   range(16, 32): 50.00% (count=16)
   range(32, 64): 56.25% (count=32)
   range(64, 128): 57.81% (count=64)
   range(128, 256): 53.85% (count=13)
avg_label_size: 1.3556
avg_pred_size: 1.0778


In [14]:
from spot.data import SrcDataset
from spot.visualization import export_preds_on_code

viz_ds = SrcDataset(src_datasets["test"].repos_root, eval_r.final_srcs)
viz_preds = eval_r.final_preds

export_preds_on_code(viz_ds, viz_preds, proj_root() / "caches/DAgger-preds-on-code")

Exporting: 100%|██████████| 11/11 [00:00<00:00, 97.24it/s]
Computing accuracies: 100%|██████████| 11/11 [00:00<00:00, 4176.84it/s]
