In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.utils import cst, proj_root, run_long_task, tqdm

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

import torch

from spot.data import SrcDataset, get_model_name, load_src_datasets
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelTrainingArgs, ModelWrapper

quicktest = True
drop_comments = False
data_reduction = 1


ctx_args = CtxArgs(
    ctx_size=1024,
    left_margin=256 + 128,
    right_margin=256 - 128,
    types_in_ctx=False,
)

dec_args = DecodingArgs(
    sampling_batch_size=128,
    ctx_args=ctx_args,
    max_workers=20,
)


r0_model_name = get_model_name(
    drop_comments=drop_comments,
    ctx_args=ctx_args,
    data_reduction=data_reduction,
    quicktest=quicktest,
)

src_datasets = load_src_datasets(
    datadir,
    drop_comments=drop_comments,
    spot_round=0,
    data_reduction=data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=quicktest,
)


In [3]:
# train the model
from spot.train import train_spot_model

train_args = ModelTrainingArgs(
    train_batch_size=12,
    eval_batch_size=64,
    max_epochs=1,
)

r0_wrapper, r0_extra = train_spot_model(
    src_datasets,
    r0_model_name,
    dec_args=dec_args,
    train_args=train_args,
    record_batches=True,
    quicktest=quicktest,
)


  warn(f"Failed to load image Python extension: {e}")


processing chunks:   0%|          | 0/3675 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/2615 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/51966 [00:00<?, ?it/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Pushover: (Finished: 'Preparing chunked datasets'.) Time taken: 70.2s


[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
445.764   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Pushover: (Finished: 'Training SPOT-model-R0-(384, 512, 128)'.) Time taken: 3142.3s


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       valid/loss           0.32140225172042847
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Generating R1 dataset: train


  0%|          | 0/9 [00:00<?, ?it/s]

predict:   0%|          | 0/4200 [00:00<?, ?it/s]

predict:   0%|          | 0/4200 [00:00<?, ?it/s]

predict:   0%|          | 0/4200 [00:00<?, ?it/s]

predict:   0%|          | 0/4200 [00:00<?, ?it/s]

predict:   0%|          | 0/4200 [00:00<?, ?it/s]

predict:   0%|          | 0/4200 [00:00<?, ?it/s]

predict:   0%|          | 0/4200 [00:00<?, ?it/s]

predict:   0%|          | 0/4200 [00:00<?, ?it/s]

predict:   0%|          | 0/1237 [00:00<?, ?it/s]

type_check_src:   0%|          | 0/14300 [00:00<?, ?it/s]

num_type_checked: 14285
errors_per_file:
   mean: 1.787
   median: 0
   min: 0
   max: 243


feedbacks_to_tokenized_src:   0%|          | 0/14300 [00:00<?, ?it/s]

Generating R1 dataset: valid


predict:   0%|          | 0/2674 [00:00<?, ?it/s]

type_check_src:   0%|          | 0/998 [00:00<?, ?it/s]

num_type_checked: 998
errors_per_file:
   mean: 2.314
   median: 0
   min: 0
   max: 84


feedbacks_to_tokenized_src:   0%|          | 0/998 [00:00<?, ?it/s]

Generating R1 dataset: test


predict:   0%|          | 0/2068 [00:00<?, ?it/s]

type_check_src:   0%|          | 0/866 [00:00<?, ?it/s]

num_type_checked: 866
errors_per_file:
   mean: 1.768
   median: 0
   min: 0
   max: 118


feedbacks_to_tokenized_src:   0%|          | 0/866 [00:00<?, ?it/s]

Pushover: (Finished: 'Generating R1 datasets'.) Time taken: 3849.6s


In [3]:
# load trained model
from spot.utils import pickle_load

r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
r0_wrapper.args.do_sample = False
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(left=384, window=512, right=128), sampling_batch_size=128, max_workers=20, generation_max_length=128, do_sample=False, top_p=0.9)


In [7]:
# checking mypy feedbacks

from spot.utils import pretty_print_accuracies, groupby
from spot.type_env import MypyFeedback


def ground_truth_feedback(dataset: SrcDataset):
    preds = {
        f: {i: str(t) for i, t in enumerate(src.types)}
        for f, src in dataset.file2src().items()
    }
    return dataset.add_type_checker_feedback(
        r0_wrapper.tokenizer, preds, max_workers=20, tqdm_args={}
    )


def show_feedback_stats(new_dataset: SrcDataset):
    fb_list: list[list[MypyFeedback]] = new_dataset.extra_stats["mypy_feedbacks"]
    stats = {}
    stats["total_feedbacks"] = sum(len(l) for l in fb_list)
    error_code_counter = Counter[str]()
    for l in fb_list:
        for fb in l:
            error_code_counter[fb.error_code] += 1
    stats["top_feedbacks"] = dict(error_code_counter.most_common(10))
    pretty_print_accuracies(stats)
    df = pd.DataFrame(error_code_counter.most_common(), columns=["error_code", "count"])
    display(px.bar(df, x="error_code", y="count"))
    fdbk_srcs = [(f, src) for src, fs in zip(new_dataset.all_srcs, fb_list) for f in fs]
    error_groups = groupby(fdbk_srcs, lambda x: x[0].error_code)
    return error_groups


# new_src = ground_truth_feedback(src_datasets["test"])
error_groups = show_feedback_stats(r1_src_datasets["train"])


type_check_src:   0%|          | 0/866 [00:00<?, ?it/s]

type_check_success_ratio: 1
feedbacks_per_file:
   mean: 0.1374
   median: 0
   min: 0
   max: 38


feedbacks_to_tokenized_src:   0%|          | 0/866 [00:00<?, ?it/s]

In [9]:
from spot.visualization import visualize_texts

def add_line_numbers(code: str):
    lines = code.split("\n")
    return "\n".join(f"{i+1:3d}|  {l}" for i, l in enumerate(lines))


to_display = []
for xs in error_groups["attr-defined"]:
    to_display.append(
        f"feedback: {xs[0]}\n"
        + "=========code=========\n"
        + add_line_numbers(xs[1].origin_code)
    )
visualize_texts(to_display)


total_feedbacks: 119
top_feedbacks:
   attr-defined: 35
   arg-type: 30
   assignment: 15
   index: 8
   return-value: 7
   name-defined: 7
   misc: 5
   has-type: 4
   override: 2
   return: 2


VBox(children=(HBox(children=(IntSlider(value=0, max=34), Label(value='(35 total)'))), Box(children=(Output(la…

In [5]:
# evaluate the model

import plotly.express as px

from spot.data import load_src_datasets, pretty_print_accuracies
from spot.utils import PickleCache


def evaluate_model(
    wrapper: ModelWrapper,
    model_name: str,
    testset: SrcDataset,
    size_factors=[1, 2],
    reeval=False,
):
    eval_cache = PickleCache(datadir / f"checkpoints/lit-saved/{model_name}/eval")
    if reeval:
        eval_cache.clear()

    with run_long_task(f"Evaluate accuracy vs ctx_size: {model_name}"):
        wrapper_args = wrapper.args
        acc_series = []
        try:
            for factor in size_factors:
                wrapper.args = wrapper_args.scale_ctx_size(factor)
                accs = eval_cache.cached(
                    f"ctx_size_factor={factor}.pkl",
                    lambda: wrapper.eval_on_dataset(
                        testset, tqdm_args={"leave": False}
                    )[0],
                )
                acc_series.append(accs)
                print(f"===ctx_size factor: {factor}===")
                print(f"ctx_args: {wrapper.args.ctx_args}")
                pretty_print_accuracies(accs)
        finally:
            wrapper.args = wrapper_args

    acc_df = pd.DataFrame(
        {
            "ctx_size": size_factors,
            "partial_acc": [x["partial_acc"] for x in acc_series],
            "full_acc": [x["full_acc"] for x in acc_series],
        }
    )
    px.line(acc_df, x="ctx_size", y=["partial_acc", "full_acc"], title=model_name)
    return acc_series


evaluate_model(r0_wrapper, r0_model_name, src_datasets["test"])
None


processing chunks:   0%|          | 0/2615 [00:00<?, ?it/s]

predict:   0%|          | 0/2068 [00:00<?, ?it/s]

[PickleCache] Saving to cache: %s /mnt/data0/jiayi/checkpoints/lit-saved/SPOT-model-R0-(384, 512, 128)/eval/ctx_size_factor=1.pkl
===ctx_size factor: 1===
ctx_args: CtxArgs(left=384, window=512, right=128)
partial_acc: 0.8329
partial_acc_wo_any: 0.8395
partial_accs:
   FuncArg: 0.8274
   FuncReturn: 0.8579
   ClassAtribute: 0.7905
   GlobalVar: 0.8393
   LocalVar: 0.8617
full_acc: 0.7762
full_accs:
   FuncArg: 0.7763
   FuncReturn: 0.8171
   ClassAtribute: 0.7189
   GlobalVar: 0.6429
   LocalVar: 0.6482
n_labels: 8421


processing chunks:   0%|          | 0/2615 [00:00<?, ?it/s]

predict:   0%|          | 0/2070 [00:00<?, ?it/s]

[PickleCache] Saving to cache: %s /mnt/data0/jiayi/checkpoints/lit-saved/SPOT-model-R0-(384, 512, 128)/eval/ctx_size_factor=2.pkl
===ctx_size factor: 2===
ctx_args: CtxArgs(left=1280, window=512, right=256)
partial_acc: 0.8535
partial_acc_wo_any: 0.8602
partial_accs:
   FuncArg: 0.8541
   FuncReturn: 0.8787
   ClassAtribute: 0.7964
   GlobalVar: 0.7857
   LocalVar: 0.881
full_acc: 0.8004
full_accs:
   FuncArg: 0.8121
   FuncReturn: 0.8357
   ClassAtribute: 0.7174
   GlobalVar: 0.625
   LocalVar: 0.7024
n_labels: 8418
Pushover: (Finished: 'Evaluate accuracy vs ctx_size: SPOT-model-R0-(384, 512, 128)'.) Time taken: 594.1s


In [10]:
# R1 experiment configurations

import torch

from spot.data import ChunkedDataset, SrcDataset, get_dataset_name, get_model_name
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelTrainingArgs, ModelWrapper

r1_model_name = get_model_name(
    spot_round=1,
    drop_comments=drop_comments,
    ctx_args=ctx_args,
    data_reduction=data_reduction,
    quicktest=quicktest,
)

train_args.max_epochs = 3
r1_wrapper, r1_extra = train_spot_model(
    r1_src_datasets,
    r1_model_name,
    dec_args=dec_args,
    train_args=train_args,
    record_batches=False,
    quicktest=quicktest,
)


processing chunks:   0%|          | 0/3817 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/2734 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/52932 [00:00<?, ?it/s]

Pushover: (Finished: 'Preparing chunked datasets'.) Time taken: 85.5s


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
445.764   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Pushover: (Finished: 'Training SPOT-model-R1-(384, 512, 128)'.) Time taken: 5644.1s


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       valid/loss           0.3151528835296631
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅█
train/loss,▄▄█▁▃▄▁▃▂▃▃▄▃▁▄▃▁▅▂▂▃▄▃▄▂▂▃▁▂▃▃▂▂▅▂▂▂▃▁▃
train/lr,███████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
valid/loss,█▅▃▃▂▃▂▁▁▁▂▂

0,1
epoch,2.0
train/loss,0.35081
train/lr,0.0
trainer/global_step,5835.0
valid/loss,0.31515


Loading best model from:  /mnt/data0/jiayi/checkpoints/lit-running/SPOT-model-R1-(384, 512, 128)/epoch=1-step=4335.ckpt


In [11]:
r1_wrapper.to(device)
evaluate_model(r1_wrapper, r1_model_name, r1_src_datasets["test"], reeval=True)
None




processing chunks:   0%|          | 0/2734 [00:00<?, ?it/s]

predict:   0%|          | 0/2176 [00:00<?, ?it/s]

[PickleCache] Saving to cache: %s /mnt/data0/jiayi/checkpoints/lit-saved/SPOT-model-R1-(384, 512, 128)/eval/ctx_size_factor=1.pkl
===ctx_size factor: 1===
ctx_args: CtxArgs(left=384, window=512, right=128)
partial_acc: 0.8436
partial_acc_wo_any: 0.8483
partial_accs:
   FuncArg: 0.836
   FuncReturn: 0.8742
   ClassAtribute: 0.7919
   GlobalVar: 0.8393
   LocalVar: 0.8933
full_acc: 0.7834
full_accs:
   FuncArg: 0.7865
   FuncReturn: 0.8302
   ClassAtribute: 0.7032
   GlobalVar: 0.6429
   LocalVar: 0.668
n_labels: 8422


processing chunks:   0%|          | 0/2734 [00:00<?, ?it/s]

predict:   0%|          | 0/2178 [00:00<?, ?it/s]

[PickleCache] Saving to cache: %s /mnt/data0/jiayi/checkpoints/lit-saved/SPOT-model-R1-(384, 512, 128)/eval/ctx_size_factor=2.pkl
===ctx_size factor: 2===
ctx_args: CtxArgs(left=1280, window=512, right=256)
partial_acc: 0.8571
partial_acc_wo_any: 0.8624
partial_accs:
   FuncArg: 0.8594
   FuncReturn: 0.8762
   ClassAtribute: 0.8084
   GlobalVar: 0.7857
   LocalVar: 0.881
full_acc: 0.8008
full_accs:
   FuncArg: 0.8121
   FuncReturn: 0.8325
   ClassAtribute: 0.7286
   GlobalVar: 0.6429
   LocalVar: 0.6865
n_labels: 8418
Pushover: (Finished: 'Evaluate accuracy vs ctx_size: SPOT-model-R1-(384, 512, 128)'.) Time taken: 601.2s


In [14]:
from IPython.display import display

from spot.visualization import visualize_code_sequence, visualize_batch


def visualize_preds_code(
    wrapper: ModelWrapper,
    src_dataset: SrcDataset,
    n_visual_exs: int = 16,
):
    _, visual_data, visual_preds = wrapper.eval_on_dataset(
        src_dataset[:n_visual_exs], tqdm_args={"leave": False}
    )

    display(
        visualize_code_sequence(
            [
                visualize_batch(
                    visual_data,
                    i,
                    visual_preds,
                    wrapper.tokenizer,
                    wrapper.args.ctx_args,
                )
                for i in range(min(n_visual_exs, len(visual_preds)))
            ]
        )
    )


visualize_preds_code(r1_wrapper, r1_src_datasets["test"], n_visual_exs=16)


processing chunks:   0%|          | 0/74 [00:00<?, ?it/s]

predict:   0%|          | 0/52 [00:00<?, ?it/s]

Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…