In [1]:
%load_ext autoreload
%autoreload 2

import os
from typing import *

from typet5.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()

In [2]:
# experiment configurations

from typet5.data import (
    TokenizedSrcSet,
    get_dataset_name,
    load_tokenized_srcsets,
    TypeCheckSettings,
)
from typet5.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper
from typet5.train import TrainingConfig, TypeCheckArgs

config = TrainingConfig(
    quicktest=False,
    all_labels=True,
    ctx_size=2048,
    left_margin=1024,
    right_margin=1023,
    modifications="no_type_checker",
)
gpu_id = 1
TypeCheckSettings.temp_path = f"DAgger-{gpu_id}"

print(f"quicktest={config.quicktest}")

project_name = "test-SPOT" if config.quicktest else "SPOT"
train_ctx_args = config.train_ctx_args()
tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)

datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
    imports_in_preamble=config.imports_in_preamble,
)

model_name = "DAgger-model--" + config.as_name()

tk_dataset = load_tokenized_srcsets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    quicktest=config.quicktest,
)

  warn(f"Failed to load image Python extension: {e}")


quicktest=False
Loading datasets:  tk_dataset-all_labels-drop_comments


In [7]:
# load the model
from typet5.model import load_model_spot, DefaultTokenizer
from typet5.model import ModelWrapper
from typet5.dagger import DAggerModel
import torch

dec_args = DecodingArgs(
    sampling_max_tokens=8 * config.ctx_size,
    ctx_args=config.dec_ctx_args(),
    do_sample=True,
    num_beams=None, # try greedy decoding
    top_p=0.9,
)

wrapper = ModelWrapper.from_pretrained(datadir / f"checkpoints/saved/{model_name}")
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
wrapper.to(device)
dmodel = DAggerModel(wrapper)

In [5]:
# evaluate (greedy)
from typet5.utils import pretty_print_dict, pretty_show_dict
from typet5.visualization import visualize_preds_on_code

eval_r = await dmodel.eval_on_data(tk_dataset["test"])
pretty_print_dict(eval_r.accuracies)

compute_preexisting_fdbks: 100%|██████████| 50/50 [00:04<00:00, 11.99it/s]
eval_on_data: 100%|██████████| 16950/16950 [46:45<00:00,  6.04it/s]


partial_acc (ImNone): 69.01% (count=16.9k)
full_acc (ImNone): 64.86% (count=16.9k)
partial_acc: 67.43% (count=16.9k)
ast_acc: 56.81% (count=21.3k)
full_acc: 62.03% (count=16.9k)
partial_acc_by_cat:
   FuncArg: 62.96% (count=8.0k)
   FuncReturn: 78.14% (count=5.7k)
   ClassAtribute: 58.21% (count=2.7k)
   GlobalVar: 75.96% (count=104)
   LocalVar: 64.22% (count=531)
partial_acc_by_pos:
   range(0, 1): 80.39% (count=933)
   range(1, 2): 77.13% (count=870)
   range(2, 4): 77.58% (count=1.5k)
   range(4, 8): 74.05% (count=2.4k)
   range(8, 16): 72.80% (count=3.1k)
   range(16, 32): 67.31% (count=3.2k)
   range(32, 64): 63.86% (count=2.3k)
   range(64, 128): 53.42% (count=1.1k)
   range(128, 256): 40.00% (count=735)
   range(256, 512): 32.89% (count=672)
   range(512, 1024): 52.83% (count=53)
avg_label_size: 1.2589
avg_pred_size: 1.1258


In [10]:
# evaluate
from numpy import roll
from typet5.utils import pretty_print_dict, pretty_show_dict
from typet5.visualization import visualize_preds_on_code

dmodel.wrapper.args = DecodingArgs(
    sampling_max_tokens=8 * config.ctx_size,
    ctx_args=config.dec_ctx_args(),
    do_sample=True,  # use necleus sampling during training
    top_p=0.9,
)

eval_r = await dmodel.eval_on_data(tk_dataset["train"][1:105:10])
pretty_print_dict(eval_r.accuracies)

compute_preexisting_fdbks: 100%|██████████| 6/6 [00:02<00:00,  2.16it/s]
eval_on_data: 100%|██████████| 180/180 [04:57<00:00,  1.65s/it]


partial_acc (ImNone): 63.89% (count=180)
full_acc (ImNone): 61.11% (count=180)
partial_acc: 65.00% (count=180)
ast_acc: 49.59% (count=244)
full_acc: 60.56% (count=180)
partial_acc_by_cat:
   FuncArg: 58.42% (count=101)
   FuncReturn: 71.23% (count=73)
   GlobalVar: 100.00% (count=3)
   LocalVar: 100.00% (count=3)
partial_acc_by_pos:
   range(0, 1): 90.91% (count=11)
   range(1, 2): 100.00% (count=8)
   range(2, 4): 100.00% (count=12)
   range(4, 8): 69.23% (count=13)
   range(8, 16): 54.55% (count=11)
   range(16, 32): 43.75% (count=16)
   range(32, 64): 50.00% (count=32)
   range(64, 128): 60.94% (count=64)
   range(128, 256): 76.92% (count=13)
avg_label_size: 1.3556
avg_pred_size: 1.1


In [12]:
from typet5.data import TokenizedSrcSet
from typet5.visualization import export_preds_on_code

viz_ds = TokenizedSrcSet(tk_dataset["test"].repos_root, eval_r.final_srcs)
viz_preds = eval_r.final_preds

export_preds_on_code(viz_ds, viz_preds, proj_root() / "caches/DAgger-preds-on-code")

Exporting: 100%|██████████| 11/11 [00:00<00:00, 131.04it/s]
Computing accuracies: 100%|██████████| 11/11 [00:00<00:00, 9128.88it/s]
