In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.data import GitRepo
from spot.type_env import (
    AnnotPath,
    MypyChecker,
    SelectAnnotations,
    TypeInfAction,
    TypeInfEnv,
    TypeInfState,
    collect_annotations,
    mypy_checker,
)
from spot.utils import cst, proj_root, read_file, seq_flatten, tqdm, write_file

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

useful_repos_path = proj_root() / "scripts" / "useful_repos.pkl"
with useful_repos_path.open("rb") as f:
    useful_repos: list[GitRepo] = pickle.load(f)

repos_split_path = datadir / "SPOT-data/repos-processed-with_margin/repos_split.pkl"
with repos_split_path.open("rb") as f:
    repos_split: dict[str, list[GitRepo]] = pickle.load(f)

In [3]:
from datasets import Dataset
from spot.model import ModelWrapper, DecodingArgs, CtxArgs
from spot.utils import TaskLoggingMonitor
import torch
from spot.model import ModelSPOT, TokenizerSPOT
import numpy as np
from spot.data import load_datasets

train_r0 = False  # whether to train or load trained R0 model
with_margin = True
data_reduction = 1


margin_tag = "with_margin" if with_margin else "no_margin"

r0_datasets, repos_split = load_datasets(
    datadir / f"SPOT-data/repos-processed-{margin_tag}"
)

data_tag = "data_full" if data_reduction == 1 else f"data_1-{data_reduction}"
n_train = len(r0_datasets["train"].data) // data_reduction

r0_model_name = f"SPOT-{margin_tag}-{data_tag}"

if train_r0:
    r0_model_path = "Salesforce/codet5-base"
else: 
    r0_model_path = datadir / f"checkpoints/saved/{r0_model_name}"
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer: TokenizerSPOT = TokenizerSPOT.from_pretrained(r0_model_path)

r0_model: ModelSPOT = ModelSPOT.from_pretrained(r0_model_path).to(device)
r0_monitor = TaskLoggingMonitor("R0")
r0_args = DecodingArgs(
    sampling_batch_size=512,
    ctx_args=CtxArgs(
        ctx_size=512,
        ctx_margin=128,
        types_in_ctx=True,
    ),
    max_workers=20,
)
r0_wrapper = ModelWrapper(r0_model, tokenizer, r0_args, r0_monitor)



In [4]:
import wandb
from spot.model import ModelTrainingArgs

r0_train_args = ModelTrainingArgs(
    train_batch_size=42,
    eval_batch_size=256,
    max_epochs=3,
)
r0_trainer = r0_wrapper.build_trainer(
    datadir / "checkpoints" / r0_model_name,
    r0_train_args,
    dataset=r0_datasets["train"].data,
    eval_dataset=r0_datasets["valid"].data,
)

if train_r0:
    wandb.init(
        project=r0_model_name,
        dir=str(datadir),
        config={"r0_decoding_args": r0_args, "r0_train_args": r0_train_args},
    )

    try:
        init_perf = r0_trainer.evaluate(max_length=r0_args.generation_max_length)
        print("initial eval loss:", init_perf)
        r0_trainer.train()
    except Exception as e:
        wandb.alert(
            title="Training stopped due to exception",
            text=f"In {r0_model_name}, exception: {e}",
        )
        raise e
    wandb.alert(title="Training finished", text=f"{r0_model_name} has finished.")
    wandb.log({"time_stats": r0_monitor.timer.total_times()})

    final_perf = r0_trainer.evaluate(max_length=r0_args.generation_max_length)
    print("final eval loss:", final_perf)
    wandb.finish()


Using amp half precision backend


In [10]:
pretty_print_accuracies(preds_to_accuracies(r0_preds, ti_datasets["test"]))

partial_acc: 75.54%
partial_acc_wo_any: 76.01%
partial_accs:
   FuncArg: 72.52%
   FuncReturn: 82.34%
   ClassAtribute: 69.17%
   GlobalVar: 79.44%
   LocalVar: 78.06%
full_acc: 68.15%
full_accs:
   FuncArg: 66.50%
   FuncReturn: 76.40%
   ClassAtribute: 58.97%
   GlobalVar: 52.34%
   LocalVar: 51.80%
n_labels: 17756


In [9]:
from spot.data import pretty_print_accuracies, preds_to_accuracies

r0_preds = r0_wrapper.predict(r0_datasets["test"], tqdm_args={})

pretty_print_accuracies(preds_to_accuracies(r0_preds, r0_datasets["test"]))

predict:   0%|          | 0/3974 [00:00<?, ?it/s]

NameError: name 'r0_datasets' is not defined

In [5]:
train_r1 = False

r1_model_name = f"SPOT-R1-{margin_tag}-{data_tag}"

if train_r1:
    r1_model_path = "Salesforce/codet5-base"
else:
    r1_model_path = datadir / f"checkpoints/saved/{r1_model_name}"

r1_model: ModelSPOT = ModelSPOT.from_pretrained(r1_model_path).to(device)
r1_monitor = TaskLoggingMonitor("R1")
r1_args = DecodingArgs(
    sampling_batch_size=512,
    ctx_args=CtxArgs(
        ctx_size=512,
        ctx_margin=128,
        types_in_ctx=False,
    ),
    max_workers=20,
)
r1_wrapper = ModelWrapper(r1_model, tokenizer, r1_args, r1_monitor)


loading configuration file /mnt/data0/jiayi/checkpoints/saved/SPOT-R1-with_margin-data_full/config.json
Model config T5Config {
  "_name_or_path": "Salesforce/codet5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "bos_token_id": 1,
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 2,
  "feed_forward_proj": "relu",
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0"
  },
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0
  },
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length"

In [25]:
from spot.data import TypeInfDataset, save_datasets

test_r1_parsing = True

r1_data_path = datadir / f"SPOT-data/{'test-' if test_r1_parsing else ''}repos-processed-R1-{margin_tag}"
r1_datasets: Dict[str, TypeInfDataset]
if r1_data_path.exists() and not test_r1_parsing:
    print(f"Loading R1 datasets from {r1_data_path}...")
    r1_datasets, _ = load_datasets(r1_data_path)
else:
    r1_datasets = {}
    for name in ["valid", "test", "train"]:
        print("Processing dataset:", name)
        repos = [r.repo_dir(repos_dir) for r in repos_split[name]]
        r0_data = r0_datasets[name]
        if test_r1_parsing:
            r0_data = r0_data[:16]
        r0_preds = r0_wrapper.predict(r0_data, tqdm_args={"leave": False})
        r1_datasets[name] = r1_wrapper.generate_r1_inputs(
            repos, r0_data, r0_preds, tqdm_args={"leave": False}
        )
    save_datasets(r1_datasets, repos_split, r1_data_path)

Processing dataset: valid


predict:   0%|          | 0/16 [00:00<?, ?it/s]

[R1] Current task: get_type_checked_inputs
[R1] Current task: get_type_checked_inputs > Collect type checker feedback


Collect type checker feedback:   0%|          | 0/40 [00:00<?, ?it/s]

[R1] Current task: get_type_checked_inputs > Augment inputs


generating augmented inputs:   0%|          | 0/8 [00:00<?, ?it/s]

[R1] Current task: chunk_masked_code


tokenizing sources:   0%|          | 0/8 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/22 [00:00<?, ?it/s]

Processing dataset: test


predict:   0%|          | 0/16 [00:00<?, ?it/s]

[R1] Current task: get_type_checked_inputs
[R1] Current task: get_type_checked_inputs > Collect type checker feedback


Collect type checker feedback:   0%|          | 0/50 [00:00<?, ?it/s]

[R1] Current task: get_type_checked_inputs > Augment inputs


generating augmented inputs:   0%|          | 0/2 [00:00<?, ?it/s]

[R1] Current task: chunk_masked_code


tokenizing sources:   0%|          | 0/2 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/29 [00:00<?, ?it/s]

Processing dataset: train


predict:   0%|          | 0/16 [00:00<?, ?it/s]

[R1] Current task: get_type_checked_inputs
[R1] Current task: get_type_checked_inputs > Collect type checker feedback


Collect type checker feedback:   0%|          | 0/573 [00:00<?, ?it/s]

[R1] Current task: get_type_checked_inputs > Augment inputs


generating augmented inputs:   0%|          | 0/9 [00:00<?, ?it/s]

[R1] Current task: chunk_masked_code


tokenizing sources:   0%|          | 0/9 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/25 [00:00<?, ?it/s]

Deleting old datasets at: /mnt/data0/jiayi/SPOT-data/test-repos-processed-R1-with_margin
308K	/mnt/data0/jiayi/SPOT-data/test-repos-processed-R1-with_margin


In [11]:
r1_train_args = r0_train_args
r1_trainer = r1_wrapper.build_trainer(
    datadir / "checkpoints" / r1_model_name,
    r1_train_args,
    dataset=r1_datasets["train"].data,
    eval_dataset=r1_datasets["valid"].data,
)

if train_r1:
    wandb.init(
        project=r1_model_name,
        dir=str(datadir),
        config={"r1_decoding_args": r1_args, "r1_train_args": r1_train_args},
    )

    try:
        init_perf = r1_trainer.evaluate(max_length=r1_args.generation_max_length)
        print("initial performance:", init_perf)
        r1_trainer.train()
    except Exception as e:
        wandb.alert(
            title="Training stopped due to exception",
            text=f"In {r1_model_name}, exception: {e}",
        )
        raise e
    wandb.alert(title="Training finished", text=f"{r1_model_name} has finished.")
    wandb.log({"time_stats": r1_monitor.timer.total_times()})

final_perf = r1_trainer.evaluate(max_length=r1_args.generation_max_length)
print("final performance:", final_perf)
wandb.finish()


PyTorch: setting up devices
Using amp half precision backend


***** Running Evaluation *****
  Num examples = 5312
  Batch size = 256


Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
***** Running training *****
  Num examples = 80433
  Num Epochs = 3
  Instantaneous batch size per device = 42
  Total train batch size (w. parallel, distributed & accumulation) = 42
  Gradient Accumulation steps = 1
  Total optimization steps = 5748


initial performance: {'eval_loss': 3.0191705226898193, 'eval_runtime': 45.5838, 'eval_samples_per_second': 116.533, 'eval_steps_per_second': 0.461}


Step,Training Loss,Validation Loss
500,0.5683,0.543072
1000,0.4574,0.488604
1500,0.426,0.47042
2000,0.3946,0.464758
2500,0.3562,0.459326
3000,0.3483,0.452277
3500,0.3479,0.453843
4000,0.3312,0.452933
4500,0.3116,0.452866


***** Running Evaluation *****
  Num examples = 5312
  Batch size = 256
Saving model checkpoint to /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500
Configuration saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500/config.json
Model weights saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500/pytorch_model.bin
tokenizer config file saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500/tokenizer_config.json
Special tokens file saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 5312
  Batch size = 256
Saving model checkpoint to /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-1000
Configuration saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-1000/config.json
Model weights saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_

final performance: {'eval_loss': 0.4522767961025238, 'eval_runtime': 45.7371, 'eval_samples_per_second': 116.142, 'eval_steps_per_second': 0.459, 'epoch': 2.35}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,█▁▁▁▁▁▁▁▁▁▁
eval/runtime,▁▇▇▇▇▇▇█▇█▇
eval/samples_per_second,█▂▂▂▂▂▂▁▂▁▂
eval/steps_per_second,█▁▁▁▁▁▁▁▁▁▁
train/epoch,▁▁▂▂▃▃▄▄▄▄▅▅▆▆▇▇████
train/global_step,▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇█████
train/learning_rate,█▇▆▅▄▄▃▂▁
train/loss,█▅▄▃▂▂▂▂▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.45228
eval/runtime,45.7371
eval/samples_per_second,116.142
eval/steps_per_second,0.459
train/epoch,2.35
train/global_step,4500.0
train/learning_rate,0.0
train/loss,0.3116
train/total_flos,1.1504554260037632e+17
train/train_loss,0.3935


In [26]:
from spot.data import pretty_print_accuracies, preds_to_accuracies

r1_preds = r1_wrapper.predict(r1_datasets["test"], tqdm_args={})

pretty_print_accuracies(preds_to_accuracies(r1_preds, r1_datasets["test"]))

predict:   0%|          | 0/20 [00:00<?, ?it/s]

partial_acc: 69.62%
partial_acc_wo_any: 69.62%
partial_accs:
   FuncArg: 69.57%
   FuncReturn: 72.41%
   LocalVar: 50.00%
full_acc: 53.16%
full_accs:
   FuncArg: 50.00%
   FuncReturn: 62.07%
   LocalVar: 25.00%
n_labels: 79


In [27]:
from spot.visualization import display_code_sequence, visualize_batch

display_code_sequence([visualize_batch(r1_datasets["test"], i, r1_preds, tokenizer, r1_args.ctx_args) for i in range(3)])

Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…