In [1]:
%load_ext autoreload
%autoreload 2

import os
from typing import *

from spot.utils import proj_root, get_data_dir

os.chdir(proj_root())

datadir = get_data_dir()
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

import torch

from spot.data import (
    SrcDataset,
    get_dataset_name,
    load_src_datasets,
)
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper
from copy import copy
from spot.train import TrainingConfig, TypeCheckArgs

config = TrainingConfig(quicktest=False, all_labels=True)
train_R1: bool = True
load_trained: bool = True
gpu_id = 0

project_name = "test-SPOT" if config.quicktest else "SPOT"
train_ctx_args = config.train_ctx_args()
tc_args = TypeCheckArgs(check_in_isolation=config.check_in_isolation)

max_tokens_per_file = config.ctx_size
dec_args = DecodingArgs(
    sampling_max_tokens=8 * max_tokens_per_file,
    ctx_args=config.dec_ctx_args(),
    max_workers=20,
)


datasets_name = get_dataset_name(
    drop_comments=config.drop_comments,
    all_labels=config.all_labels,
)

r0_model_name = "R0-model--" + config.as_name()

src_datasets = load_src_datasets(
    datadir,
    datasets_name,
    data_reduction=config.data_reduction,
    repos_root=datadir / "SPOT-data/repos",
    quicktest=config.quicktest,
)


  warn(f"Failed to load image Python extension: {e}")


Loading datasets:  src_datasets-all_labels-drop_comments


In [3]:
# train the model
from spot.train import ModelTrainingArgs, train_spot_model, TypeCheckArgs
import wandb

train_args = ModelTrainingArgs(
    train_ctx_args,
    dec_args,
    train_max_tokens=max_tokens_per_file,
    eval_max_tokens=2 * max_tokens_per_file,
    max_epochs=2,
    tc_args=tc_args,
)

if not load_trained:
    wandb.init(
        project=project_name,
        name=r0_model_name,
        config=config.as_dict(),
        dir=str(datadir),
    )
    r0_wrapper, r0_extra = train_spot_model(
        src_datasets,
        r0_model_name,
        train_args=train_args,
        record_batches=train_R1,
        gpus=[gpu_id],
        quicktest=config.quicktest,
        use_small_model=config.use_small_model,
    )


In [4]:
# load trained model
from spot.utils import pickle_load, pickle_dump

r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r0_model_name}"
)
if train_R1:
    r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
r0_wrapper.to(device)
r0_wrapper.args.do_sample = False
print(r0_wrapper.args)


DecodingArgs(ctx_args=CtxArgs(ctx_size=4096, left_margin=2048, right_margin=1024), sampling_max_tokens=32768, max_workers=20)


In [6]:
# model evaluation

import plotly.express as px

from spot.train import evaluate_model
from spot.utils import PickleCache
from spot.visualization import visualize_dicts, pretty_print_dict

r0_cache = PickleCache(datadir / f"checkpoints/lit-saved/{r0_model_name}/eval_cache")
r0_eval = evaluate_model(
    r0_wrapper,
    None,
    src_datasets["test"],
    eval_cache=r0_cache,
    tc_args=train_args.tc_args,
)
for d in [x[1].accuracies for x in r0_eval]:
    pretty_print_dict(d)


partial_acc (ImNone): 75.42% (count=16.9k)
full_acc (ImNone): 72.22% (count=16.9k)
partial_acc: 73.96% (count=16.9k)
ast_acc: 67.52% (count=21.3k)
full_acc: 69.25% (count=16.9k)
partial_acc_by_cat:
   FuncArg: 68.85% (count=8.0k)
   FuncReturn: 83.87% (count=5.7k)
   ClassAtribute: 65.95% (count=2.7k)
   GlobalVar: 86.54% (count=104)
   LocalVar: 82.11% (count=531)
partial_acc_by_pos:
   range(0, 1): 74.83% (count=1.6k)
   range(1, 2): 77.19% (count=1.6k)
   range(2, 4): 75.92% (count=2.8k)
   range(4, 8): 73.75% (count=4.6k)
   range(8, 16): 72.23% (count=6.3k)
avg_label_size: 1.2589
avg_pred_size: 1.2327


In [7]:
# close wandb
from spot.utils import pretty_show_dict
from spot.visualization import string_to_html
import wandb


def wandb_string(s: str):
    return wandb.Html(string_to_html(s))


if not load_trained:
    for i, e in enumerate(r0_eval):
        wandb.log({f"test/R{i}": wandb_string(pretty_show_dict(e[1].accuracies))})
    wandb.finish()


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▅▅▅▅█
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.k.weight_epoch,▁█
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.o.weight_epoch,▁█
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.q.weight_epoch,▁█
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight_epoch,▁█
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.v.weight_epoch,▁█
grad_2.0_norm/model.decoder.block.0.layer.0.layer_norm.weight_epoch,▁█
grad_2.0_norm/model.decoder.block.0.layer.1.EncDecAttention.k.weight_epoch,▁█
grad_2.0_norm/model.decoder.block.0.layer.1.EncDecAttention.o.weight_epoch,▁█
grad_2.0_norm/model.decoder.block.0.layer.1.EncDecAttention.q.weight_epoch,▁█

0,1
epoch,2.0
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.k.weight_epoch,0.1291
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.o.weight_epoch,0.8254
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.q.weight_epoch,0.15374
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight_epoch,0.10869
grad_2.0_norm/model.decoder.block.0.layer.0.SelfAttention.v.weight_epoch,1.17765
grad_2.0_norm/model.decoder.block.0.layer.0.layer_norm.weight_epoch,0.47247
grad_2.0_norm/model.decoder.block.0.layer.1.EncDecAttention.k.weight_epoch,0.70331
grad_2.0_norm/model.decoder.block.0.layer.1.EncDecAttention.o.weight_epoch,1.57787
grad_2.0_norm/model.decoder.block.0.layer.1.EncDecAttention.q.weight_epoch,0.71179


In [18]:
# export the code with inlined predictions as HTML

from spot.visualization import export_preds_on_code, display_persist, proj_root

eval_to_viz = r0_eval[0][1]
sub_ids = range(0, len(eval_to_viz.chunks), 10)
export_preds_on_code(
    eval_to_viz.chunks[sub_ids],
    [eval_to_viz.predictions[i] for i in sub_ids],
    {},
    export_to=proj_root() / "R0_predictions",
)


Exporting: 100%|██████████| 164/164 [00:01<00:00, 132.31it/s]


In [8]:
# train the critic
from spot.critic import CriticModel, ModelSPOT, train_critic_model, CriticTrainArgs
from spot.utils import pickle_load, run_long_task, PickleCache
from spot.train import R1_srcs_from_extra
import wandb

critic_no_feedback = True

if train_R1:
    with run_long_task("Training Critic", notify=not load_trained):
        critic_train_args = CriticTrainArgs(
            ctx_args=train_ctx_args,
            train_max_tokens=max_tokens_per_file,
            eval_max_tokens=2 * max_tokens_per_file,
            max_epochs=1,
        )
        feedback_tag = "no_feedback-" if critic_no_feedback else ""
        critic_name = "critic-model--" + feedback_tag + config.as_name()

        critic_tc_args = tc_args._replace(no_feedback=critic_no_feedback)
        critic_cache = PickleCache(
            datadir / f"checkpoints/lit-saved/CriticData-{critic_name}"
        )
        # critic_cache.remove("src_datasets")
        critic_src_datasets: dict[str, SrcDataset] = critic_cache.cached(
            "src_datasets",
            lambda: {
                k: v.inline_predictions(as_comment=False)
                for k, v in R1_srcs_from_extra(
                    r0_wrapper,
                    src_datasets,
                    extra=pickle_load(
                        datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl"
                    ),
                    tc_args=critic_tc_args,
                ).items()
            },
        )

        if not load_trained:
            wandb.init(
                project=project_name,
                name=critic_name,
                config=config.as_dict(),
                dir=str(datadir),
            )
            critic, critic_extra = train_critic_model(
                critic_src_datasets,
                critic_train_args,
                critic_name,
                gpus=[gpu_id],
                quicktest=config.quicktest,
                use_early_stop=False,
                use_small_model=config.use_small_model,
            )
            # critic.save_pretrained("CriticSaved")
            wandb.finish()


Starting task: Training Critic


predict:   9%|▉         | 182/2002 [01:37<16:11,  1.87it/s]


Generating R1 dataset: train


chunk_srcs_per_file: 100%|██████████| 933/933 [00:14<00:00, 63.36it/s] 
verify_labels: 100%|██████████| 1637/1637 [00:00<00:00, 11356.96it/s]
chunk_srcs_per_file: 100%|██████████| 1087/1087 [00:17<00:00, 61.75it/s]
verify_labels: 100%|██████████| 2002/2002 [00:00<00:00, 11629.19it/s]
chunk_srcs_per_file: 100%|██████████| 16281/16281 [04:16<00:00, 63.36it/s] 
verify_labels: 100%|██████████| 29269/29269 [00:01<00:00, 16372.90it/s]


type_check_success_ratio: 1
feedbacks_per_file:
   mean: 0
   median: 0
   min: 0
   max: 0


feedbacks_to_tokenized_src:   0%|          | 0/12225 [00:00<?, ?it/s]

Generating R1 dataset: valid


predict: 100%|██████████| 2002/2002 [08:32<00:00,  3.90it/s]


type_check_success_ratio: 1
feedbacks_per_file:
   mean: 0
   median: 0
   min: 0
   max: 0


feedbacks_to_tokenized_src:   0%|          | 0/1087 [00:00<?, ?it/s]

Generating R1 dataset: test


predict: 100%|██████████| 1637/1637 [06:28<00:00,  4.21it/s]


type_check_success_ratio: 1
feedbacks_per_file:
   mean: 0
   median: 0
   min: 0
   max: 0


feedbacks_to_tokenized_src:   0%|          | 0/933 [00:00<?, ?it/s]

inline_predictions: 100%|██████████| 12225/12225 [08:07<00:00, 25.09it/s] 
inline_predictions: 100%|██████████| 1087/1087 [00:39<00:00, 27.58it/s]
inline_predictions: 100%|██████████| 933/933 [01:12<00:00, 12.90it/s]


chunk_srcs_per_file: 100%|██████████| 1087/1087 [00:12<00:00, 84.07it/s]
verify_labels: 100%|██████████| 1465/1465 [00:00<00:00, 11335.04it/s]
chunk_srcs_per_file: 100%|██████████| 933/933 [00:10<00:00, 93.02it/s] 
verify_labels: 100%|██████████| 1165/1165 [00:00<00:00, 10628.82it/s]
chunk_srcs_per_file: 100%|██████████| 12225/12225 [02:19<00:00, 87.54it/s] 
verify_labels: 100%|██████████| 15438/15438 [00:01<00:00, 11623.82it/s]
  rank_zero_warn(


pos_weight = 1.0


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        valid/F1            0.33512205387205385
     valid/accuracy         0.3752533491522072
       valid/loss           0.7158970236778259
     valid/pos_rate         0.20045478641986847
     valid/precision        0.7854500616522812
      valid/recall          0.21300073563833344
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]

  | Name  | Type        | Params
--------------------------------------
0 | model | CriticModel | 109 M 
--------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
219.216   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         test/F1             0.832843137254902
      test/accuracy         0.7585840707964602
        test/loss           0.5557802319526672
      test/pos_rate         0.7882006168365479
     test/precision         0.7630239520958084
       test/recall           0.916726618705036
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Pushover: (Failed: Training Critic.) 'ZMQDisplayPublisher' object has no attribute '_orig_publish'


AttributeError: 'ZMQDisplayPublisher' object has no attribute '_orig_publish'

In [4]:
# load trained critic
from spot.utils import pickle_load, pickle_dump
from spot.critic import CriticModel

critic = CriticModel.load(datadir / f"checkpoints/lit-saved/{critic_name}")
if train_R1 and ("r1_src_datasets" not in globals()):
    r0_extra = pickle_load(datadir / f"checkpoints/lit-saved/{r0_model_name}/extra.pkl")
    r1_src_datasets: dict[str, SrcDataset] = r0_extra["R1-src_datasets"]

device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
critic.to(device)
print("Critic loaded.")


Critic loaded.


In [10]:
# show critic performance

from spot.visualization import visualize_preds_on_code, pretty_print_dict

device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
critic.to(device)
r1_testset = critic_src_datasets["test"]
critic_eval = critic.eval_on_src_dataset(
    r1_testset, train_ctx_args, dec_args.sampling_max_tokens
)
nicer_preds = [[f"{x:.1%}" for x in xs] for xs in critic_eval[1]]
pretty_print_dict(critic_eval[2])


chunk_srcs_per_file: 100%|██████████| 933/933 [00:10<00:00, 92.38it/s] 
verify_labels: 100%|██████████| 1165/1165 [00:00<00:00, 10572.36it/s]
predict: 100%|██████████| 1165/1165 [02:11<00:00,  8.86it/s]

accuracy: 0.75858
F1: 0.83284
precision: 0.76302
recall: 0.91673
pos_rate: 0.7882





In [13]:
# The performance achieved by always predicting true or random values

from spot.utils import not_none, pretty_print_dict
from spot.type_check import normalize_type
from spot.critic import CriticModel
import random


def dummy_performance(dataset: SrcDataset, pred_f):
    targets = list[bool]()
    for s in dataset.srcs_with_labels():
        for p, t in zip(not_none(s.prev_types).values(), s.types):
            targets.append(normalize_type(t) == normalize_type(p))

    preds = [pred_f() for _ in range(len(targets))]
    return CriticModel.compute_metrics(preds, targets)


pretty_print_dict(dummy_performance(r1_testset, lambda: True))
pretty_print_dict(dummy_performance(r1_testset, lambda: random.choice([True, False])))


accuracy: 0.65605
F1: 0.7923
precision: 0.65605
recall: 1
pos_rate: 1
accuracy: 0.49876
F1: 0.56807
precision: 0.65345
recall: 0.50243
pos_rate: 0.50442


In [17]:
from spot.utils import DefaultTokenizer, decode_tokens, np

def chunk_has_fdbk(tks):
    return "/* error:" in decode_tokens(tks)


test_chunks = r1_src_datasets["test"].to_chunks(DefaultTokenizer, dec_args.ctx_args)
fraction_chunks_with_fdbk = np.mean([chunk_has_fdbk(tks) for tks in test_chunks.data["input_ids"]])
print("Fraction of chunks with feedback:", fraction_chunks_with_fdbk)

chunk_srcs_per_file: 100%|██████████| 933/933 [00:14<00:00, 65.79it/s] 
verify_labels: 100%|██████████| 1656/1656 [00:00<00:00, 17364.38it/s]


0.4945652173913043

In [16]:
# checking mypy feedbacks
from spot.visualization import show_feedback_stats

if train_R1:
    error_groups = show_feedback_stats(r1_src_datasets["test"])


feedbacks_per_file:
   mean: 2.2851
   median: 0
   min: 0
   max: 291
type_check_success_ratio: 1
total_feedbacks: 2132
feedbacks_per_label: 0.12578
fraction_files_with_feedbacks: 0.29582
top_feedbacks:
   name-defined: 784
   arg-type: 372
   attr-defined: 352
   return-value: 195
   assignment: 101
   misc: 61
   override: 41
   index: 39
   operator: 35
   union-attr: 31


In [7]:
# visualize feedback samples

from spot.utils import seq_flatten, add_line_numbers
from spot.visualization import code_inline_type_masks, visualize_sequence, display


if train_R1:
    to_display = []
    for xs in error_groups["return-value"]: # seq_flatten(error_groups.values()):
        src = xs[1]
        code = code_inline_type_masks(src.origin_code, src.types)
        to_display.append(
            f"feedback: {xs[0]}\n" + "=========code=========\n" + add_line_numbers(code)
        )
    if len(to_display) > 0:
        display(visualize_sequence(to_display))

interactive(children=(IntSlider(value=0, description='i', max=194), Output()), _dom_classes=('widget-interact'…

In [15]:
print(to_display[3])

feedback: MypyFeedback('[return-value]225:20: Incompatible return value type (got "Tuple[ReturnCodes, Tuple[Any, Any]]", expected "Tuple[int, Tuple[str, str]]") )'
  1|  from typing import Any, List, Tuple, Dict, Set # SPOT
  2|  import importlib
  3|  import os
  4|  import subprocess
  5|  import sys
  6|  import tempfile
  7|  from configparser import ConfigParser
  8|  from pathlib import Path
  9|  from typing import (
 10|      TYPE_CHECKING,
 11|      Any,
 12|      Callable,
 13|      Dict,
 14|      List,
 15|      Optional,
 16|      Tuple,
 17|      Union,
 18|      no_type_check,
 19|  )
 20|  
 21|  import py
 22|  import pytest
 23|  from _pytest._code import ExceptionInfo
 24|  from _pytest._code.code import ReprEntry, ReprFileLocation, TerminalRepr
 25|  from _pytest._io import TerminalWriter
 26|  from _pytest.config import Config
 27|  from mypy import build
 28|  from mypy.fscache import FileSystemCache
 29|  from mypy.main import process_options
 30|  
 31|  if TYPE

In [19]:
# R1 training

import torch
import wandb
from spot.data import SrcDataset, get_dataset_name
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper

r1_model_name = "R1-model--" + config.as_name()

if not load_trained:
    wandb.init(
        project=project_name,
        name=r1_model_name,
        config=config.as_dict(),
        dir=str(datadir),
    )

    r1_train_args = copy(train_args)
    r1_train_args.max_epochs = 1

    r1_wrapper, r1_extra = train_spot_model(
        r1_src_datasets,
        r1_model_name,
        train_args=r1_train_args,
        gpus=[gpu_id],
        record_batches=False,
        quicktest=config.quicktest,
        use_early_stop=False,
        use_small_model=config.use_small_model,
    )


In [23]:
# load trained model and evaluate
from spot.train import evaluate_model
from spot.visualization import visualize_dicts

r1_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/lit-saved/{r1_model_name}"
)
r1_wrapper.to(device)

r1_cache = PickleCache(datadir / f"checkpoints/lit-saved/{r1_model_name}/eval_cache")
r1_cache.clear()
r1_eval = evaluate_model(
    r0_wrapper,
    r1_wrapper,
    src_datasets["test"],
    tc_args=tc_args,
    eval_cache=r1_cache,
)
visualize_dicts([x[1].accuracies for x in r1_eval])


chunk_srcs_per_file: 100%|██████████| 933/933 [00:14<00:00, 64.89it/s]
verify_labels: 100%|██████████| 1637/1637 [00:00<00:00, 15724.40it/s]
predict: 100%|██████████| 1637/1637 [06:27<00:00,  4.22it/s]
map type_check_src_in_project: 100%|██████████| 933/933 [02:04<00:00,  7.48it/s]


type_check_success_ratio: 1
feedbacks_per_file:
   mean: 2.2851
   median: 0
   min: 0
   max: 291


feedbacks_to_tokenized_src: 100%|██████████| 933/933 [00:06<00:00, 151.15it/s]
inline_predictions: 100%|██████████| 933/933 [00:35<00:00, 25.92it/s]
chunk_srcs_per_file: 100%|██████████| 933/933 [00:14<00:00, 64.02it/s]
verify_labels: 100%|██████████| 1659/1659 [00:00<00:00, 15586.34it/s]
predict: 100%|██████████| 1659/1659 [07:12<00:00,  3.83it/s]


VBox(children=(Tab(children=(Tab(children=(HTML(value="<div style='white-space: pre-wrap; line-height: 1.2; fo…

In [24]:
from spot.visualization import export_preds_on_code, display_persist, proj_root

eval_to_viz = r1_eval[1][1]
sub_ids = range(0, len(eval_to_viz.chunks), 10)
export_preds_on_code(
    eval_to_viz.chunks[sub_ids],
    [eval_to_viz.predictions[i] for i in sub_ids],
    {},
    export_to=proj_root() / "R1_predictions",
)


Exporting: 100%|██████████| 166/166 [00:01<00:00, 100.50it/s]
Computing accuracies: 100%|██████████| 166/166 [00:00<00:00, 11120.50it/s]


In [None]:
from spot.visualization import visualize_conf_matrix

visualize_conf_matrix({n: x[1] for n, x in zip(["R0", "R1"], r1_eval)})


interactive(children=(IntSlider(value=1, description='round', max=1), IntSlider(value=10, continuous_update=Fa…

In [None]:
from spot.utils import pretty_show_dict

if not load_trained:
    for i, e in enumerate(r1_eval):
        wandb.log({f"test/R{i}": wandb_string(pretty_show_dict(e[1].accuracies))})
    wandb.finish()


In [None]:
from IPython.display import display

from spot.visualization import visualize_preds_on_code

round = 1
pred_dataset = r1_eval[round][1].chunks
visualize_preds_on_code(pred_dataset, r1_eval[round][1].predictions)


interactive(children=(IntSlider(value=0, continuous_update=False, description='i', max=1204), Output()), _dom_…

Box(children=(Output(),), layout=Layout(overflow='scroll'))

Box(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…