In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from spot.utils import cst, read_file, write_file, seq_flatten, proj_root
import os
from spot.type_env import (
    collect_annotations, MypyChecker, AnnotPath, mypy_checker, 
    TypeInfEnv, TypeInfState, TypeInfAction, SelectAnnotations)
from spot.data_prepare import GitRepo
import shutil
import pickle
from pathlib import Path
import pandas as pd
import plotly.express as px
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import time
from typing import *

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

In [9]:
import ipywidgets as widgets
import html

def display_code_sequence(texts: list[str], titles=None):
    if titles is None:
        titles = range(len(texts))
    outputs = [widgets.HTML(value=f"<pre style='line-height:1.2;'>{html.escape(s)}</pre>") for s in texts]

    tab = widgets.Tab(outputs)
    for i, t in enumerate(titles):
        tab.set_title(i, str(t))
    return tab


def test_policy(env: TypeInfEnv, pi: Callable[[TypeInfState], TypeInfAction]):
    env.reset()
    state_seq = [str(env.state)]

    while len(env.state.to_annot) > 0:
        if env.step(act := pi(env.state)):
            type_str = env.state.module.code_for_node(act.type)
            print(f"Action rejected: {type_str} @ {str(act.path)}")
        state_seq.append(str(env.state))

    return display_code_sequence(state_seq)

In [4]:
# remove `inference_dir` if it exists
inference_dir = Path("data/code_output/inference")
if inference_dir.exists():
    shutil.rmtree(inference_dir)
inference_dir.mkdir(parents=True)
write_file(inference_dir/"env_code_1.py", read_file("data/code/env_code_1.py"))

In [5]:
inf_checker = MypyChecker(".venv/bin/dmypy", inference_dir)
env = TypeInfEnv(inf_checker, inference_dir/"env_code_1.py", select_annotations=SelectAnnotations.select_all_paths)
env.reset()
print(env.state)

Daemon started

num_errors: 0
num_to_annot: 11
to_annotate: [AnnotPath('fib.n'), AnnotPath('fib.<return>'), AnnotPath('foo.bar'), AnnotPath('foo.<return>'), AnnotPath('int_add.a'), AnnotPath('int_add.b'), AnnotPath('int_add.<return>'), AnnotPath('int_tripple_add.a'), AnnotPath('int_tripple_add.b'), AnnotPath('int_tripple_add.c'), AnnotPath('int_tripple_add.<return>')]
------------------------ code -------------------------------
# Env example 1: no existing annotations

from typing import Any  # [added by SPOT]
def fib(n):
    if n == 0:
        return 0
    elif n == 1:
        return 1
    else:
        return fib(n-1) + fib(n-2)

def foo(bar):
    return fib(bar)

def int_add(a, b):
    return a + b + "c"

def int_tripple_add(a, b, c):
    return a + b + c



In [10]:
test_policy(env, lambda s: TypeInfAction(s.to_annot[0], cst.Name("str")))

Action rejected: str @ 'fib.n'
Action rejected: str @ 'fib.<return>'


Tab(children=(HTML(value="<pre style='line-height:1.2;'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Annot…

In [11]:
int_policy = lambda s: TypeInfAction(s.to_annot[0], cst.Name("int"))
test_policy(env, int_policy)

Action rejected: int @ 'int_add.b'


Tab(children=(HTML(value="<pre style='line-height:1.2;'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Annot…

In [12]:
model_dir = "./checkpoints/saved/SPOT-CodeT5-fine-tune/checkpoint-1500"

import torch
from transformers import RobertaTokenizer, T5ForConditionalGeneration, DataCollatorForSeq2Seq
from transformers.models.t5 import T5ForConditionalGeneration

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(model_dir)
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(model_dir).to(device)



In [13]:
from transformers.models.t5 import T5ForConditionalGeneration
from transformers import RobertaTokenizer
from spot.type_env import apply_annotations
from spot.utils import join_str
import re

def greedy_policy_from_model(model: T5ForConditionalGeneration, tokenizer: RobertaTokenizer):
    def pi(s: TypeInfState) -> TypeInfAction:
        path = s.to_annot[0]
        annot = cst.Annotation(cst.Name("SPOT_TYPE_MASK"))
        m1 = apply_annotations(s.module, {path: annot})
        code_input = m1.code.replace("SPOT_TYPE_MASK", "<extra_id_0>")
        dec = model.generate(
            tokenizer.encode(code_input, return_tensors='pt').to(model.device),
            max_length=20,
            num_beams=16,
        )[0]
        pred = tokenizer.decode(dec, skip_special_tokens=True)
        print(f"Prediction for {path.__str__()}:", pred)
        try:
            type_ex = cst.parse_expression(pred)
        except Exception as e:
            print("Failed to parse:", pred)
            type_ex = cst.Name("Any")
        return TypeInfAction(path, type_ex)
    return pi

def planner_policy_from_model(model: T5ForConditionalGeneration, tokenizer: RobertaTokenizer):
    def pi(s: TypeInfState) -> TypeInfAction:
        path = s.to_annot[0]
        annot = cst.Annotation(cst.Name("SPOT_TYPE_MASK"))
        m1 = apply_annotations(s.module, {p: annot for p in s.to_annot})
        code_segs = m1.code.split("SPOT_TYPE_MASK")
        mask_tokens = [f"<extra_id_{i}>" for i in range(len(s.to_annot))]
        code_input = join_str(code_segs, mask_tokens)
        dec = model.generate(
            tokenizer.encode(code_input, return_tensors='pt').to(model.device),
            max_length=56,
            num_beams=16,
        )[0]
        dec = tokenizer.decode(dec)
        mr = re.match(r".+<extra_id_0>(.+)<extra_id_1>.+", dec)
        if mr is None:
            mr = re.match(r".+<extra_id_0>(.+)</s>", dec)
        if mr is not None:
            type_str = mr.group(1)
            print(f"Prediction for {path.__str__()}:", type_str)
            try:
                type_ex = cst.parse_expression(type_str)
            except Exception as e:
                print("Failed to parse as type:", type_str)
                raise e
                type_ex = cst.Name("Any")
        else:
            print(f"Failed to parse model output: {dec}")
            type_ex = cst.Name("Any")
        return TypeInfAction(path, type_ex)
    return pi

greedy_policy = greedy_policy_from_model(model, tokenizer)
planner_policy = planner_policy_from_model(model, tokenizer)

In [14]:
test_policy(env, greedy_policy)

Prediction for 'fib.n': Any
Prediction for 'fib.<return>': int
Prediction for 'foo.bar': Any
Prediction for 'foo.<return>': int
Prediction for 'int_add.a': int
Prediction for 'int_add.b': int
Action rejected: int @ 'int_add.b'
Prediction for 'int_add.<return>': str
Prediction for 'int_tripple_add.a': int
Prediction for 'int_tripple_add.b': Any
Prediction for 'int_tripple_add.c': Any
Prediction for 'int_tripple_add.<return>': str


Tab(children=(HTML(value="<pre style='line-height:1.2;'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Annot…

In [15]:
test_policy(env, planner_policy)

Prediction for 'fib.n': int
Prediction for 'fib.<return>': int
Prediction for 'foo.bar': Any
Prediction for 'foo.<return>': int
Prediction for 'int_add.a': int
Prediction for 'int_add.b': int
Action rejected: int @ 'int_add.b'
Prediction for 'int_add.<return>': str
Prediction for 'int_tripple_add.a': int
Prediction for 'int_tripple_add.b': Any
Prediction for 'int_tripple_add.c': Any
Prediction for 'int_tripple_add.<return>': str


Tab(children=(HTML(value="<pre style='line-height:1.2;'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Annot…

In [88]:
test_src = "src/spot/utils.py"
write_file(inference_dir/"utils.py", read_file(test_src))
inf_checker.close()

Daemon stopped


In [93]:
from spot.type_env import type_inf_env


inf_checker = MypyChecker(".venv/bin/dmypy", inference_dir)
with type_inf_env(inf_checker, inference_dir / "utils.py", SelectAnnotations.select_annotated, check_any=True, print_mypy_output=True) as env:
    while env.state.to_annot:
        env.step(int_policy(env.state))

Daemon is still alive


{"validate_meta_time": 9.5367431640625e-07, "files_parsed": 1, "modules_parsed": 1, "stubs_parsed": 0, "parse_time": 0.0006725788116455078, "find_module_time": 0.0019850730895996094, "find_module_calls": 17, "semanal_time": 0.0012164115905761719, "typecheck_time": 0.0046253204345703125, "finish_passes_time": 0.000457763671875, "load_fg_deps_time": 1.430511474609375e-06, "update_isolated_time": 0.009349346160888672, "propagate_time": 0.00028395652770996094, "find_changes_time": 0.003499269485473633, "fg_update_time": 0.009707450866699219, "files_changed": 1}Daemon is up and running
{"/home/jiayi/Projects/SPOT/data/code_output/inference/env_code_1.py": [1651261265.6999123, 399, "2168858f3ebddd243021eb78ac3ff1d85a19c6a7c4852a5afbb80053a4c69f03"], "/home/jiayi/Projects/SPOT/data/code_output/inference/utils.py": [1651264940.4920394, 1210, "0470849f22faa12bb878d718caf653f6458dbc388c915659a1dec327bd2fe5fd"]}{"find_changes_time": 0.0045735836029052734, "fg_update_time": 9.298324584960938e-06, 

Traceback (most recent call last):
  File "/home/jiayi/Projects/SPOT/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3369, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_19000/2311706083.py", line 7, in <cell line: 5>
    env.step(int_policy(env.state))
  File "/home/jiayi/Projects/SPOT/src/spot/type_env.py", line 446, in step
    assert out.num_errors == state.num_errors, (
AssertionError: Adding Any should not trigger more type errors.
original errors: 0, new errors: 3
action: TypeInfAction('write_file.<return>' : int)
mypy output: utils.py:28: error: Incompatible return value type (got "str", expected "int")
utils.py:31: error: Missing return statement
utils.py:34: error: Argument 1 to "write" of "TextIOBase" has incompatible type "int"; expected "str"
Found 3 errors in 1 file (checked 2 source files)

---------Code---------
 from typing import Any  # [added by SPOT]
from concurrent.futures import ProcessPoolExecut

In [30]:
from spot.type_env import type_inf_env

inf_checker = MypyChecker(".venv/bin/dmypy", inference_dir)
with type_inf_env(inf_checker, inference_dir / "utils.py", check_any=True, print_mypy_output=True) as env:
    # test_r=test_policy(env, planner_policy)
    test_r=test_policy(env, int_policy)
test_r

Daemon started
action:  TypeInfAction('read_file.<return>' : int)
mypy output: Success: no issues found in 1 source file

action:  TypeInfAction('write_file.content' : int)
mypy output: utils.py:28: error: Incompatible return value type (got "str", expected "int")
utils.py:34: error: Argument 1 to "write" of "TextIOBase" has incompatible type "int"; expected "str"
Found 2 errors in 1 file (checked 1 source file)



AssertionError: Adding Any should not trigger more type errors.
original errors: 0, new errors: 1
action: TypeInfAction('write_file.content' : int)
mypy output: utils.py:28: error: Incompatible return value type (got "str", expected "int")
Found 1 error in 1 file (checked 1 source file)

---------Code---------
 from typing import Any  # [added by SPOT]
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import (
    Callable,
    Sequence,
    Optional,
    TypeVar,
    Union,
    Generator,
)
from typing import cast
import libcst as cst
import os
from pathlib import Path
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map


class SpecialNames:
    Return = "<return>"
    Missing = "<missing>"
    Lambda = "<lambda>"


def read_file(path) -> int:
    """read file content as string."""
    with open(path, "r") as f:
        return f.read()


def write_file(path, content: Any) -> Any:
    """write content to file."""
    with open(path, "w") as f:
        f.write(content)


def proj_root() -> Any:
    return Path(__file__).parent.parent.parent


T1 = TypeVar("T1")
T2 = TypeVar("T2")


def seq_flatten(xs: Any) -> Any:
    return (item for sublist in xs for item in sublist)


def join_str(segs: Any, seps: Any) -> Any:
    assert len(seps) == len(segs) - 1, f"{len(seps)} != {len(segs) - 1}"
    all_segs = [segs[0]]
    for s, sep in zip(segs[1:], seps):
        all_segs.append(sep)
        all_segs.append(s)
    return "".join(all_segs)


In [58]:
from spot.type_env import test_inference_performance
from spot.utils import parallel_map_unordered

test_dirs = [r.repo_dir(repos_dir) for r in useful_repos[:2] if r.lines_of_code < 10000]
with ProcessPoolExecutor(max_workers=10) as executor:
    results = parallel_map_unordered(test_inference_performance, test_dirs, executor)
n_checks = sum(r["n_checks"] for r in results)
total_time = sum(r["time"] for r in results)
print(f"{n_checks} checks in {total_time} seconds")
print(f"{n_checks / total_time} checks/second")



[A[A

Daemon started
Daemon started


  0%|          | 0/10 [16:34<?, ?it/s]
 20%|██        | 3/15 [09:53<39:32, 197.72s/it]


[A[A

Daemon stopped


  0%|          | 0/10 [16:58<?, ?it/s]
 20%|██        | 3/15 [10:17<41:11, 205.98s/it]


100%|██████████| 2/2 [00:52<00:00, 26.03s/it]

Daemon stopped
479 checks in 69.5186333656311 seconds
6.890239016649167 checks/second



