In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.data import GitRepo
from spot.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from spot.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

repos_split_path = datadir / "SPOT-data/repos-processed-with_margin/repos_split.pkl"
with repos_split_path.open("rb") as f:
    repos_split = pickle.load(f)

In [2]:
import torch

from spot.model import ModelSPOT, TokenizerSPOT

train_from_scrach = True

model_path = "Salesforce/codet5-base" if train_from_scrach else datadir / "checkpoints/saved/SPOT-CodeT5-with_margin"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: TokenizerSPOT = TokenizerSPOT.from_pretrained(model_path)
model: ModelSPOT = ModelSPOT.from_pretrained(model_path).to(device)



In [3]:
from IPython.display import display, display_pretty

import wandb
from spot.training import DAggerTrainer, DAggerTrainerArgs

test_run = False
test_tag = 'test-' if test_run else ''

scratch_tag = '-scratch' if train_from_scrach else ''
model_name = f"{test_tag}SPOT-DAgger{scratch_tag}"

args = DAggerTrainerArgs(
    output_dir=proj_root() / "checkpoints" / model_name,
    max_epochs=2,
    skip_first_eval=False,
    repos_group_size=16,
    ctx_size=512,
    ctx_margin=128,
    types_in_ctx=False,
    sampling_batch_size=300,
    train_batch_size=42,
    generation_max_length=128,
    max_workers=16,
)


trainer = DAggerTrainer(model, tokenizer, args)
train_repos = [r.repo_dir(repos_dir) for r in repos_split["train"]]
valid_repos = [r.repo_dir(repos_dir) for r in repos_split["valid"]]
if test_run:
    train_repos = train_repos[:10]
    valid_repos = valid_repos[:10]

TypeError: DAggerTrainerArgs.__init__() got an unexpected keyword argument 'skip_first_eval'

In [None]:
wandb.init(project=model_name, config=args)

try:
    trainer.train(train_repos, valid_repos)
except Exception as e:
    wandb.alert(title="Training stopped due to exception", text=f"In {model_name}, exception: {e}")
    raise e
wandb.alert(title="Training finished", text=f"{model_name} has finished.")

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
loss,▁
step,▁

0,1
epoch,1.0
loss,0.67867
step,16.0


DAgger Training:   0%|          | 0/1146 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
from random import randint
rd = {randint(0, 100): randint(1000,2000) for _ in range(10)}
rd

{4: 1875,
 55: 1596,
 74: 1983,
 35: 1762,
 91: 1767,
 79: 1337,
 31: 1082,
 0: 1023,
 12: 1630}

In [None]:
list(rd.keys())

[4, 55, 74, 35, 91, 79, 31, 0, 12]

In [None]:
_, _, ds, preds = trainer.eval_on_repos(valid_repos[1:4], silent=False)

parsing and masking sources:   0%|          | 0/86 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/86 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/646 [00:00<?, ?it/s]

predict:   0%|          | 0/346 [00:00<?, ?it/s]

apply file changes:   0%|          | 0/53 [00:00<?, ?it/s]

calling mypy:   0%|          | 0/3 [00:00<?, ?it/s]

generating augmented inputs:   0%|          | 0/53 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/53 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/654 [00:00<?, ?it/s]

predict:   0%|          | 0/357 [00:00<?, ?it/s]

In [None]:
from spot.data import TypeInfDataset, inline_predictions
from spot import PythonType

def visualize_batch(dataset: TypeInfDataset, preds: list[list[PythonType]], i: int):
    types = preds[i]
    typpes_enc = [tokenizer.encode(str(t), add_special_tokens=False) for t in types]

    code_tks = inline_predictions(dataset.data["input_ids"][i], typpes_enc, tokenizer)
    code_dec = tokenizer.decode(code_tks, skip_special_tokens=False)
    label_dec = dataset.chunks_info[i].types
    return "".join([
        "labels: ", str(label_dec), "\n",
        "preds: ", str(types), "\n",
        "========================== Code =======================\n", code_dec, "\n",
    ])

from spot.visualization import display_code_sequence


display_code_sequence([visualize_batch(ds, preds, i) for i in range(6, 20)])

Tab(children=(HTML(value="<pre style='line-height:1.2;'>labels: [List[Schema]]\npreds: [List[Union[Schema, Md5…

In [None]:
display(trainer.timer.as_dataframe())

Unnamed: 0,name,count,avg_time,total_time
3,training > model fitting,7,153.695309,1075.867161
1,training > model prediction,8,84.873097,678.984775
2,training > type checking,7,66.09076,462.635319
0,training > preparing data,15,12.201682,183.025235


In [None]:
code="""from typing import Any # SPOT
"Check health of a baseplate service on localhost."
import argparse
import socket
import sys
import typing
import urllib.parse

import requests

from baseplate.lib._requests import add_unix_socket_support
from baseplate.lib.config import Endpoint
from baseplate.lib.config import EndpointConfiguration
from baseplate.lib.config import InternetAddress
from baseplate.lib.thrift_pool import ThriftConnectionPool
from baseplate.thrift import BaseplateServiceV2
from baseplate.thrift.ttypes import IsHealthyProbe
from baseplate.thrift.ttypes import IsHealthyRequest


TIMEOUT = 30  # seconds


def check_thrift_service(endpoint: Endpoint, probe: str) -> None:
    pool = ThriftConnectionPool(endpoint, size=1, timeout=TIMEOUT)
    with pool.connection() as protocol:
        client = BaseplateServiceV2.Client(protocol)
        assert client.is_healthy(
            request=IsHealthyRequest(probe=probe),
        ), f"service indicated unhealthiness in probe {probe}"


def check_http_service(endpoint: EndpointConfiguration, probe: str) -> InternetAddress:
    if endpoint.family == socket.AF_INET:
        address: None = typing.cast(InternetAddress, endpoint.address)
        url = f"http://{address.host}:{address.port}/health?type={probe}"
    elif endpoint.family == socket.AF_UNIX:
        quoted_path = urllib.parse.quote(typing.cast(str, endpoint.address), safe="")
        url = f"http+unix://{quoted_path}/health?type={probe}"
    else:
        raise ValueError(f"unrecognized socket family {endpoint.family!r}")

    session = requests.Session()
    add_unix_socket_support(session)
    response = session.get(url, timeout=TIMEOUT)
    response.raise_for_status()
    response.json()


CHECKERS = {"thrift": check_thrift_service, "wsgi": check_http_service}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=sys.modules[__name__].__doc__)

    parser.add_argument(
        "type",
        choices=CHECKERS.keys(),
        default="thrift",
        help="The protocol of the service to check.",
    )
    parser.add_argument(
        "endpoint",
        type=Endpoint,
        default=Endpoint("localhost:9090"),
        help="The endpoint to find the service on.",
    )
    parser.add_argument(
        "--probe",
        choices=[probe.lower() for probe in IsHealthyProbe._NAMES_TO_VALUES],
        default="readiness",
        help="The probe to check.",
    )

    return parser.parse_args()


def run_healthchecks() -> None:
    args = parse_args()

    checker = CHECKERS[args.type]
    checker(args.endpoint, IsHealthyProbe._NAMES_TO_VALUES[args.probe.upper()])
    print("OK!")


if __name__ == "__main__":
    run_healthchecks()"""

lines = code.splitlines()

In [None]:
len(lines)

87