In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

from spot.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

import torch

from spot.data import (
    SrcDataset,
    get_dataset_name,
    load_src_datasets,
)
from copy import copy
from spot.train import TrainingConfig, TypeCheckArgs

config = TrainingConfig(quicktest=False, all_labels=True)
train_R1: bool = True
gpu_id = 1

project_name = "test-SPOT" if config.quicktest else "SPOT"
train_ctx_args = config.train_ctx_args()

max_tokens_per_file = config.ctx_size

datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)

r0_model_name = "R0-model--" + config.as_name()

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=config.quicktest,
)


  warn(f"Failed to load image Python extension: {e}")


Loading datasets:  src_datasets-all_labels-drop_comments


In [3]:
# load trained model
from spot.utils import pickle_load, pickle_dump
from spot.model import ModelWrapper


r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
# if train_R1:
#     r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
#     r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(ctx_size=4096, left_margin=2048, right_margin=1024), sampling_max_tokens=32768, max_workers=20)


In [4]:
# load the critic

from spot.critic import CriticModel

critic_no_feedback = False
feedback_tag = "no_feedback-" if critic_no_feedback else ""
critic_name = "critic-model--" + feedback_tag + config.as_name()
critic = CriticModel.load(datadir / f"checkpoints/lit-saved/{critic_name}")
critic.to(device)
print("Critic loaded.")

Critic loaded.


In [11]:
# compute results
from spot.decode import sample_candidates, select_candidates_by_type_errors, select_candidates_using_oracle, select_candidates_using_critic
from spot.train import evaluate_model
from spot.model import DatasetPredResult
from spot.utils import pretty_print_dict, run_long_task, PickleCache
from spot.model import CtxArgs, DecodingArgs, ModelSPOT

testset = src_datasets["test"][0:50:10]

# used for inference
n_samples = 16
greedy_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=28,
    do_sample=False,
)

sample_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=28,
    do_sample=True,
    top_p=0.9,
)

bs_args = DecodingArgs(
    sampling_max_tokens=max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=28,
    do_sample=False,
    num_beams=n_samples,
)

dbs_args = DecodingArgs(
    sampling_max_tokens=max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=28,
    do_sample=False,
    num_beams=n_samples,
    num_beam_groups=n_samples,
    diversity_penalty=0.5,
)

results = dict[str, DatasetPredResult]()

eval_cache = PickleCache(proj_root() / "caches" / "inference_spot" / r0_model_name)
eval_cache.clear()

with run_long_task("Computing results"):
    # r0_wrapper.args = greedy_args
    # results["greedy"] = evaluate_model(r0_wrapper, None, testset, eval_cache=eval_cache, tc_args=tc_args)[0][1]

    # r0_wrapper.args = bs_args
    # results["BS"] = evaluate_model(r0_wrapper, None, testset, eval_cache=eval_cache, tc_args=tc_args)[0][1]

    # r0_wrapper.args = dbs_args
    # results["DBS"] = evaluate_model(r0_wrapper, None, testset, check_in_isolation=False)[0][1]

    r0_wrapper.args = bs_args
    test_chunks, pred_candidates = sample_candidates(r0_wrapper, testset, n_samples=n_samples)

    # sample_eval = select_candidates_by_type_errors(testset, test_chunks, pred_candidates)
    # results["BS + feedback"] = sample_eval

    # sample_eval = select_candidates_using_critic(critic, testset, test_chunks, pred_candidates, dec_args=greedy_args)
    # results["BS + critic"] = sample_eval

    # oracle_eval = select_candidates_using_oracle(test_chunks, pred_candidates)
    # results["BS + oracle"] = oracle_eval
    from spot.type_check import parse_type_str
    from spot.utils import with_default_workers

    bad_pred_candidates = [[[parse_type_str("BadType") for y in ys] for ys in xs] for xs in pred_candidates]

    sample_eval = select_candidates_using_critic(critic, testset, test_chunks, pred_candidates, dec_args=greedy_args)
    results["BS + critic"] = sample_eval




Starting task: Computing results


predict:   0%|          | 0/6 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/96 [00:00<?, ?it/s]

Type checking success rate: 100.00%
Average number of feedbacks per check: 4.50


map to_critic_inputs:   0%|          | 0/96 [00:00<?, ?it/s]

predict:   0%|          | 0/96 [00:00<?, ?it/s]

Average number of errors after selection: 4.5
Pushover: (Finished: 'Computing results'.) Time taken: 28.5s


In [6]:
from spot.type_check import parse_type_str
from spot.utils import with_default_workers

bad_pred_candidates = [[[parse_type_str("BadType") for y in ys] for ys in xs] for xs in pred_candidates]

with with_default_workers(1):
    sample_eval = select_candidates_using_critic(critic, testset, test_chunks, bad_pred_candidates, dec_args=greedy_args)
results["BS + critic"] = sample_eval

NameError: name 'pred_candidates' is not defined

In [None]:
from spot.visualization import visualize_dicts

visualize_dicts([x.accuracies for x in results.values()], titles=list(results.keys()))

In [None]:
from spot.utils import not_none
from spot.visualization import visualize_preds_on_code

critic_eval = results["BS + critic"]
preds_extra = {"critic_preds": [x["critic_preds"] for x in not_none(critic_eval.extra_info)]}
visualize_preds_on_code(critic_eval.chunks, critic_eval.predictions, preds_extra)

In [None]:
from spot.decode import collect_type_errors_from_predictions
from spot.model import DatasetPredResult
from spot.type_check import PythonType, MypyFeedback
from spot.data import SrcDataset


def collect_base_errors(dataset: SrcDataset):
    "Collect the type errors triggered by replacing all labels with `Any`."
    chunks = dataset.to_chunks(
        r0_wrapper.tokenizer, r0_wrapper.args.ctx_args, tqdm_args={"disable": True}
    )
    dummy_preds = [
        [PythonType(("Any",)) for _ in info.types] for info in chunks.chunks_info
    ]
    pred_r = DatasetPredResult(chunks, dummy_preds)
    return collect_type_errors_from_predictions(dataset, pred_r, max_workers=30)

def collect_gold_errors(dataset: SrcDataset):
    "Collect the type errors triggered by ground-truth labels."
    chunks = dataset.to_chunks(
        r0_wrapper.tokenizer, r0_wrapper.args.ctx_args, tqdm_args={"disable": True}
    )
    label_preds = [
        info.types for info in chunks.chunks_info
    ]
    pred_r = DatasetPredResult(chunks, label_preds)
    return collect_type_errors_from_predictions(dataset, pred_r, max_workers=30)

num_labels = sum(len(s.types) for s in testset.srcs_with_labels())
print("Total number of labels: ", num_labels)
type_errors = dict[str, list[tuple[Path, MypyFeedback]]]()
type_errors["default"] = collect_base_errors(testset)
type_errors["gold"] = collect_gold_errors(testset)
for k, v in results.items():
    type_errors[k] = collect_type_errors_from_predictions(testset, v, max_workers=30)

from spot.visualization import pretty_display_dict, display
display(pretty_display_dict({k: len(v) for k, v in type_errors.items()}))

In [None]:
from spot.utils import pickle_dump

pickle_dump(Path("caches/inference_spot.pkl"), 
    {"results": results, "type_errors": type_errors})

In [None]:
from spot.utils import pickle_load

if False:
    match pickle_load(Path("caches/inference_spot.pkl")):
        case {"results": results, "type_errors": type_errors}:
            print("Loaded!")

In [None]:
from spot.visualization import seq_flatten, visualize_counts
from ipywidgets import interact
from spot.utils import Counter
from spot.type_check import count_type_frequency

def show_type_distr(recursive: bool, top_k: int):
    counts = dict[str, Counter]()
    for name in ["greedy", "BS + feedback"]:
        types = seq_flatten(results[name].predictions)
        counts[name] = count_type_frequency(types, recursive=recursive)

    display(visualize_counts(counts, x_name="Predicted Type", top_k=top_k))


show_type_distr(recursive=True, top_k=15)

In [None]:
from spot.visualization import visualize_counts, visualize_sequence_tabs, display
from spot.utils import Counter

default_counts = Counter(e.error_code for _, e in type_errors["default"])

error_counts = dict[str, Counter]()
for name in ["gold"]: #["greedy", "BS + feedback"]:
    c = Counter(e.error_code for _, e in type_errors[name])
    for e, v in default_counts.items():
        c[e] -= v
    error_counts[name] = c
display(visualize_counts(error_counts, "Error"))

In [None]:
from spot.visualization import visualize_conf_matrix
visualize_conf_matrix(results)

In [None]:
import torch
torch.cuda.empty_cache()