In [3]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.utils import cst, proj_root, run_long_task, tqdm

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
# experiment configurations

import torch

from spot.data import SrcDataset, get_model_name, load_src_datasets, TokenizedSrc
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelTrainingArgs, ModelWrapper

quicktest = False
drop_comments = True
data_reduction = 1
train_per_file = True
train_R1 = False
max_tokens_per_file = 4000


ctx_args = (
    CtxArgs(
        ctx_size=1024,
        left_margin=256 + 128,
        right_margin=256 - 128,
        types_in_ctx=False,
    )
    if not train_per_file
    else None
)

dec_args = DecodingArgs(
    sampling_batch_size=2 * max_tokens_per_file if train_per_file else 128,
    ctx_args=ctx_args,
    max_workers=20,
)


r0_model_name = get_model_name(
    drop_comments=drop_comments,
    ctx_args=ctx_args,
    data_reduction=data_reduction,
    quicktest=quicktest,
)

src_datasets = load_src_datasets(
    datadir,
    drop_comments=drop_comments,
    spot_round=0,
    data_reduction=data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=quicktest,
)

for n in ["train", "valid", "test"]:
    print("Filtering files for", n)
    # tkns = 2 * max_tokens_per_file if n == "test" else max_tokens_per_file
    tkns = max_tokens_per_file
    src_datasets[n].filter_files(max_tokens_per_file=tkns)


Filtering files for train
labels_kept_after_filtering: 0.77591
tkns_kept_after_filtering: 0.68408
Filtering files for valid
labels_kept_after_filtering: 0.75903
tkns_kept_after_filtering: 0.6829
Filtering files for test
labels_kept_after_filtering: 0.78382
tkns_kept_after_filtering: 0.70351


In [5]:
# train the model
from spot.train import train_spot_model

train_args = ModelTrainingArgs(
    train_batch_size=max_tokens_per_file if train_per_file else 12,
    eval_batch_size=2 * max_tokens_per_file if train_per_file else 64,
    max_epochs=3,
)

r0_wrapper, r0_extra = train_spot_model(
    src_datasets,
    r0_model_name,
    dec_args=dec_args,
    train_args=train_args,
    record_batches=train_R1,
    quicktest=quicktest,
    use_small_model=False,
)


  warn(f"Failed to load image Python extension: {e}")


  0%|          | 0/667 [00:00<?, ?it/s]

  0%|          | 0/576 [00:00<?, ?it/s]

  0%|          | 0/9780 [00:00<?, ?it/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Pushover: (Finished: 'Preparing chunked datasets'.) Time taken: 8.5s


[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
445.764   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Pushover: (Finished: 'Training SPOT-model-R0-per_file-drop_comments'.) Time taken: 5444.9s


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       valid/loss           0.22865037620067596
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████
train/loss,▅▅▁▂▅▆▃▃▄▄▂▂▁▂▄▁▆█▄▅▄▇▃▅▁▂▄▅▇▄▃▅▄▃▁▂▃▂▂▁
train/lr,████████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
valid/loss,█▆▅▃▃▃▃▂▂▁▁▁▁▁▁▁

0,1
epoch,2.0
train/loss,0.41909
train/lr,0.0
trainer/global_step,7592.0
valid/loss,0.22865


Loading best model with score 0.22393393516540527 from: /mnt/data0/jiayi/checkpoints/lit-running/SPOT-model-R0-per_file-drop_comments/epoch=1-step=6092.ckpt


In [6]:
# load trained model
from spot.utils import pickle_load

r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
if train_R1:
    r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
    for n, src in r1_src_datasets.items():
        print("Filtering files for", n)
        src.filter_files(max_tokens_per_file=max_tokens_per_file)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
r0_wrapper.args.do_sample = False
print(r0_wrapper.args)


DecodingArgs(ctx_args=None, sampling_batch_size=8000, max_workers=20, max_tokens_per_type=10, do_sample=False, top_p=0.9)


In [7]:
# checking mypy feedbacks

from spot.type_env import MypyFeedback
from spot.utils import groupby, pretty_print_accuracies


def ground_truth_feedback(dataset: SrcDataset):
    preds = {
        f: {i: str(t) for i, t in enumerate(src.types)}
        for f, src in dataset.file2src().items()
    }
    return dataset.add_type_checker_feedback(
        r0_wrapper.tokenizer, preds, max_workers=20, tqdm_args={}
    )


def show_feedback_stats(new_dataset: SrcDataset):
    fb_list: list[list[MypyFeedback]] = new_dataset.extra_stats["mypy_feedbacks"]
    stats = {}
    stats["total_feedbacks"] = sum(len(l) for l in fb_list)
    error_code_counter = Counter[str]()
    for l in fb_list:
        for fb in l:
            error_code_counter[fb.error_code] += 1
    stats["top_feedbacks"] = dict(error_code_counter.most_common(10))
    pretty_print_accuracies(stats)
    df = pd.DataFrame(error_code_counter.most_common(), columns=["error_code", "count"])
    display(px.bar(df, x="error_code", y="count", title="Error code frequencies"))
    fdbk_srcs = [(f, src) for src, fs in zip(new_dataset.all_srcs, fb_list) for f in fs]
    error_groups = groupby(fdbk_srcs, lambda x: x[0].error_code)
    return error_groups


# new_src = ground_truth_feedback(src_datasets["test"])
if train_R1:
    error_groups = show_feedback_stats(r1_src_datasets["train"])


In [8]:
from spot.visualization import visualize_sequence


def add_line_numbers(code: str):
    lines = code.split("\n")
    return "\n".join(f"{i+1:3d}|  {l}" for i, l in enumerate(lines))


if train_R1:
    to_display = []
    for xs in error_groups["name-defined"]:
        to_display.append(
            f"feedback: {xs[0]}\n"
            + "=========code=========\n"
            + add_line_numbers(xs[1].origin_code)
        )
    visualize_sequence(to_display)


In [16]:
# model evaluation

import ipywidgets as widgets
import plotly.express as px

from spot.data import R1_srcs_from_preds, load_src_datasets, pretty_print_accuracies
from spot.utils import PickleCache, assert_eq
from spot.visualization import visualize_sequence


def eval_model_helper(
    r0_wrapper: ModelWrapper,
    r1_wrapper: Optional[ModelWrapper],
    eval_cache: PickleCache,
    r0_srcs: SrcDataset,
    reeval=False,
):
    if reeval:
        eval_cache.clear()
    r0_accs, r0_chunks, r0_preds = eval_cache.cached(
        "r0_eval.pkl",
        lambda: r0_wrapper.eval_on_dataset(r0_srcs, tqdm_args={"leave": False}),
    )
    if r1_wrapper is None:
        return r0_accs

    r1_srcs = eval_cache.cached(
        "r1_srcs.pkl",
        lambda: R1_srcs_from_preds(
            r1_wrapper.tokenizer,
            r0_srcs,
            r0_chunks.chunks_info,
            r0_chunks.files,
            r0_preds,
            max_workers=r0_wrapper.args.max_workers,
        ),
    )

    r1_accs, _, _ = eval_cache.cached(
        "r1_eval.pkl",
        lambda: r1_wrapper.eval_on_dataset(r1_srcs, tqdm_args={"leave": False}),
    )

    return (r0_accs, r1_accs)


def evaluate_model(
    r0_wrapper: ModelWrapper,
    r1_wrapper: Optional[ModelWrapper],
    model_name: str,
    r0_srcs: SrcDataset,
    size_factors=[1, 2],
    reeval=False,
):

    acc_series = []
    accs_seq = []
    if r1_wrapper is not None:
        assert_eq(r0_wrapper.args, r1_wrapper.args)
    with run_long_task(f"Evaluate accuracy vs ctx_size: {model_name}"):
        for factor in size_factors:
            wrapper0 = r0_wrapper.scale_ctx_size(factor)
            wrapper1 = (
                r1_wrapper.scale_ctx_size(factor) if r1_wrapper is not None else None
            )
            eval_cache = PickleCache(
                datadir / f"checkpoints/lit-saved/{model_name}/eval-ctx_size={factor}"
            )
            accs = eval_model_helper(
                wrapper0, wrapper1, eval_cache, r0_srcs, reeval=reeval
            )
            if isinstance(accs, tuple):
                acc_series.append(accs[1])
                accs_seq.extend([("R0", accs[0]), ("R1", accs[1])])
            else:
                accs_seq.extend([("R0", accs)])
                acc_series.append(accs)

    def print_acc(i, expand):
        m, a = accs_seq[i]
        print(f"model={m}, ctx_size_factor={factor}")
        print(f"ctx_args: {wrapper0.args.ctx_args}")
        pretty_print_accuracies(a, max_show_level=100 if expand else 0)

    display(widgets.interactive(print_acc, i=(0, len(accs_seq) - 1), expand=False))

    if len(size_factors) > 1:
        acc_df = pd.DataFrame(
            {
                "ctx_size": size_factors,
                "partial_acc": [x["partial_acc"] for x in acc_series],
                "full_acc": [x["full_acc"] for x in acc_series],
                "full_acc_strict": [x["full_acc_strict"] for x in acc_series],
            }
        )
        display(
            px.line(
                acc_df,
                x="ctx_size",
                y=["partial_acc", "full_acc", "full_acc_strict"],
                title=model_name,
            )
        )
    return acc_series


evaluate_model(
    r0_wrapper,
    None,
    r0_model_name,
    src_datasets["test"][-50:],
    size_factors=[1] if train_per_file else [1, 2],
    reeval=True,
)
None


In [10]:
# R1 experiment configurations

import torch

from spot.data import ChunkedDataset, SrcDataset, get_dataset_name, get_model_name
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelTrainingArgs, ModelWrapper

r1_model_name = get_model_name(
    spot_round=1,
    drop_comments=drop_comments,
    ctx_args=ctx_args,
    data_reduction=data_reduction,
    quicktest=quicktest,
)


In [11]:
train_args.max_epochs = 3

r1_wrapper, r1_extra = train_spot_model(
    r1_src_datasets,
    r1_model_name,
    dec_args=dec_args,
    train_args=train_args,
    record_batches=False,
    quicktest=quicktest,
)


NameError: name 'r1_src_datasets' is not defined

In [None]:
# load trained model and evaluate

r1_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r1_model_name}"
)

r1_wrapper.to(device)

evaluate_model(
    r0_wrapper,
    r1_wrapper,
    r1_model_name,
    src_datasets["test"],
    reeval=False,
    size_factors=[1] if train_per_file else [1, 2],
)
None


  0%|          | 0/12 [00:00<?, ?it/s]

predict:   0%|          | 0/12 [00:00<?, ?it/s]

[PickleCache] Saving to cache: '/mnt/data0/jiayi/checkpoints/lit-saved/quicktest-SPOT-model-R1-per_file/eval-ctx_size=1/r0_eval.pkl'


type_check_src:   0%|          | 0/12 [00:00<?, ?it/s]

type_check_success_ratio: 1
feedbacks_per_file:
   mean: 4.167
   median: 0
   min: 0
   max: 20


feedbacks_to_tokenized_src:   0%|          | 0/12 [00:00<?, ?it/s]

[PickleCache] Saving to cache: '/mnt/data0/jiayi/checkpoints/lit-saved/quicktest-SPOT-model-R1-per_file/eval-ctx_size=1/r1_srcs.pkl'


  0%|          | 0/12 [00:00<?, ?it/s]

predict:   0%|          | 0/12 [00:00<?, ?it/s]

[PickleCache] Saving to cache: '/mnt/data0/jiayi/checkpoints/lit-saved/quicktest-SPOT-model-R1-per_file/eval-ctx_size=1/r1_eval.pkl'


  0%|          | 0/12 [00:00<?, ?it/s]

predict:   0%|          | 0/12 [00:00<?, ?it/s]

[PickleCache] Saving to cache: '/mnt/data0/jiayi/checkpoints/lit-saved/quicktest-SPOT-model-R1-per_file/eval-ctx_size=2/r0_eval.pkl'


type_check_src:   0%|          | 0/12 [00:00<?, ?it/s]

type_check_success_ratio: 1
feedbacks_per_file:
   mean: 4.167
   median: 0
   min: 0
   max: 20


feedbacks_to_tokenized_src:   0%|          | 0/12 [00:00<?, ?it/s]

[PickleCache] Saving to cache: '/mnt/data0/jiayi/checkpoints/lit-saved/quicktest-SPOT-model-R1-per_file/eval-ctx_size=2/r1_srcs.pkl'


  0%|          | 0/12 [00:00<?, ?it/s]

predict:   0%|          | 0/12 [00:00<?, ?it/s]

[PickleCache] Saving to cache: '/mnt/data0/jiayi/checkpoints/lit-saved/quicktest-SPOT-model-R1-per_file/eval-ctx_size=2/r1_eval.pkl'
Pushover: (Finished: 'Evaluate accuracy vs ctx_size: quicktest-SPOT-model-R1-per_file'.) Time taken: 41.1s


VBox(children=(HBox(children=(IntSlider(value=0, max=3), Label(value='(4 total)'))), Box(children=(Output(),),…

In [None]:
from IPython.display import display

from spot.visualization import visualize_batch, visualize_code_sequence


def visualize_preds_code(
    wrapper: ModelWrapper,
    src_dataset: SrcDataset,
    n_visual_exs: int = 16,
):
    _, visual_data, visual_preds = wrapper.eval_on_dataset(
        src_dataset[:n_visual_exs], tqdm_args={"leave": False}
    )

    display(
        visualize_code_sequence(
            [
                visualize_batch(
                    visual_data,
                    i,
                    visual_preds,
                    wrapper.tokenizer,
                    wrapper.args.ctx_args,
                )
                for i in range(min(n_visual_exs, len(visual_preds)))
            ]
        )
    )


visualize_preds_code(r1_wrapper, r1_src_datasets["test"], n_visual_exs=50)


  0%|          | 0/12 [00:00<?, ?it/s]

predict:   0%|          | 0/12 [00:00<?, ?it/s]

VBox(children=(HBox(children=(IntSlider(value=0, max=11), Label(value='(12 total)'))), Box(children=(Output(la…