In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.data import GitRepo
from spot.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from spot.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

repos_split_path = datadir / "SPOT-data/repos-processed-with_margin/repos_split.pkl"
with repos_split_path.open("rb") as f:
    repos_split = pickle.load(f)

In [2]:
import torch
from transformers import (
    DataCollatorForSeq2Seq,
    RobertaTokenizer,
    T5ForConditionalGeneration,
)
from transformers.models.t5 import T5ForConditionalGeneration

from spot.model import ModelSPOT, TokenizerSPOT

model_path = datadir / "checkpoints/saved/SPOT-CodeT5-with_margin"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: TokenizerSPOT = TokenizerSPOT.from_pretrained(model_path)
model: ModelSPOT
model = ModelSPOT.from_pretrained(model_path).to(device)
max_target_length = 128
max_input_length = 1028



In [3]:
import wandb
from spot.training import DAggerTrainer, DAggerTrainerArgs

args = DAggerTrainerArgs(
    output_dir=proj_root() / "checkpoints" / "DAgger",
    max_epochs=2,
    repos_group_size=10,
    ctx_size=512,
    ctx_margin=128,
    sampling_batch_size=64,
    train_batch_size=32,
    generation_max_length=max_target_length,
    generation_num_beams=1,
)

wandb.init(project="test-DAgger-SPOT-CodeT5", config=args)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
from IPython.display import display

trainer = DAggerTrainer(model, tokenizer, args, max_workers=1)
train_repos = [r.repo_dir(repos_dir) for r in repos_split["train"][0:3]]
valid_repos = [r.repo_dir(repos_dir) for r in repos_split["valid"][2:4]]

trainer.train(train_repos, valid_repos)

DAgger Training:   0%|          | 0/6 [00:00<?, ?it/s]

[Epoch 0] R0 stats: {'R0_accuracy_partial': {'total': 0.7835497835497836, 'FuncArg': 0.7642276422764228, 'FuncReturn': 0.9076923076923077, 'ClassAtribute': 0.6451612903225806, 'LocalVar': 0.6666666666666666}, 'R0_accuracy_full': {'total': 0.7748917748917749, 'FuncArg': 0.7642276422764228, 'FuncReturn': 0.8769230769230769, 'ClassAtribute': 0.6451612903225806, 'LocalVar': 0.6666666666666666}, 'R0_n_labels': 231}
[Epoch 0] R1 stats: {'R1_accuracy_partial': {'total': 0.8051948051948052, 'FuncArg': 0.8373983739837398, 'FuncReturn': 0.8923076923076924, 'ClassAtribute': 0.6451612903225806, 'LocalVar': 0.4166666666666667}, 'R1_accuracy_full': {'total': 0.7748917748917749, 'FuncArg': 0.8292682926829268, 'FuncReturn': 0.8, 'ClassAtribute': 0.6451612903225806, 'LocalVar': 0.4166666666666667}, 'R1_n_labels': 231}
[Epoch 1] R0 stats: {'R0_accuracy_partial': {'total': 0.7922077922077922, 'FuncArg': 0.7723577235772358, 'FuncReturn': 0.9230769230769231, 'ClassAtribute': 0.6774193548387096, 'LocalVar':

In [7]:
display(trainer.timer.as_dataframe())

Unnamed: 0,name,count,avg_time,total_time
4,training,2,19.481645,38.96329
3,training > model fitting,2,8.284251,16.568503
8,evaluating,2,8.220471,16.440942
2,training > type checking,2,4.995415,9.99083
7,evaluating > type checking,2,3.839049,7.678098
1,training > model prediction,2,3.096355,6.192711
0,training > preparing data,4,1.547501,6.190005
6,evaluating > model prediction,4,1.451308,5.805234
5,evaluating > preparing data,4,0.734775,2.939101
