In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.data import GitRepo
from spot.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from spot.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

repos_split_path = datadir / "SPOT-data/repos-processed-with_margin/repos_split.pkl"
with repos_split_path.open("rb") as f:
    repos_split = pickle.load(f)

In [4]:
import torch

from spot.model import ModelSPOT, TokenizerSPOT

r0_model_path = datadir / "checkpoints/saved/SPOT-CodeT5-with_margin"
r1_model_path = "Salesforce/codet5-base"

# model_path = datadir / "checkpoints/saved/SPOT-DAgger-scratch"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: TokenizerSPOT = TokenizerSPOT.from_pretrained(r0_model_path)
r0_model: ModelSPOT = ModelSPOT.from_pretrained(r0_model_path).to(device)
r1_model: ModelSPOT = ModelSPOT.from_pretrained(r1_model_path).to(device)



In [5]:
from IPython.display import display, display_pretty

import wandb
from spot.training import DAggerTrainer, DAggerTrainerArgs, CtxArgs

test_run = True
test_tag = 'test-' if test_run else ''

model_name = f"{test_tag}SPOT-DAgger-r1"

args = DAggerTrainerArgs(
    output_dir=datadir / "checkpoints" / model_name,
    max_epochs=3,
    skip_first_eval=False,
    repos_group_size=16,
    ctx_args=CtxArgs(
        ctx_size=512,
        ctx_margin=128,
        types_in_ctx=False,
    ),
    sampling_batch_size=256,
    train_batch_size=42,
    generation_max_length=128,
    max_workers=16,
)


trainer = DAggerTrainer(r0_model, r1_model, tokenizer, args)
train_repos = [r.repo_dir(repos_dir) for r in repos_split["train"]]
valid_repos = [r.repo_dir(repos_dir) for r in repos_split["valid"]]
if test_run:
    train_repos = train_repos[:10]
    valid_repos = valid_repos[:5]

In [7]:
wandb.init(project=model_name, config=args, dir=str(datadir))

try:
    trainer.train(train_repos, valid_repos)
except Exception as e:
    wandb.alert(title="Training stopped due to exception", text=f"In {model_name}, exception: {e}")
    raise e
wandb.alert(title="Training finished", text=f"{model_name} has finished.")
wandb.log({"time_stats": trainer.timer.total_times()})
wandb.finish()

VBox(children=(Label(value='0.000 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.0, max…

0,1
R0_full_acc,▁
R0_n_labels,▁
R0_partial_acc,▁
R0_partial_acc_no_any,▁
R1_full_acc,▁
R1_n_labels,▁
R1_partial_acc,▁
R1_partial_acc_no_any,▁
epoch,▁▁
step,▁▁

0,1
R0_full_acc,0.53049
R0_n_labels,3214.0
R0_partial_acc,0.61045
R0_partial_acc_no_any,0.61814
R1_full_acc,0.0498
R1_n_labels,3213.0
R1_partial_acc,0.05478
R1_partial_acc_no_any,0.03577
epoch,0.0
step,0.0


DAgger Training:   0%|          | 0/30 [00:00<?, ?it/s]

parsing and masking sources:   0%|          | 0/351 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/351 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1353 [00:00<?, ?it/s]

predict:   0%|          | 0/730 [00:00<?, ?it/s]

Collect type checker feedback:   0%|          | 0/5 [00:00<?, ?it/s]

generating augmented inputs:   0%|          | 0/249 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/249 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1318 [00:00<?, ?it/s]

predict:   0%|          | 0/800 [00:00<?, ?it/s]

[Epoch 0] R0 stats:

R0_partial_acc: 7738.02%%

R0_partial_acc_wo_any: 7851.21%%

R0_partial_accs:
   
FuncArg: 7429.56%%
   
FuncReturn: 8488.98%%
   
ClassAtribute: 7668.85%%
   
GlobalVar: 10000.00%%
   
LocalVar: 4821.43%%

R0_full_acc: 6947.73%%

R0_full_accs:
   
FuncArg: 6457.73%%
   
FuncReturn: 7995.80%%
   
ClassAtribute: 7015.25%%
   
GlobalVar: 5714.29%%
   
LocalVar: 3928.57%%

R0_n_labels: 3214
[Epoch 0] R1 stats:

R1_partial_acc: 463.74%%

R1_partial_acc_wo_any: 281.06%%

R1_partial_accs:
   
FuncArg: 632.91%%
   
FuncReturn: 398.74%%
   
LocalVar: 178.57%%

R1_full_acc: 426.39%%

R1_full_accs:
   
FuncArg: 621.40%%
   
FuncReturn: 304.30%%

R1_n_labels: 3213


parsing and masking sources:   0%|          | 0/351 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/351 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1353 [00:00<?, ?it/s]

predict:   0%|          | 0/730 [00:00<?, ?it/s]

Collect type checker feedback:   0%|          | 0/5 [00:00<?, ?it/s]

generating augmented inputs:   0%|          | 0/249 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/249 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1322 [00:00<?, ?it/s]

predict:   0%|          | 0/794 [00:00<?, ?it/s]

[Epoch 1] R0 stats:

R0_partial_acc: 7825.14%%

R0_partial_acc_wo_any: 7953.38%%

R0_partial_accs:
   
FuncArg: 7573.32%%
   
FuncReturn: 8562.43%%
   
ClassAtribute: 7668.85%%
   
GlobalVar: 7142.86%%
   
LocalVar: 4464.29%%

R0_full_acc: 7022.40%%

R0_full_accs:
   
FuncArg: 6555.49%%
   
FuncReturn: 8153.20%%
   
ClassAtribute: 6906.32%%
   
GlobalVar: 2857.14%%
   
LocalVar: 3750.00%%

R0_n_labels: 3214
[Epoch 1] R1 stats:

R1_partial_acc: 8814.19%%

R1_partial_acc_wo_any: 8834.24%%

R1_partial_accs:
   
FuncArg: 9108.17%%
   
FuncReturn: 8698.85%%
   
ClassAtribute: 8474.95%%
   
GlobalVar: 7142.86%%
   
LocalVar: 4642.86%%

R1_full_acc: 8555.87%%

R1_full_accs:
   
FuncArg: 8768.70%%
   
FuncReturn: 8614.90%%
   
ClassAtribute: 8126.36%%
   
GlobalVar: 7142.86%%
   
LocalVar: 4642.86%%

R1_n_labels: 3213




parsing and masking sources:   0%|          | 0/351 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/351 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1353 [00:00<?, ?it/s]

predict:   0%|          | 0/730 [00:00<?, ?it/s]

Collect type checker feedback:   0%|          | 0/5 [00:00<?, ?it/s]

generating augmented inputs:   0%|          | 0/249 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/249 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1318 [00:00<?, ?it/s]

predict:   0%|          | 0/805 [00:00<?, ?it/s]

[Epoch 2] R0 stats:

R0_partial_acc: 7703.80%%

R0_partial_acc_wo_any: 7825.67%%

R0_partial_accs:
   
FuncArg: 7389.30%%
   
FuncReturn: 8583.42%%
   
ClassAtribute: 7342.05%%
   
GlobalVar: 10000.00%%
   
LocalVar: 5178.57%%

R0_full_acc: 6981.95%%

R0_full_accs:
   
FuncArg: 6486.49%%
   
FuncReturn: 8153.20%%
   
ClassAtribute: 6753.81%%
   
GlobalVar: 7142.86%%
   
LocalVar: 4285.71%%

R0_n_labels: 3214
[Epoch 2] R1 stats:

R1_partial_acc: 9137.88%%

R1_partial_acc_wo_any: 9134.46%%

R1_partial_accs:
   
FuncArg: 9418.87%%
   
FuncReturn: 8992.65%%
   
ClassAtribute: 9019.61%%
   
GlobalVar: 4285.71%%
   
LocalVar: 4464.29%%

R1_full_acc: 8969.81%%

R1_full_accs:
   
FuncArg: 9229.00%%
   
FuncReturn: 8908.71%%
   
ClassAtribute: 8779.96%%
   
GlobalVar: 2857.14%%
   
LocalVar: 4285.71%%

R1_n_labels: 3213




parsing and masking sources:   0%|          | 0/351 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/351 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1353 [00:00<?, ?it/s]

predict:   0%|          | 0/730 [00:00<?, ?it/s]

Collect type checker feedback:   0%|          | 0/5 [00:00<?, ?it/s]

generating augmented inputs:   0%|          | 0/249 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/249 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/1321 [00:00<?, ?it/s]

predict:   0%|          | 0/807 [00:00<?, ?it/s]

[Epoch 3] R0 stats:

R0_partial_acc: 7812.69%%

R0_partial_acc_wo_any: 7921.46%%

R0_partial_accs:
   
FuncArg: 7561.82%%
   
FuncReturn: 8551.94%%
   
ClassAtribute: 7625.27%%
   
GlobalVar: 10000.00%%
   
LocalVar: 4285.71%%

R0_full_acc: 7044.18%%

R0_full_accs:
   
FuncArg: 6572.74%%
   
FuncReturn: 8142.71%%
   
ClassAtribute: 7015.25%%
   
GlobalVar: 2857.14%%
   
LocalVar: 3750.00%%

R0_n_labels: 3214
[Epoch 3] R1 stats:

R1_partial_acc: 9389.98%%

R1_partial_acc_wo_any: 9383.58%%

R1_partial_accs:
   
FuncArg: 9660.53%%
   
FuncReturn: 9150.05%%
   
ClassAtribute: 9411.76%%
   
GlobalVar: 8571.43%%
   
LocalVar: 5000.00%%

R1_full_acc: 9265.48%%

R1_full_accs:
   
FuncArg: 9533.95%%
   
FuncReturn: 9097.59%%
   
ClassAtribute: 9128.54%%
   
GlobalVar: 8571.43%%
   
LocalVar: 5000.00%%

R1_n_labels: 3213




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
R0_full_acc,▁▆▃█
R0_n_labels,▁▁▁▁
R0_partial_acc,▃█▁▇
R0_partial_acc_wo_any,▂█▁▆
R1_full_acc,▁▇██
R1_n_labels,▁▁▁▁
R1_partial_acc,▁███
R1_partial_acc_wo_any,▁███
epoch,▁▁▃▃▃▆▆▆███
loss,█▂▁

0,1
R0_full_acc,0.70442
R0_n_labels,3214.0
R0_partial_acc,0.78127
R0_partial_acc_wo_any,0.79215
R1_full_acc,0.92655
R1_n_labels,3213.0
R1_partial_acc,0.939
R1_partial_acc_wo_any,0.93836
epoch,3.0
loss,0.06567


In [48]:
from spot.data import pretty_print_accuracies

test_repos = [r.repo_dir(repos_dir) for r in repos_split["test"]]

(r0_stats, r0_data, r0_preds), (r1_stats, r1_data, r1_preds) = trainer.eval_on_repos(test_repos[0:3])
pretty_print_accuracies(r0_stats, max_show_level=0)
pretty_print_accuracies(r1_stats, max_show_level=0)

parsing and masking sources:   0%|          | 0/74 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/74 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/503 [00:00<?, ?it/s]

predict:   0%|          | 0/234 [00:00<?, ?it/s]

Collect type checker feedback:   0%|          | 0/3 [00:00<?, ?it/s]

generating augmented inputs:   0%|          | 0/44 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/44 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/485 [00:00<?, ?it/s]

predict:   0%|          | 0/260 [00:00<?, ?it/s]

R0_partial_acc: 66.36%
R0_partial_acc_wo_any: 67.75%
R0_partial_accs:
   ...
R0_full_acc: 58.21%
R0_full_accs:
   ...
R0_n_labels: 871
R1_partial_acc: 63.26%
R1_partial_acc_wo_any: 64.38%
R1_partial_accs:
   ...
R1_full_acc: 55.91%
R1_full_accs:
   ...
R1_n_labels: 871


In [49]:
from spot import PythonType
from spot.data import TypeInfDataset, inline_predictions
from spot.visualization import display_code_sequence, code_inline_extra_ids

ctx_margin = args.ctx_args.ctx_margin


def visualize_batch(dataset: TypeInfDataset, preds: list[list[PythonType]], i: int):
    pred_types = preds[i]
    typpes_enc = [tokenizer.encode(str(t), add_special_tokens=False) for t in pred_types]

    label_types = dataset.chunks_info[i].types
    code_tks = inline_predictions(dataset.data["input_ids"][i], typpes_enc, tokenizer)
    sep_1 = tokenizer.encode("\n---------⬆context⬆---------\n", add_special_tokens=False)
    sep_2 = tokenizer.encode("\n---------⬇context⬇---------\n", add_special_tokens=False)
    code_tks = code_tks[:ctx_margin] + sep_1 + code_tks[ctx_margin:-ctx_margin] + sep_2 + code_tks[-ctx_margin:]
    code_dec = tokenizer.decode(code_tks, skip_special_tokens=False)
    code_dec = code_inline_extra_ids(code_dec, label_types)
    src_ids = sorted(list(set(dataset.chunks_info[i].src_ids)))
    files = [dataset.files[i].relative_to(datadir) for i in src_ids]
    return "".join([
        "labels: ", str(label_types), "\n",
        "preds: ", str(pred_types), "\n",
        "files: ", str(files), "\n",
        "========================== Code =======================\n", code_dec, "\n",
    ])


display_code_sequence([visualize_batch(r1_data, r1_preds, i) for i in range(20)])

Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…

In [50]:
display_code_sequence([visualize_batch(r0_data, r0_preds, i) for i in range(20)])

Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…

In [63]:
from spot.data import mask_type_annots
from spot.utils import join_str
src_collect = read_file("/mnt/data0/jiayi/SPOT-data/repos/downloaded/typeddjango__pytest-mypy-plugins/pytest_mypy_plugins/collect.py")
mask_r = mask_type_annots(src_collect)
print(join_str(mask_r["code_segs"], list(map(str, mask_r["types"]))))

import os
import pathlib
import platform
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Mapping, Optional, Set

import pkg_resources
import pytest
import yaml
from _pytest.config.argparsing import Parser
from _pytest.nodes import Node
from py._path.local import LocalPath

from pytest_mypy_plugins import utils

if TYPE_CHECKING:
    from pytest_mypy_plugins.item import YamlTestItem


class File:
    def __init__(self, path: str, content: str) -> None:
        self.path = path
        self.content = content


def parse_test_files(test_files: List[Dict[str, Any]]) -> List[File]:
    files: List[File] = []
    for test_file in test_files:
        path = test_file.get("path", "main.py")
        file = File(path=path, content=test_file.get("content", ""))
        files.append(file)
    return files


def parse_environment_variables(env_vars: List[str]) -> Dict[str, str]:
    parsed_vars: Dict[str, str] = {}
    for env_var in env_vars:
        name, _

In [62]:
import libcst as cst
print([str(a.path) for a in collect_annotations(cst.parse_module(src_collect))])

["'File.__init__.path'", "'File.__init__.content'", "'File.__init__.<return>'", "'parse_test_files.test_files'", "'parse_test_files.files'", "'parse_test_files.<return>'", "'parse_environment_variables.env_vars'", "'parse_environment_variables.parsed_vars'", "'parse_environment_variables.<return>'", "'parse_parametrized.params'", "'parse_parametrized.parsed_params'", "'parse_parametrized.known_params'", "'parse_parametrized.<return>'", "'SafeLineLoader.construct_mapping.node'", "'SafeLineLoader.construct_mapping.deep'", "'SafeLineLoader.construct_mapping.<return>'", "'YamlTestFile.collect.<return>'", "'YamlTestFile._eval_skip.skip_if'", "'YamlTestFile._eval_skip.<return>'", "'pytest_collect_file.file_path'", "'pytest_collect_file.parent'", "'pytest_collect_file.<return>'", "'pytest_collect_file[1].path'", "'pytest_collect_file[1].parent'", "'pytest_collect_file[1].<return>'", "'pytest_addoption.parser'", "'pytest_addoption.<return>'"]


In [None]:
display(trainer.timer.as_dataframe())

Unnamed: 0,name,count,avg_time,total_time
3,training > model fitting,7,153.695309,1075.867161
1,training > model prediction,8,84.873097,678.984775
2,training > type checking,7,66.09076,462.635319
0,training > preparing data,15,12.201682,183.025235
