In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.utils import cst, proj_root, run_long_task, tqdm

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

import torch

from spot.data import (
    SrcDataset,
    get_dataset_name,
    load_src_datasets,
)
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper
from copy import copy
from spot.train import TrainingConfig

config = TrainingConfig(quicktest=False, all_labels=True)
train_R1: bool = True
gpu_id = 0

train_ctx_args = config.train_ctx_args()

max_tokens_per_file = config.ctx_size
dec_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=20,
)


datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

r0_model_name = "R0-model--" + config.as_name()

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=config.quicktest,
)


  warn(f"Failed to load image Python extension: {e}")


In [3]:
# train the model
from spot.train import ModelTrainingArgs, train_spot_model
import wandb

train_args = ModelTrainingArgs(
    train_ctx_args,
    dec_args,
    train_max_tokens=max_tokens_per_file,
    eval_max_tokens=2 * max_tokens_per_file,
    max_epochs=3,
    check_in_isolation=config.check_in_isolation,
)

project_name = "test-SPOT" if config.quicktest else "SPOT"
wandb.init(
    project=project_name, name=r0_model_name, config=config.as_dict(), dir=str(datadir)
)
r0_wrapper, r0_extra = train_spot_model(
    src_datasets,
    r0_model_name,
    train_args=train_args,
    record_batches=train_R1,
    gpus=[gpu_id],
    quicktest=config.quicktest,
    use_small_model=config.use_small_model,
)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# load trained model
from spot.utils import pickle_load, pickle_dump

r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
if train_R1:
    r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
r0_wrapper.args.do_sample = False
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(left=2048, window=1024, right=1024, max_labels=16), sampling_max_tokens=32768, max_workers=20, max_tokens_per_type=10, do_sample=False, top_p=0.9, num_beams=None)


In [None]:
# checking mypy feedbacks

from spot.type_env import MypyFeedback
from spot.utils import groupby


def ground_truth_feedback(dataset: SrcDataset):
    preds = {
        f: {i: str(t) for i, t in enumerate(src.types)}
        for f, src in dataset.file2src().items()
    }
    return dataset.add_type_checker_feedback(
        r0_wrapper.tokenizer,
        preds,
        in_isolation=check_in_isolation,
        max_workers=20,
        tqdm_args={},
    )


def show_feedback_stats(new_dataset: SrcDataset):
    fb_list: list[list[MypyFeedback]] = new_dataset.extra_stats["mypy_feedbacks"]
    stats = {}
    for k in ["feedbacks_per_file", "type_check_success_ratio"]:
        stats[k] = new_dataset.extra_stats[k]
    stats["total_feedbacks"] = sum(len(l) for l in fb_list)
    error_code_counter = Counter[str]()
    for l in fb_list:
        for fb in l:
            error_code_counter[fb.error_code] += 1
    stats["top_feedbacks"] = dict(error_code_counter.most_common(10))
    pretty_print_accuracies(stats)
    df = pd.DataFrame(error_code_counter.most_common(), columns=["error_code", "count"])
    display(px.bar(df, x="error_code", y="count", title="Error code frequencies"))
    fdbk_srcs = [(f, src) for src, fs in zip(new_dataset.all_srcs, fb_list) for f in fs]
    error_groups = groupby(fdbk_srcs, lambda x: x[0].error_code)
    return error_groups


if train_R1:
    error_groups = show_feedback_stats(r1_src_datasets["test"])


feedbacks_per_file:
   mean: 1.0909
   median: 0
   min: 0
   max: 6
type_check_success_ratio: 1
total_feedbacks: 12
top_feedbacks:
   attr-defined: 4
   return: 3
   return-value: 2
   name-defined: 1
   func-returns-value: 1
   syntax: 1


In [None]:
# visualize feedback samples

from spot.utils import seq_flatten, add_line_numbers
from spot.visualization import code_inline_type_masks, visualize_sequence


if train_R1:
    to_display = []
    for xs in seq_flatten(error_groups.values()):
        src = xs[1]
        code = code_inline_type_masks(src.origin_code, src.types)
        to_display.append(
            f"feedback: {xs[0]}\n" + "=========code=========\n" + add_line_numbers(code)
        )
    display(visualize_sequence(to_display))


VBox(children=(HBox(children=(IntSlider(value=0, max=11), Label(value='(12 total)'))), Box(children=(Output(la…

In [None]:
# model evaluation

import plotly.express as px

from spot.train import evaluate_model, visualize_accuracies

r0_eval = evaluate_model(
    r0_wrapper,
    None,
    r0_model_name,
    src_datasets["test"],
    datadir=datadir,
    check_in_isolation=config.check_in_isolation,
    reeval=False,
)


chunk_srcs_per_file:   0%|          | 0/11 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/18 [00:00<?, ?it/s]

predict:   0%|          | 0/18 [00:00<?, ?it/s]

In [None]:
from spot.utils import pretty_show_dict
import wandb

def wandb_string(s: str) -> str:
    c = f"<div style='white-space: pre-wrap;'>{s}</div>"
    return wandb.Html(c)

for i, e in enumerate(r0_eval):
    wandb.log({f"test/R{i}": wandb_string(pretty_show_dict(e[1].accuracies))})
wandb.finish()
visualize_accuracies(r0_eval)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.558282…

Tab(children=(VBox(children=(Output(), Tab(children=(Output(), Output()), _titles={'0': 'Compressed', '1': 'Ex…

In [None]:
# R1 training

import torch

from spot.data import SrcDataset, get_dataset_name
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper

r1_model_name = "R1-model--" + config.as_name()

wandb.init(
    project=project_name, name=r1_model_name, config=config.as_dict(), dir=str(datadir)
)

r1_wrapper, r1_extra = train_spot_model(
    r1_src_datasets,
    r1_model_name,
    train_args=train_args,
    gpus=[gpu_id],
    record_batches=False,
    quicktest=config.quicktest,
    use_small_model=config.use_small_model,
)


chunk_srcs_per_file:   0%|          | 0/11 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/11 [00:00<?, ?it/s]

chunk_srcs_per_file:   0%|          | 0/7 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/7 [00:00<?, ?it/s]


There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
445.764   Total estimated model params size (MB)


Pushover: (Finished: 'Preparing chunked datasets'.) Time taken: 0.1s


Sanity Checking: 0it [00:00, ?it/s]


Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
                not been set for this class (_ResultMetric). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                


The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 1: 'valid/loss' reached 3.27613 (best 3.27613), saving model to '/mnt/data0/jiayi/checkpoints/lit-running/R1-model--quicktest=True/epoch=0-step=1.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 0, global step 2: 'valid/loss' reached 2.76775 (best 2.76775), saving model to '/mnt/data0/jiayi/checkpoints/lit-running/R1-model--quicktest=True/epoch=0-step=2.ckpt' as top 3
Metric valid/loss improved. New best score: 2.768


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 3: 'valid/loss' reached 2.76775 (best 2.76775), saving model to '/mnt/data0/jiayi/checkpoints/lit-running/R1-model--quicktest=True/epoch=1-step=3.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 4: 'valid/loss' reached 2.68372 (best 2.68372), saving model to '/mnt/data0/jiayi/checkpoints/lit-running/R1-model--quicktest=True/epoch=1-step=4.ckpt' as top 3
Metric valid/loss improved by 0.084 >= min_delta = 0.0. New best score: 2.684


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 5: 'valid/loss' reached 2.66538 (best 2.66538), saving model to '/mnt/data0/jiayi/checkpoints/lit-running/R1-model--quicktest=True/epoch=2-step=5.ckpt' as top 3


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 6: 'valid/loss' reached 2.64813 (best 2.64813), saving model to '/mnt/data0/jiayi/checkpoints/lit-running/R1-model--quicktest=True/epoch=2-step=6.ckpt' as top 3
Metric valid/loss improved by 0.036 >= min_delta = 0.0. New best score: 2.648


Pushover: (Finished: 'Training R1-model--quicktest=True'.) Time taken: 30.4s


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       valid/loss            2.648132562637329
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
# load trained model and evaluate
from spot.train import evaluate_model

r1_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r1_model_name}"
)
r1_wrapper.to(device)

r1_eval = evaluate_model(
    r0_wrapper,
    r1_wrapper,
    r1_model_name,
    src_datasets["test"],
    datadir=datadir,
    check_in_isolation=config.check_in_isolation,
    reeval=False,
)


chunk_srcs_per_file:   0%|          | 0/11 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/18 [00:00<?, ?it/s]

predict:   0%|          | 0/18 [00:00<?, ?it/s]

map type_check_src_in_project:   0%|          | 0/11 [00:00<?, ?it/s]

type_check_success_ratio: 1
feedbacks_per_file:
   mean: 2.1818
   median: 2
   min: 0
   max: 6


feedbacks_to_tokenized_src:   0%|          | 0/11 [00:00<?, ?it/s]

chunk_srcs_per_file:   0%|          | 0/11 [00:00<?, ?it/s]

verify_labels:   0%|          | 0/18 [00:00<?, ?it/s]

predict:   0%|          | 0/18 [00:00<?, ?it/s]

In [None]:
from spot.train import evaluate_model, visualize_accuracies, visualize_conf_matrix

for i, e in enumerate(r1_eval):
    wandb.log({f"test/R{i}": wandb_string(e[1].accuracies)})
wandb.finish()
visualize_accuracies(r1_eval)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▃▃▆▆█
trainer/global_step,▁▂▃▅▆▇█
valid/loss,█▂▂▁▁▁▁

0,1
epoch,3.0
trainer/global_step,6.0
valid/loss,2.64813


Tab(children=(VBox(children=(Output(), Tab(children=(Output(), Output()), _titles={'0': 'Compressed', '1': 'Ex…

In [None]:
visualize_conf_matrix(r1_eval)


interactive(children=(IntSlider(value=1, description='round', max=1), IntSlider(value=10, continuous_update=Fa…

In [None]:
from IPython.display import display

from spot.visualization import visualize_preds_on_code

round = 1
pred_dataset = r1_eval[round][1].chunks
visualize_preds_on_code(pred_dataset, r1_eval[round][1].predictions)


interactive(children=(IntSlider(value=0, continuous_update=False, description='i', max=1093), Output()), _dom_…

Box(children=(Output(),), layout=Layout(overflow='scroll'))

Box(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…