In [39]:
%load_ext autoreload
%autoreload 2

import os
import pickle
import shutil
import time
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from typet5.data import GitRepo
from typet5.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from typet5.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
from typet5.visualization import display_code_sequence


def test_policy(env: TypeInfEnv, pi: Callable[[TypeInfState], TypeInfAction]):
    env.reset()
    state_seq = [str(env.state)]
    n_steps = len(env.state.to_annot)
    for i in tqdm(range(n_steps)):
        if env.step(act := pi(env.state)):
            type_str = env.state.module.code_for_node(act.type)
            print(f"Action rejected: [{str(act.path)}: {type_str}]")
        state_seq.append(str(env.state))

    return display_code_sequence(state_seq)

In [46]:
# remove `inference_dir` if it exists
if "inf_checker" in globals():
    inf_checker.close()

inference_dir = Path("data/code_output/inference")
if inference_dir.exists():
    shutil.rmtree(inference_dir)
inference_dir.mkdir(parents=True)
write_file(inference_dir / "env_code_1.py", read_file("data/code/env_code_1.py"))


Daemon stopped


In [55]:
from IPython.display import display

from typet5.type_env import type_inf_env

inf_checker = MypyChecker(".venv/bin/dmypy", inference_dir)

with type_inf_env(
    inf_checker,
    inference_dir / "env_code_1.py",
    SelectAnnotations.select_all_paths,
    print_mypy_output=False,
) as env:
    display(test_policy(env, lambda s: TypeInfAction(s.to_annot[0], cst.Name("str"))))

Daemon is still alive


  0%|          | 0/11 [00:00<?, ?it/s]

Action rejected: ['fib.n': str]
Action rejected: ['fib.<return>': str]


Tab(children=(HTML(value="<pre style='line-height:1.2;'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Annot…

In [7]:
model_dir = datadir / "checkpoints/saved/SPOT-CodeT5-with_margin/"

import torch
from transformers import (
    DataCollatorForSeq2Seq,
    RobertaTokenizer,
    T5ForConditionalGeneration,
)
from transformers.models.t5 import T5ForConditionalGeneration

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(model_dir)
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
    model_dir
).to(device)



In [8]:
import re

from transformers import RobertaTokenizer
from transformers.models.t5 import T5ForConditionalGeneration

from typet5.type_env import apply_annotations
from typet5.utils import join_str


def greedy_policy_from_model(
    model: T5ForConditionalGeneration, tokenizer: RobertaTokenizer
):
    def pi(s: TypeInfState) -> TypeInfAction:
        path = s.to_annot[0]
        annot = cst.Annotation(cst.Name("SPOT_TYPE_MASK"))
        m1 = apply_annotations(s.module, {path: annot})
        code_input = m1.code.replace("SPOT_TYPE_MASK", "<extra_id_0>")
        dec = model.generate(
            tokenizer.encode(code_input, return_tensors="pt").to(model.device),
            max_length=20,
            num_beams=16,
        )[0]
        pred = tokenizer.decode(dec, skip_special_tokens=True)
        print(f"Prediction for {path.__str__()}:", pred)
        try:
            type_ex = cst.parse_expression(pred)
        except Exception as e:
            print("Failed to parse:", pred)
            type_ex = cst.Name("Any")
        return TypeInfAction(path, type_ex)

    return pi


def planner_policy_from_model(
    model: T5ForConditionalGeneration, tokenizer: RobertaTokenizer
):
    def pi(s: TypeInfState) -> TypeInfAction:
        path = s.to_annot[0]
        annot = cst.Annotation(cst.Name("SPOT_TYPE_MASK"))
        m1 = apply_annotations(s.module, {p: annot for p in s.to_annot})
        code_segs = m1.code.split("SPOT_TYPE_MASK")
        mask_tokens = [f"<extra_id_{i}>" for i in range(len(s.to_annot))]
        code_input = join_str(code_segs, mask_tokens)
        dec = model.generate(
            tokenizer.encode(code_input, return_tensors="pt").to(model.device),
            max_length=56,
            num_beams=16,
        )[0]
        dec = tokenizer.decode(dec)
        mr = re.match(r".+<extra_id_0>(.+)<extra_id_1>.+", dec)
        if mr is None:
            mr = re.match(r".+<extra_id_0>(.+)</s>", dec)
        if mr is not None:
            type_str = mr.group(1)
            print(f"Prediction for {path.__str__()}:", type_str)
            try:
                type_ex = cst.parse_expression(type_str)
            except Exception as e:
                print("Failed to parse as type:", type_str)
                type_ex = cst.Name("Any")
        else:
            print(f"Failed to parse model output: {dec}")
            type_ex = cst.Name("Any")
        return TypeInfAction(path, type_ex)

    return pi


greedy_policy = greedy_policy_from_model(model, tokenizer)
planner_policy = planner_policy_from_model(model, tokenizer)

In [56]:
with type_inf_env(
    inf_checker,
    inference_dir / "env_code_1.py",
    SelectAnnotations.select_all_paths,
    print_mypy_output=False,
) as env:
    display(test_policy(env, greedy_policy))


  0%|          | 0/11 [00:00<?, ?it/s]

Prediction for 'fib.n': int
Prediction for 'fib.<return>': int
Prediction for 'foo.bar': int
Prediction for 'foo.<return>': int
Prediction for 'int_add.a': int
Prediction for 'int_add.b': int
Action rejected: ['int_add.b': int]
Prediction for 'int_add.<return>': str
Prediction for 'int_tripple_add.a': int
Prediction for 'int_tripple_add.b': Annotated[Any, int]
Prediction for 'int_tripple_add.c': int
Prediction for 'int_tripple_add.<return>': int


Tab(children=(HTML(value="<pre style='line-height:1.2;'>\nnum_errors: 0\nnum_to_annot: 11\nto_annotate: [Annot…

In [57]:
with type_inf_env(
    inf_checker,
    inference_dir / "env_code_1.py",
    SelectAnnotations.select_all_paths,
    print_mypy_output=False,
) as env:
    display(test_policy(env, planner_policy))


  0%|          | 0/11 [00:00<?, ?it/s]

Prediction for 'fib.n': int
Prediction for 'fib.<return>': int
Prediction for 'foo.bar': int
Prediction for 'foo.<return>': int
Prediction for 'int_add.a': int
Prediction for 'int_add.b': int, c : int
Failed to parse as type: int, c : int
Prediction for 'int_add.<return>': str
Prediction for 'int_tripple_add.a': int
Prediction for 'int_tripple_add.b': Any
Prediction for 'int_tripple_add.c': Any
Prediction for 'int_tripple_add.<return>': str


In [58]:
test_src = proj_root() / "src/spot/utils.py"
write_file(
    inference_dir / test_src.name,
    read_file(test_src).replace("[added by SPOT]", "[MASK]"),
)


In [65]:
inf_checker = MypyChecker(".venv/bin/dmypy", inference_dir)
with type_inf_env(
    inf_checker,
    inference_dir / test_src.name,
    SelectAnnotations.select_annotated,
    print_mypy_output=False,
) as env:
    display(test_policy(env, greedy_policy))


Daemon is still alive


  0%|          | 0/12 [00:00<?, ?it/s]

Prediction for 'read_file.<return>': str
Prediction for 'write_file.content': str
Prediction for 'write_file.<return>': None
Prediction for 'proj_root.<return>': Path
Prediction for 'seq_flatten.xs': Sequence[Any]
Prediction for 'seq_flatten.<return>': Sequence[Any]
Action rejected: ['seq_flatten.<return>': Sequence[Any]]
Prediction for 'join_str.segs': Sequence[Any]
Prediction for 'join_str.seps': Sequence[Any]
Prediction for 'join_str.<return>': str
Prediction for 'accuracy_by_labels.y_preds': Sequence[Any]
Prediction for 'accuracy_by_labels.y_true': Sequence[Any]
Prediction for 'accuracy_by_labels.top_k': Optional[int]


Tab(children=(HTML(value="<pre style='line-height:1.2;'>\nnum_errors: 1\nnum_to_annot: 12\nto_annotate: [Annot…

In [None]:
from typet5.type_env import test_inference_performance
from typet5.utils import parallel_map_unordered

test_dirs = [r.repo_dir(repos_dir) for r in useful_repos[:2] if r.lines_of_code < 10000]
with ProcessPoolExecutor(max_workers=10) as executor:
    results = parallel_map_unordered(test_inference_performance, test_dirs, executor)
n_checks = sum(r["n_checks"] for r in results)
total_time = sum(r["time"] for r in results)
print(f"{n_checks} checks in {total_time} seconds")
print(f"{n_checks / total_time} checks/second")




[A[A

Daemon started
Daemon started


  0%|          | 0/10 [16:34<?, ?it/s]
 20%|██        | 3/15 [09:53<39:32, 197.72s/it]


[A[A

Daemon stopped


  0%|          | 0/10 [16:58<?, ?it/s]
 20%|██        | 3/15 [10:17<41:11, 205.98s/it]


100%|██████████| 2/2 [00:52<00:00, 26.03s/it]

Daemon stopped
479 checks in 69.5186333656311 seconds
6.890239016649167 checks/second





In [8]:
ex_dict = {"c": 1, "b": 2, "0": 5}
next(ex_dict.__iter__())
next(ex_dict.__iter__())


'c'