In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.data import GitRepo
from spot.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from spot.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

repos_split_path = datadir / "SPOT-data/repos-processed-with_margin/repos_split.pkl"
with repos_split_path.open("rb") as f:
    repos_split = pickle.load(f)

In [2]:
import torch

from spot.model import ModelSPOT, TokenizerSPOT

model_path = datadir / "checkpoints/saved/SPOT-CodeT5-with_margin"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: TokenizerSPOT = TokenizerSPOT.from_pretrained(model_path)
model: ModelSPOT = ModelSPOT.from_pretrained(model_path).to(device)



In [3]:
from IPython.display import display

import wandb
from spot.training import DAggerTrainer, DAggerTrainerArgs

test_run = False
test_tag = 'test-' if test_run else ''

model_name = f"{test_tag}SPOT-DAgger"

args = DAggerTrainerArgs(
    output_dir=proj_root() / "checkpoints" / model_name,
    max_epochs=2,
    repos_group_size=16,
    ctx_size=512,
    ctx_margin=128,
    sampling_batch_size=300,
    train_batch_size=42,
    generation_max_length=128,
    max_workers=16,
)


trainer = DAggerTrainer(model, tokenizer, args)
train_repos = [r.repo_dir(repos_dir) for r in repos_split["train"]]
valid_repos = [r.repo_dir(repos_dir) for r in repos_split["valid"]]
if test_run:
    train_repos = train_repos[:10]
    valid_repos = valid_repos[:10]

In [4]:
wandb.init(project=model_name, config=args)

try:
    trainer.train(train_repos, valid_repos)
except Exception as e:
    wandb.alert(title="Training stopped due to exception", text=f"In {model_name}, exception: {e}")
    raise e
wandb.alert(title="Training finished", text=f"{model_name} has finished.")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


DAgger Training:   0%|          | 0/1146 [00:00<?, ?it/s]

[Epoch 0] R0 stats: {'R0_accuracy_partial': {'total': 0.7511115406920549, 'FuncArg': 0.6980251346499102, 'ClassAtribute': 0.7765064836003052, 'FuncReturn': 0.8344476281439032, 'LocalVar': 0.7176470588235294, 'GlobalVar': 0.8840579710144928}, 'R0_accuracy_full': {'total': 0.6830659191958245, 'FuncArg': 0.6274685816876122, 'ClassAtribute': 0.7063310450038138, 'FuncReturn': 0.7777777777777778, 'LocalVar': 0.611764705882353, 'GlobalVar': 0.6811594202898551}, 'R0_n_labels': 10346}
[Epoch 0] R1 stats: {'R1_accuracy_partial': {'total': 0.7447834645669291, 'FuncArg': 0.6966824644549763, 'ClassAtribute': 0.75592960979342, 'FuncReturn': 0.8295566502463054, 'LocalVar': 0.6782945736434108, 'GlobalVar': 0.875}, 'R1_accuracy_full': {'total': 0.7049212598425196, 'FuncArg': 0.6562158220925993, 'ClassAtribute': 0.7115531752104055, 'FuncReturn': 0.7960591133004926, 'LocalVar': 0.6085271317829457, 'GlobalVar': 0.796875}, 'R1_n_labels': 10160}


In [None]:
trainer.eval_on_repos(valid_repos, silent=False)

parsing and masking sources:   0%|          | 0/1953 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/1953 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/5422 [00:00<?, ?it/s]

predict:   0%|          | 0/5 [00:00<?, ?it/s]

reading orginal srcs:   0%|          | 0/938 [00:00<?, ?it/s]

apply file changes:   0%|          | 0/938 [00:00<?, ?it/s]

calling mypy:   0%|          | 0/40 [00:00<?, ?it/s]

generating augmented inputs:   0%|          | 0/938 [00:00<?, ?it/s]

tokenizing sources:   0%|          | 0/938 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/3982 [00:00<?, ?it/s]

predict:   0%|          | 0/5 [00:00<?, ?it/s]

({'R0_accuracy_partial': {'total': 0.7467620336361879,
   'FuncArg': 0.6946140035906643,
   'ClassAtribute': 0.7574370709382151,
   'FuncReturn': 0.8341292581980261,
   'LocalVar': 0.7058823529411765,
   'GlobalVar': 0.927536231884058},
  'R0_accuracy_full': {'total': 0.6792963464140731,
   'FuncArg': 0.6245960502692999,
   'ClassAtribute': 0.6864988558352403,
   'FuncReturn': 0.7780961477236549,
   'LocalVar': 0.596078431372549,
   'GlobalVar': 0.7681159420289855},
  'R0_n_labels': 10346},
 {'R1_accuracy_partial': {'total': 0.7504204174497973,
   'FuncArg': 0.6992064956634065,
   'ClassAtribute': 0.7669753086419753,
   'FuncReturn': 0.8387516254876463,
   'GlobalVar': 0.8823529411764706,
   'LocalVar': 0.6666666666666666},
  'R1_accuracy_full': {'total': 0.707488376694035,
   'FuncArg': 0.6552869533124193,
   'ClassAtribute': 0.7222222222222222,
   'FuncReturn': 0.8033159947984395,
   'GlobalVar': 0.7843137254901961,
   'LocalVar': 0.5767790262172284},
  'R1_n_labels': 10109})

In [None]:
display(trainer.timer.as_dataframe())

Unnamed: 0,name,count,avg_time,total_time
3,training > model fitting,7,153.695309,1075.867161
1,training > model prediction,8,84.873097,678.984775
2,training > type checking,7,66.09076,462.635319
0,training > preparing data,15,12.201682,183.025235


In [None]:
import libcst as cst

mod = cst.parse_module(Path("data/code/code_with_slash.py").read_text())

In [None]:
print(mod.code)

def __init__(
    self, check_interval: int, folder: Path, /) -> None:
    super().__init__(check_interval, "AutoLocker")

    self._autolocked: Dict[Path, int] = {}
    self._lockers: Dict[Path, "DirectEdit"] = {}
    self._to_lock: Items = []

