In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from spot.utils import cst, read_file, write_file, seq_flatten, proj_root, parallel_map_unordered
import os
from spot.type_env import (
    collect_annotations, MypyChecker, AnnotPath, mypy_checker, 
    TypeInfEnv, TypeInfState, TypeInfAction, SelectAnnotations)
from spot.data_prepare import GitRepo
import shutil
import pickle
from pathlib import Path
import pandas as pd
import plotly.express as px
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import time
from typing import *

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

In [3]:
# remove `data/temp` if it exists
inference_dir = "data/code_output/inference"
if os.path.exists(inference_dir):
    shutil.rmtree(inference_dir)
if not os.path.exists(inference_dir):
    os.mkdir(inference_dir)
write_file(f"{inference_dir}/env_code_1.py", read_file("data/code/env_code_1.py"))

In [4]:
inf_checker = MypyChecker(".venv/bin/dmypy", inference_dir)
env = TypeInfEnv(inf_checker, f"{inference_dir}/env_code_1.py", select_annotations=SelectAnnotations.select_all_paths)
env.reset()
print(env.state)

Daemon started

num_errors: 0
num_to_annot: 11
to_annotate: [AnnotPath('fib.n'), AnnotPath('fib.<return>'), AnnotPath('foo.bar'), AnnotPath('foo.<return>'), AnnotPath('int_add.a'), AnnotPath('int_add.b'), AnnotPath('int_add.<return>'), AnnotPath('int_tripple_add.a'), AnnotPath('int_tripple_add.b'), AnnotPath('int_tripple_add.c'), AnnotPath('int_tripple_add.<return>')]
------------------------ code -------------------------------
# Env example 1: no existing annotations

from typing import Any  # [added by SPOT]
def fib(n):
    if n == 0:
        return 0
    elif n == 1:
        return 1
    else:
        return fib(n-1) + fib(n-2)

def foo(bar):
    return fib(bar)

def int_add(a, b):
    return a + b + "c"

def int_tripple_add(a, b, c):
    return a + b + c



In [5]:
import ipywidgets as widgets

def display_code_sequence(texts: list[str], titles=None):
    if titles is None:
        titles = range(len(texts))
    outputs = [widgets.HTML(value=f"<pre style='line-height:1.2;'>{s}</pre>") for s in texts]

    tab = widgets.Tab(outputs)
    for i, t in enumerate(titles):
        tab.set_title(i, str(t))
    return tab


def test_policy(env: TypeInfEnv, pi: Callable[[TypeInfState], TypeInfAction]):
    env.reset()
    state_seq = [str(env.state)]

    while len(env.state.to_annot) > 0:
        env.step(pi(env.state))
        state_seq.append(str(env.state))

    return display_code_sequence(state_seq)

In [47]:
test_policy(env, lambda s: TypeInfAction(s.to_annot[0], cst.Name("str")))

Tab(children=(HTML(value='<pre style=\'line-height:1.2;\'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Ann…

In [48]:
test_policy(env, lambda s: TypeInfAction(s.to_annot[0], cst.Name("int")))

Tab(children=(HTML(value='<pre style=\'line-height:1.2;\'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Ann…

In [50]:
import random

test_policy(env, lambda s: 
    TypeInfAction(s.to_annot[0], cst.Name(random.choice(["int", "str"]))))

Tab(children=(HTML(value='<pre style=\'line-height:1.2;\'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Ann…

In [58]:
from spot.type_env import test_inference_performance
from spot.utils import parallel_map_unordered

test_dirs = [r.repo_dir(repos_dir) for r in useful_repos[:2] if r.lines_of_code < 10000]
with ProcessPoolExecutor(max_workers=10) as executor:
    results = parallel_map_unordered(test_inference_performance, test_dirs, executor)
n_checks = sum(r["n_checks"] for r in results)
total_time = sum(r["time"] for r in results)
print(f"{n_checks} checks in {total_time} seconds")
print(f"{n_checks / total_time} checks/second")



[A[A

Daemon started
Daemon started


  0%|          | 0/10 [16:34<?, ?it/s]
 20%|██        | 3/15 [09:53<39:32, 197.72s/it]


[A[A

Daemon stopped


  0%|          | 0/10 [16:58<?, ?it/s]
 20%|██        | 3/15 [10:17<41:11, 205.98s/it]


100%|██████████| 2/2 [00:52<00:00, 26.03s/it]

Daemon stopped
479 checks in 69.5186333656311 seconds
6.890239016649167 checks/second



