In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.data import GitRepo
from spot.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from spot.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

repos_split_path = datadir / "SPOT-data/repos-processed-with_margin/repos_split.pkl"
with repos_split_path.open("rb") as f:
    repos_split = pickle.load(f)

In [2]:
import torch
from transformers import (
    DataCollatorForSeq2Seq,
    RobertaTokenizer,
    T5ForConditionalGeneration,
)
from transformers.models.t5 import T5ForConditionalGeneration

from spot.model import ModelSPOT, TokenizerSPOT

model_path = datadir / "checkpoints/saved/SPOT-CodeT5-with_margin"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: TokenizerSPOT = TokenizerSPOT.from_pretrained(model_path)
model: ModelSPOT
model = ModelSPOT.from_pretrained(model_path).to(device)



In [3]:
from IPython.display import display

import wandb
from spot.training import DAggerTrainer, DAggerTrainerArgs

args = DAggerTrainerArgs(
    output_dir=proj_root() / "checkpoints" / "DAgger",
    max_epochs=2,
    repos_group_size=16,
    ctx_size=512,
    ctx_margin=128,
    sampling_batch_size=256,
    train_batch_size=42,
    generation_max_length=128,
    max_workers=16,
)

test_run = False

trainer = DAggerTrainer(model, tokenizer, args)
train_repos = [r.repo_dir(repos_dir) for r in repos_split["train"]]
valid_repos = [r.repo_dir(repos_dir) for r in repos_split["valid"]]
if test_run:
    train_repos = train_repos[:10]
    valid_repos = valid_repos[:10]
    
test_tag = 'test-' if test_run else ''
exp_name = f"{test_tag}DAgger-SPOT-CodeT5"
wandb.init(project=exp_name, config=args)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
try:
    trainer.train(train_repos, valid_repos)
except Exception as e:
    wandb.alert(title="Training stopped due to exception", text=f"In {exp_name}, exception: {e}")
    raise e
wandb.alert(title="Training finished", text=f"{exp_name} has finished.")

DAgger Training:   0%|          | 0/1148 [00:00<?, ?it/s]

AssertionError: ty: None has incorrect type <class 'NoneType'>

In [None]:
display(trainer.timer.as_dataframe())

Unnamed: 0,name,count,avg_time,total_time
3,training > model fitting,7,153.695309,1075.867161
1,training > model prediction,8,84.873097,678.984775
2,training > type checking,7,66.09076,462.635319
0,training > preparing data,15,12.201682,183.025235


In [7]:
None == 'None'

False