In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

from spot.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

import torch

from spot.data import (
    SrcDataset,
    get_dataset_name,
    load_src_datasets,
)
from copy import copy
from spot.train import TrainingConfig

config = TrainingConfig(quicktest=False, all_labels=True)
train_R1: bool = True
load_trained: bool = False
gpu_id = 1

project_name = "test-SPOT" if config.quicktest else "SPOT"
train_ctx_args = config.train_ctx_args()

max_tokens_per_file = config.ctx_size

datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

r0_model_name = "R0-model--" + config.as_name()

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=config.quicktest,
)


  warn(f"Failed to load image Python extension: {e}")


Loading datasets:  src_datasets-all_labels-drop_comments


In [3]:
# load trained model
from spot.utils import pickle_load, pickle_dump
from spot.model import ModelWrapper


r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
if train_R1:
    r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(left=2048, window=1024, right=1024, max_labels=16), sampling_max_tokens=32768, max_workers=20, max_tokens_per_type=10, do_sample=False, top_p=0.9, num_beams=None, num_beam_groups=None, diversity_penalty=None)


In [4]:
# compute results
from spot.decode import sample_then_select
from spot.train import evaluate_model
from spot.model import DatasetPredResult
from copy import deepcopy
from spot.utils import pretty_print_dict, run_long_task, PickleCache
from spot.model import CtxArgs, DecodingArgs, ModelSPOT

# used for inference
n_samples = 16
greedy_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=20,
    do_sample=False,
)

sample_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=20,
    do_sample=True,
    top_p=0.9,
)

bs_args = DecodingArgs(
    sampling_max_tokens=max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=20,
    do_sample=False,
    num_beams=n_samples,
)

dbs_args = DecodingArgs(
    sampling_max_tokens=max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=20,
    do_sample=False,
    num_beams=n_samples,
    num_beam_groups=n_samples,
    diversity_penalty=0.5,
)

testset = src_datasets["test"]
results = dict[str, DatasetPredResult]()

eval_cache = PickleCache(proj_root() / "inference_spot-eval_cache")
# eval_cache.clear()

with run_long_task("Computing results"):
    r0_wrapper.args = greedy_args
    results["greedy"] = evaluate_model(r0_wrapper, None, testset, eval_cache=eval_cache, check_in_isolation=False)[0][1]

    r0_wrapper.args = bs_args
    results["BS"] = evaluate_model(r0_wrapper, None, testset, check_in_isolation=False)[0][1]

    # r0_wrapper.args = dbs_args
    # results["DBS"] = evaluate_model(r0_wrapper, None, testset, check_in_isolation=False)[0][1]

    sample_eval = sample_then_select(r0_wrapper, testset, n_samples=n_samples)
    results["BS + feedback"] = sample_eval


chunk_srcs_per_file:   0%|          | 0/933 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

chunk_srcs_per_file:   0%|          | 0/933 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

predict:   0%|          | 0/1637 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/26192 [00:00<?, ?it/s]

Pushover: (Finished: 'Computing results'.) Time taken: 9201.5s


In [5]:
from spot.visualization import visualize_dicts

visualize_dicts([x.accuracies for x in results.values()], titles=list(results.keys()))

VBox(children=(Tab(children=(Tab(children=(HTML(value="<div style='white-space: pre-wrap; line-height: 1.2; fo…

In [10]:
from spot.decode import collect_type_errors_from_predictions
from spot.model import DatasetPredResult
from spot.type_check import PythonType, MypyFeedback
from spot.data import SrcDataset


def collect_base_errors(dataset: SrcDataset):
    chunks = dataset.to_chunks(
        r0_wrapper.tokenizer, r0_wrapper.args.ctx_args, tqdm_args={"disable": True}
    )
    dummy_preds = [
        [PythonType(("Any",)) for _ in info.types] for info in chunks.chunks_info
    ]
    pred_r = DatasetPredResult(chunks, dummy_preds)
    return collect_type_errors_from_predictions(dataset, pred_r, max_workers=30)

type_errors = dict[str, list[tuple[Path, MypyFeedback]]]()
type_errors["default"] = collect_base_errors(testset)
for k, v in results.items():
    type_errors[k] = collect_type_errors_from_predictions(testset, v, max_workers=30)

from spot.visualization import pretty_display_dict, display
display(pretty_display_dict({k: len(v) for k, v in type_errors.items()}))

map type_check_src_in_project:   0%|          | 0/50 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/50 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/50 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/50 [00:00<?, ?it/s]

Tab(children=(HTML(value="<div style='white-space: pre-wrap; line-height: 1.2; font-family: monospace, monospa…

In [7]:
from spot.utils import pickle_dump

pickle_dump(Path("inference_spot.pkl"), 
    {"results": results, "type_errors": type_errors})

In [11]:
from spot.utils import pickle_load

match pickle_load(Path("inference_spot.pkl")):
    case {"results": results, "type_errors": type_errors}:
        print("Loaded!")

Loaded!


In [42]:
from spot.visualization import seq_flatten, visualize_counts
from ipywidgets import interact
from spot.utils import Counter
from spot.type_check import count_type_frequency

def show_type_distr(recursive: bool, top_k: int):
    counts = dict[str, Counter]()
    for name in ["greedy", "BS + feedback"]:
        types = seq_flatten(results[name].predictions)
        counts[name] = count_type_frequency(types, recursive=recursive)

    display(visualize_counts(counts, x_name="Predicted Type", top_k=top_k))


interact(show_type_distr, recursive=True, top_k=15)

interactive(children=(Checkbox(value=True, description='recursive'), IntSlider(value=15, description='top_k', …

<function __main__.show_type_distr(recursive: bool, top_k: int)>

In [41]:
from spot.visualization import visualize_counts, visualize_sequence_tabs, display
from spot.utils import Counter
import ipywidgets as widgets

default_counts = Counter(e.error_code for _, e in type_errors["default"])

error_counts = dict[str, Counter]()
for name in ["greedy", "BS + feedback"]:
    c = Counter(e.error_code for _, e in type_errors[name])
    for e, v in default_counts.items():
        c[e] -= v
    error_counts[name] = c
display(visualize_counts(error_counts, "Error"))

In [28]:
import plotly.express as px
import pandas as pd

greedy_counts = Counter(e.error_code for _, e in type_errors["greedy"])
keys = list(greedy_counts.keys())[:10]
data = pd.DataFrame({"Error": keys, "Greedy": [greedy_counts.get(k, 0) for k in keys], "Default": [default_counts.get(k, 0) for k in keys]})

px.bar(data, x="Error", y=["Greedy", "Default"])