In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.data import GitRepo
from spot.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from spot.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

repos_split_path = datadir / "SPOT-data/repos-processed-with_margin/repos_split.pkl"
with repos_split_path.open("rb") as f:
    repos_split = pickle.load(f)

In [2]:
import torch

from spot.model import ModelSPOT, TokenizerSPOT

train_from_scrach = True

model_path = "Salesforce/codet5-base" if train_from_scrach else datadir / "checkpoints/saved/SPOT-CodeT5-with_margin"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: TokenizerSPOT = TokenizerSPOT.from_pretrained(model_path)
model: ModelSPOT = ModelSPOT.from_pretrained(model_path).to(device)



In [9]:
from IPython.display import display, display_pretty

import wandb
from spot.training import DAggerTrainer, DAggerTrainerArgs

test_run = True
test_tag = 'test-' if test_run else ''

scratch_tag = '-scratch' if train_from_scrach else ''
model_name = f"{test_tag}SPOT-DAgger{scratch_tag}"

args = DAggerTrainerArgs(
    output_dir=proj_root() / "checkpoints" / model_name,
    max_epochs=2,
    skip_first_eval=False,
    repos_group_size=16,
    ctx_size=512,
    ctx_margin=128,
    types_in_ctx=False,
    sampling_batch_size=300,
    train_batch_size=42,
    generation_max_length=128,
    max_workers=16,
)


trainer = DAggerTrainer(model, tokenizer, args)
train_repos = [r.repo_dir(repos_dir) for r in repos_split["train"]]
valid_repos = [r.repo_dir(repos_dir) for r in repos_split["valid"]]
if test_run:
    train_repos = train_repos[:10]
    valid_repos = valid_repos[:5]

In [11]:
wandb.init(project=model_name, config=args, dir=str(datadir))

try:
    trainer.train(train_repos, valid_repos)
except Exception as e:
    wandb.alert(title="Training stopped due to exception", text=f"In {model_name}, exception: {e}")
    raise e
wandb.alert(title="Training finished", text=f"{model_name} has finished.")

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

DAgger Training:   0%|          | 0/20 [00:00<?, ?it/s]

[Epoch 0] R0 stats:


{'R0_accuracy_partial': {'total': 0.5273802115743622,
  'FuncArg': 0.5296147211040828,
  'ClassAtribute': 0.2440087145969499,
  'FuncReturn': 0.6820566631689402,
  'LocalVar': 0.16071428571428573,
  'GlobalVar': 0.42857142857142855},
 'R0_accuracy_full': {'total': 0.41288114499066586,
  'FuncArg': 0.3887291546866015,
  'FuncReturn': 0.5991605456453305,
  'ClassAtribute': 0.16122004357298475,
  'LocalVar': 0.08928571428571429,
  'GlobalVar': 0.14285714285714285},
 'R0_n_labels': 3214}

[Epoch 0] R1 stats:


{'R1_accuracy_partial': {'total': 0.47121070650482416,
  'FuncArg': 0.4723820483314154,
  'ClassAtribute': 0.20697167755991286,
  'FuncReturn': 0.6107030430220357,
  'LocalVar': 0.23214285714285715,
  'GlobalVar': 0.42857142857142855},
 'R1_accuracy_full': {'total': 0.3657018362900716,
  'FuncArg': 0.3411967779056387,
  'FuncReturn': 0.5456453305351522,
  'ClassAtribute': 0.12200435729847495,
  'LocalVar': 0.08928571428571429,
  'GlobalVar': 0.14285714285714285},
 'R1_n_labels': 3213}

[Epoch 1] R0 stats:


{'R0_accuracy_partial': {'total': 0.5721841941505912,
  'FuncArg': 0.562392179413456,
  'ClassAtribute': 0.3899782135076253,
  'FuncReturn': 0.6915005246589717,
  'LocalVar': 0.30357142857142855,
  'GlobalVar': 0.8571428571428571},
 'R0_accuracy_full': {'total': 0.47510889856876165,
  'FuncArg': 0.4485336400230017,
  'ClassAtribute': 0.3289760348583878,
  'FuncReturn': 0.6169989506820567,
  'GlobalVar': 0.7142857142857143,
  'LocalVar': 0.05357142857142857},
 'R0_n_labels': 3214}

[Epoch 1] R1 stats:


{'R1_accuracy_partial': {'total': 0.5798319327731093,
  'FuncArg': 0.5650172612197929,
  'FuncReturn': 0.6988457502623295,
  'ClassAtribute': 0.40522875816993464,
  'LocalVar': 0.42857142857142855,
  'GlobalVar': 0.7142857142857143},
 'R1_accuracy_full': {'total': 0.4774354186118892,
  'FuncArg': 0.4430379746835443,
  'FuncReturn': 0.6253934942287513,
  'ClassAtribute': 0.3464052287581699,
  'GlobalVar': 0.5714285714285714,
  'LocalVar': 0.08928571428571429},
 'R1_n_labels': 3213}



[Epoch 2] R0 stats:


{'R0_accuracy_partial': {'total': 0.5883634100808961,
  'FuncArg': 0.5802185163887291,
  'ClassAtribute': 0.4684095860566449,
  'FuncReturn': 0.6715634837355718,
  'LocalVar': 0.375,
  'GlobalVar': 0.8571428571428571},
 'R0_accuracy_full': {'total': 0.48879900435594276,
  'FuncArg': 0.45255894192064405,
  'FuncReturn': 0.6065057712486883,
  'ClassAtribute': 0.4139433551198257,
  'GlobalVar': 0.2857142857142857,
  'LocalVar': 0.25},
 'R0_n_labels': 3214}

[Epoch 2] R1 stats:


{'R1_accuracy_partial': {'total': 0.6050420168067226,
  'FuncArg': 0.593210586881473,
  'ClassAtribute': 0.5185185185185185,
  'FuncReturn': 0.6810073452256034,
  'LocalVar': 0.375,
  'GlobalVar': 0.7142857142857143},
 'R1_accuracy_full': {'total': 0.4964207905384376,
  'FuncArg': 0.452819332566168,
  'ClassAtribute': 0.4596949891067538,
  'FuncReturn': 0.6065057712486883,
  'GlobalVar': 0.42857142857142855,
  'LocalVar': 0.2857142857142857},
 'R1_n_labels': 3213}



In [30]:
_, _, ds, preds = trainer.eval_on_repos(valid_repos[1:2], silent=True)

In [32]:
from spot import PythonType
from spot.data import TypeInfDataset, inline_predictions


def visualize_batch(dataset: TypeInfDataset, preds: list[list[PythonType]], i: int):
    types = preds[i]
    typpes_enc = [tokenizer.encode(str(t), add_special_tokens=False) for t in types]

    code_tks = inline_predictions(dataset.data["input_ids"][i], typpes_enc, tokenizer)
    code_dec = tokenizer.decode(code_tks, skip_special_tokens=False)
    label_dec = dataset.chunks_info[i].types
    return "".join([
        "labels: ", str(label_dec), "\n",
        "preds: ", str(types), "\n",
        "========================== Code =======================\n", code_dec, "\n",
    ])

from spot.visualization import display_code_sequence

display_code_sequence([visualize_batch(ds, preds, i) for i in range(6, 20)])

Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…

In [None]:
def add_name(
        self, name_attr: Attr, space_attr: Attr, new_schema: Schema
    ) -> Atrr
    return 1

In [None]:
display(trainer.timer.as_dataframe())

Unnamed: 0,name,count,avg_time,total_time
3,training > model fitting,7,153.695309,1075.867161
1,training > model prediction,8,84.873097,678.984775
2,training > type checking,7,66.09076,462.635319
0,training > preparing data,15,12.201682,183.025235
