In [1]:
%load_ext autoreload
%autoreload 2

import os
from typing import *

from typet5.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()

In [2]:
# experiment configurations

import torch

from typet5.data import (
    TokenizedSrcSet,
    get_dataset_name,
    load_tokenized_srcsets,
    TypeCheckSettings,
)
from typet5.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper
from copy import copy
from typet5.train import TrainingConfig, TypeCheckArgs

config = TrainingConfig(quicktest=False, all_labels=True)
train_R1: bool = True
load_R0: bool = True
load_critic: bool = False
gpu_id = 0
TypeCheckSettings.temp_path = f"GPU-{gpu_id}"

project_name = "test-SPOT" if config.quicktest else "SPOT"
train_ctx_args = config.train_ctx_args()
tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)

max_tokens_per_file = config.ctx_size
dec_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=20,
)


datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

r0_model_name = "R0-model--" + config.as_name()

tk_dataset = load_tokenized_srcsets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    quicktest=config.quicktest,
)


  warn(f"Failed to load image Python extension: {e}")


Loading datasets:  tk_dataset-all_labels-drop_comments


In [3]:
# train the model
from typet5.train import ModelTrainingArgs, train_spot_model, TypeCheckArgs
import wandb

train_args = ModelTrainingArgs(
    train_ctx_args,
    dec_args,
    train_max_tokens=max_tokens_per_file,
    eval_max_tokens=2 * max_tokens_per_file,
    max_epochs=2,
    tc_args=tc_args,
)

if not load_R0:
    wandb.init(
        project=project_name,
        name=r0_model_name,
        config=config.as_dict(),
        dir=str(datadir),
    )
    r0_wrapper, r0_extra = train_spot_model(
        tk_dataset,
        r0_model_name,
        train_args=train_args,
        record_batches=train_R1,
        gpus=[gpu_id],
        quicktest=config.quicktest,
        use_small_model=config.use_small_model,
    )


In [3]:
# load trained model
from typet5.utils import pickle_load, pickle_dump

r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
# if train_R1:
    # r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    # r1_tk_dataset: dict[str, TokenizedSrcSet] = r0_extra["R1-tk_dataset"]
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
r0_wrapper.args.do_sample = False
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(ctx_size=4096, left_margin=2048, right_margin=1024), sampling_max_tokens=32768, max_workers=20)


In [12]:
# test DAgger
from typet5.dagger import DAggerModel
from typet5.utils import print_limited, display, pretty_print_dict

dmodel = DAggerModel(r0_wrapper)

metrics = await dmodel.eval_on_data(tk_dataset["test"][1:10], concurrency=8)
pretty_print_dict(metrics)

display(dmodel.t_logger.as_dataframe())

compute_preexisting_fdbks: 100%|██████████| 3/3 [00:02<00:00,  1.25it/s]
Evaluating: 100%|██████████| 219/219 [01:03<00:00,  3.44it/s]


loss: nan
acc: 0.73973


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Unnamed: 0,name,count,avg_time,total_time
1,type checking,219,0.56488,123.708648
0,predict next type,219,0.306826,67.19479
2,generate new src,219,0.206394,45.200321


In [13]:
# train the DAgger model
from typet5.dagger import DAggerModel, DAggerArgs
from typet5.utils import display, pretty_print_dict

dmodel = DAggerModel(r0_wrapper)
dmodel.t_logger.clear()
await dmodel.train_on_data(tk_dataset, DAggerArgs())
display(dmodel.t_logger.as_dataframe())

compute_preexisting_fdbks: 100%|██████████| 612/612 [00:59<00:00, 10.33it/s]
  if SIMPLE_WHITESPACE_RE.fullmatch(self.value) is None:
  return tuple(f for f in fields.values() if f._field_type is _FIELD)
  self._stack.append(CodePosition(self.line, self.column))
  self._codegen_impl(state, **kwargs)
  self.gen = func(*args, **kwds)
  start = CodePosition(self.line, self.column)
Training:   0%|          | 316/295457 [01:40<25:59:28,  3.15it/s]


CancelledError: 

In [None]:
display(dmodel.t_logger.as_dataframe())

In [None]:
# model evaluation

import plotly.express as px

from typet5.train import evaluate_model
from typet5.utils import PickleCache
from typet5.visualization import display_persist, dict_widget

r0_cache = PickleCache(datadir / f"checkpoints/lit-saved/{r0_model_name}/eval_cache")
r0_eval = evaluate_model(
    r0_wrapper,
    None,
    tk_dataset["test"],
    eval_cache=r0_cache,
    tc_args=train_args.tc_args,
)
r0_accs = r0_eval[0][1].accuracies
display_persist(dict_widget(r0_accs))


In [None]:
# close wandb
from typet5.utils import pretty_show_dict
from typet5.visualization import string_to_html
import wandb


def wandb_string(s: str):
    return wandb.Html(string_to_html(s))


if not load_R0:
    for i, e in enumerate(r0_eval):
        wandb.log({f"test/R{i}": wandb_string(pretty_show_dict(e[1].accuracies))})
    wandb.finish()


In [None]:
# export the code with inlined predictions as HTML

from typet5.visualization import export_preds_on_code, display_persist, proj_root

export_preds = False

if export_preds:
    pr = r0_eval[0][1]
    sub_ids = range(0, len(pr.chunks), 10)
    export_preds_on_code(
        pr.chunks[sub_ids],
        [pr.predictions[i] for i in sub_ids],
        {},
        export_to=proj_root() / "R0_predictions",
    )


In [None]:
# train the critic
from typet5.critic import (
    CriticModel,
    ModelSPOT,
    train_critic_model,
    CriticTrainArgs,
    get_critic_name,
)
from typet5.utils import pickle_load, run_long_task, PickleCache
from typet5.train import R1_srcs_from_extra, R1_srcs_from_model
import wandb

critic_new_data = True
critic_no_feedback = False
critic_name = get_critic_name(critic_no_feedback, critic_new_data, config)

with run_long_task(f"Training Critic: {critic_name}", notify=not load_critic):
    critic_train_args = CriticTrainArgs(
        ctx_args=train_ctx_args,
        train_max_tokens=max_tokens_per_file,
        eval_max_tokens=2 * max_tokens_per_file,
        max_epochs=1,
    )

    critic_tc_args = tc_args._replace(no_feedback=critic_no_feedback)
    critic_cache = PickleCache(
        datadir / f"checkpoints/lit-saved/CriticData-{critic_name}"
    )
    # critic_cache.remove("tk_dataset")
    critic_tk_dataset: dict[str, TokenizedSrcSet]

    if critic_new_data:
        # use sampling to increase example diversity
        r0_wrapper.args.do_sample = True
        r0_wrapper.args.top_p = 0.9

    critic_tk_dataset = critic_cache.cached(
        "tk_dataset",
        lambda: {
            k: v.inline_predictions(as_comment=False)
            for k, v in (
                R1_srcs_from_model(
                    r0_wrapper,
                    tk_dataset,
                    critic_tc_args,
                )
                if critic_new_data
                else R1_srcs_from_extra(
                    r0_wrapper,
                    tk_dataset,
                    extra=pickle_load(
                        datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl"
                    ),
                    tc_args=critic_tc_args,
                )
            ).items()
        },
    )

    if not load_critic:
        wandb.init(
            project=project_name,
            name=critic_name,
            config=config.as_dict(),
            dir=str(datadir),
        )
        critic, critic_extra = train_critic_model(
            critic_tk_dataset,
            critic_train_args,
            critic_name,
            gpus=[gpu_id],
            quicktest=config.quicktest,
            use_early_stop=False,
            use_small_model=config.use_small_model,
        )
        # critic.save_pretrained("CriticSaved")
        wandb.finish()


In [None]:
# load trained critic
from typet5.utils import pickle_load, pickle_dump
from typet5.critic import CriticModel

critic = CriticModel.load(datadir / f"checkpoints/lit-saved/{critic_name}")
if train_R1 and ("r1_tk_dataset" not in globals()):
    r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    r1_tk_dataset: dict[str, TokenizedSrcSet] = r0_extra["R1-tk_dataset"]

device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
critic.to(device)
print("Critic loaded.")


In [None]:
# show critic performance

from typet5.visualization import visualize_preds_on_code, pretty_print_dict

device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
critic.to(device)
r1_testset = critic_tk_dataset["test"]
critic_eval = critic.eval_on_src_dataset(
    r1_testset, train_ctx_args, dec_args.sampling_max_tokens
)
nicer_preds = [[f"{x:.1%}" for x in xs] for xs in critic_eval[1]]
pretty_print_dict(critic_eval[2])


In [None]:
# The performance achieved by always predicting true or random values

from typet5.utils import not_none, pretty_print_dict
from typet5.type_check import normalize_type
from typet5.critic import CriticModel
import random


def dummy_performance(dataset: TokenizedSrcSet, pred_f):
    targets = list[bool]()
    for s in dataset.all_srcs:
        for p, t in zip(not_none(s.prev_types).values(), s.types):
            targets.append(normalize_type(t) == normalize_type(p))

    preds = [pred_f() for _ in range(len(targets))]
    return CriticModel.compute_metrics(preds, targets)


pretty_print_dict(dummy_performance(r1_testset, lambda: True))
pretty_print_dict(dummy_performance(r1_testset, lambda: random.choice([True, False])))


In [None]:
from typet5.utils import DefaultTokenizer, decode_tokens, np


def chunk_has_fdbk(tks):
    return "/* error:" in decode_tokens(tks)


test_chunks = r1_tk_dataset["test"].to_chunks(DefaultTokenizer, dec_args.ctx_args)
fraction_chunks_with_fdbk = np.mean(
    [chunk_has_fdbk(tks) for tks in test_chunks.data["input_ids"]]
)
print("Fraction of chunks with feedback:", fraction_chunks_with_fdbk)


In [None]:
# checking mypy feedbacks
from typet5.visualization import show_feedback_stats

if train_R1:
    error_groups = show_feedback_stats(r1_tk_dataset["test"])


In [None]:
# visualize feedback samples

from typet5.utils import seq_flatten, add_line_numbers
from typet5.visualization import code_inline_type_masks, visualize_sequence, display


if train_R1:
    to_display = []
    for xs in error_groups["return-value"]:  # seq_flatten(error_groups.values()):
        src = xs[1]
        code = code_inline_type_masks(src.origin_code, src.types)
        to_display.append(
            f"feedback: {xs[0]}\n" + "=========code=========\n" + add_line_numbers(code)
        )
    if len(to_display) > 0:
        display(visualize_sequence(to_display))


In [None]:
# R1 training

import torch
import wandb
from typet5.data import TokenizedSrcSet, get_dataset_name
from typet5.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper

load_R1 = False
r1_model_name = "R1-model--" + config.as_name()

if not load_R1:
    wandb.init(
        project=project_name,
        name=r1_model_name,
        config=config.as_dict(),
        dir=str(datadir),
    )

    r1_train_args = copy(train_args)
    r1_train_args.max_epochs = 1

    r1_wrapper, r1_extra = train_spot_model(
        r1_tk_dataset,
        r1_model_name,
        train_args=r1_train_args,
        gpus=[gpu_id],
        record_batches=False,
        quicktest=config.quicktest,
        use_early_stop=False,
        use_small_model=config.use_small_model,
    )


In [None]:
# load trained model and evaluate
from typet5.train import evaluate_model
from typet5.visualization import visualize_dicts

r1_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r1_model_name}"
)
r1_wrapper.to(device)

r1_cache = PickleCache(datadir / f"checkpoints/lit-saved/{r1_model_name}/eval_cache")
r1_cache.clear()
r1_eval = evaluate_model(
    r0_wrapper,
    r1_wrapper,
    tk_dataset["test"],
    tc_args=tc_args,
    eval_cache=r1_cache,
)
visualize_dicts([x[1].accuracies for x in r1_eval])


In [None]:
from typet5.visualization import export_preds_on_code, display_persist, proj_root

eval_to_viz = r1_eval[1][1]
sub_ids = range(0, len(eval_to_viz.chunks), 10)
export_preds_on_code(
    eval_to_viz.chunks[sub_ids],
    [eval_to_viz.predictions[i] for i in sub_ids],
    {},
    export_to=proj_root() / "caches/R1_predictions",
)


In [None]:
from typet5.visualization import visualize_conf_matrix

visualize_conf_matrix({n: x[1] for n, x in zip(["R0", "R1"], r1_eval)})


In [None]:
from typet5.utils import pretty_show_dict

if not load_R1:
    for i, e in enumerate(r1_eval):
        wandb.log({f"test/R{i}": wandb_string(pretty_show_dict(e[1].accuracies))})
    wandb.finish()


In [None]:
from IPython.display import display

from typet5.visualization import visualize_preds_on_code

round = 1
pred_dataset = r1_eval[round][1].chunks
visualize_preds_on_code(pred_dataset, r1_eval[round][1].predictions, dict())
