In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

from spot.utils import proj_root

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

import torch

from spot.data import (
    SrcDataset,
    get_dataset_name,
    load_src_datasets,
)
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper
from copy import copy
from spot.train import TrainingConfig

config = TrainingConfig(quicktest=False, all_labels=True)
train_R1: bool = True
load_trained: bool = False
gpu_id = 1

project_name = "test-SPOT" if config.quicktest else "SPOT"
train_ctx_args = config.train_ctx_args()

max_tokens_per_file = config.ctx_size
dec_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=20,
)


datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

r0_model_name = "R0-model--" + config.as_name()

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=config.quicktest,
)
testset = src_datasets["test"]


  warn(f"Failed to load image Python extension: {e}")


Loading datasets:  src_datasets-all_labels-drop_comments


In [3]:
# load trained model
from spot.utils import pickle_load, pickle_dump

r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
if train_R1:
    r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(left=2048, window=1024, right=1024, max_labels=16), sampling_max_tokens=32768, max_workers=20, max_tokens_per_type=10, do_sample=False, top_p=0.9, num_beams=None)


In [5]:
# compute results
from spot.decode import sample_then_select
from spot.train import evaluate_model, visualize_accuracies
from copy import deepcopy
from spot.utils import pretty_print_dict

r0_wrapper.args.do_sample = False
r0_wrapper.args.top_p = 0.9

results = evaluate_model(r0_wrapper, None, testset, check_in_isolation=False)

r0_wrapper.args.do_sample = True
sample_eval, sample_stats = sample_then_select(r0_wrapper, testset, n_samples=32)
results.append((deepcopy(r0_wrapper.args), sample_eval))


chunk_srcs_per_file:   0%|          | 0/933 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

Sampling:   0%|          | 0/32 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/52384 [00:00<?, ?it/s]

n_errors: 81733


VBox(children=(Tab(children=(VBox(children=(HTML(value="<div style='white-space: pre-wrap; line-height: 1.2;'>…

In [12]:
# compute results
from spot.decode import sample_then_select
from spot.train import evaluate_model, visualize_accuracies
from copy import deepcopy
from spot.utils import pretty_print_dict

r0_wrapper.args.do_sample = True
r0_wrapper.args.top_p = 0.9

results.extend(evaluate_model(r0_wrapper, None, testset, check_in_isolation=False))

chunk_srcs_per_file:   0%|          | 0/933 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

In [5]:
from spot.decode import collect_type_errors_from_predictions
from spot.model import DatasetPredResult
from spot.type_check import PythonType
from spot.data import SrcDataset


def collect_base_errors(dataset: SrcDataset):
    chunks = dataset.to_chunks(
        r0_wrapper.tokenizer, r0_wrapper.args.ctx_args, tqdm_args={"disable": True}
    )
    dummy_preds = [
        [PythonType(("Any",)) for _ in info.types] for info in chunks.chunks_info
    ]
    pred_r = DatasetPredResult(chunks, dummy_preds)
    return collect_type_errors_from_predictions(dataset, pred_r, max_workers=30)


base_errors = collect_base_errors(testset)
greedy_errors = collect_type_errors_from_predictions(testset, results[0][1], max_workers=30)
sample_errors = collect_type_errors_from_predictions(testset, results[1][1], max_workers=30)

errors0 = len(greedy_errors) - len(base_errors)
errors1 = len(sample_errors) - len(base_errors)


map type_check_src_in_project:   0%|          | 0/50 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/50 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/50 [00:00<?, ?it/s]

In [13]:
sample_once_errors = collect_type_errors_from_predictions(testset, results[2][1], max_workers=30)
errors2 = len(sample_once_errors) - len(base_errors)

map type_check_src_in_project:   0%|          | 0/50 [00:00<?, ?it/s]

In [16]:
from spot.utils import pickle_dump

pickle_dump(Path("inference_spot.pkl"), 
    {"results": results, "base_errors": base_errors, "greedy_errors": greedy_errors, "sample_errors": sample_errors, "sample_once_errors": sample_once_errors})

In [4]:
from spot.utils import pickle_load

match pickle_load(Path("inference_spot.pkl")):
    case {"results": results, "base_errors": base_errors, "greedy_errors": greedy_errors, "sample_errors": sample_errors, "sample_once_errors": sample_once_errors}:
        print("Loaded!")

Loaded!


In [15]:
from spot.visualization import pretty_display_dict, visualize_dicts, display
display(pretty_display_dict(
    {"greedy_errors": errors0, "sampling_errors": errors1, "sample_once_errors": errors2, "base_errors": len(base_errors)}
))
visualize_dicts([r[1].accuracies for r in results])

Tab(children=(HTML(value="<div style='white-space: pre-wrap; line-height: 1.2; font-family: monospace, monospa…

VBox(children=(Tab(children=(Tab(children=(HTML(value="<div style='white-space: pre-wrap; line-height: 1.2; fo…

In [12]:
from spot.train import visualize_conf_matrix
visualize_conf_matrix([x[1] for x in results])

interactive(children=(IntSlider(value=1, description='round', max=1), IntSlider(value=10, continuous_update=Fa…

In [7]:
from spot.visualization import visualize_type_distribution, seq_flatten
from ipywidgets import interact

def show_type_distr(round: int, recursive: bool, top_k: int):
    display(visualize_type_distribution(seq_flatten(results[round][1].predictions), top_k=top_k, recursive=recursive))


interact(show_type_distr, round=(0, 1), recursive=True, top_k=10)

interactive(children=(IntSlider(value=0, description='round', max=1), Checkbox(value=True, description='recurs…

<function __main__.show_type_distr(round: int, recursive: bool, top_k: int)>

In [11]:
from spot.visualization import visualize_counts
from spot.utils import Counter

base_error_codes = Counter(e.error_code for _, e in base_errors)
greedy_error_codes = Counter(e.error_code for _, e in greedy_errors)
sample_error_codes = Counter(e.error_code for _, e in sample_errors)

for k, v in base_error_codes.items():
    greedy_error_codes[k] = greedy_error_codes[k] - v
    sample_error_codes[k] = sample_error_codes[k] - v
display(visualize_counts(greedy_error_codes, "Error", title="Error distribution (Greedy)"))
display(visualize_counts(sample_error_codes, "Error", title="Error distribution (Sampling)"))