In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.utils import cst, proj_root, run_long_task, tqdm

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

In [2]:
from spot.data import SrcDataset

src_datasets_path = datadir / f"SPOT-data/src_datasets"
src_datasets = dict[str, SrcDataset]()
for n in ["train", "valid", "test"]:
    with open(src_datasets_path / f"{n}.pkl", "rb") as f:
        src_datasets[n] = pickle.load(f)
        src_datasets[n].repos_root = repos_dir


In [3]:
import torch

from spot.model import ModelSPOT, ModelWrapper

with_margin = True
data_reduction = 1

margin_tag = "with_margin" if with_margin else "no_margin"
data_tag = "data_full" if data_reduction == 1 else f"data_1-{data_reduction}"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
r0_model_name = f"SPOT-R0-{margin_tag}-{data_tag}"
r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/saved/{r0_model_name}"
)
r0_wrapper.model.to(device)
tokenizer = r0_wrapper.tokenizer



In [4]:
# Set this to the best ctx_size
best_r0_ctx_factor = 2
r0_wrapper_best = r0_wrapper.scale_ctx_size(best_r0_ctx_factor)


In [5]:
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper, TokenizerSPOT
from spot.utils import TaskLoggingMonitor

train_r1 = True

r1_model_name = f"SPOT-R1-{margin_tag}-{data_tag}"

if train_r1:
    r1_model_path = "Salesforce/codet5-base"
    # fine_tune from R0
    # r1_model_path = datadir / f"checkpoints/saved/{r0_model_name}"
else:
    r1_model_path = datadir / f"checkpoints/saved/{r1_model_name}"

r1_model: ModelSPOT = ModelSPOT.from_pretrained(r1_model_path).to(device)
r1_monitor = TaskLoggingMonitor("R1")
r1_args = r0_wrapper.args
r1_wrapper = ModelWrapper(r1_model, tokenizer, r1_args, r1_monitor)



In [6]:
import pickle

from spot.data import ChunkedDataset, save_datasets
from spot.utils import PickleCache

test_r1_generation = False
use_file_level_feedback = False

feedback_tag = "iso_file"  # "per_file" if use_file_level_feedback else "per_project"

r1_cache = PickleCache(datadir / f"cache/r1_src_datasets-{test_r1_generation}")

with run_long_task("Generating R1 datasets", notify=False):
    r1_src_datasets = dict()
    for name in ["test", "valid", "train"]:
        print("Working on:", name)
        r0_src = src_datasets[name]
        if test_r1_generation:
            r0_src = SrcDataset(r0_src.srcs_with_labels()[:16], r0_src.repos_root)
        _, r0_data, r0_preds = r1_cache.cached(
            f"eval_r0/{name}",
            lambda: r0_wrapper_best.eval_on_dataset(r0_src, tqdm_args={"leave": False}),
        )
        r1_src_datasets[name] = r1_cache.cached(
            f"r1_src_datasets/{name}",
            lambda: r1_wrapper.generate_r1_srcs(
                r0_src,
                r0_data,
                r0_preds,
            ),
        )


Working on: test
Working on: valid
Working on: train
Pushover: (Finished: Generating R1 datasets.) Time taken: 28.3s


In [7]:
import wandb
from spot.model import ModelTrainingArgs

r1_train_args = ModelTrainingArgs(
    train_batch_size=8,
    eval_batch_size=64,
    max_epochs=3,
)

if train_r1:
    r1_chunks: dict[str, ChunkedDataset] = {}
    with run_long_task("Preparing R1 chunked datasets", notify=False):
        for n in ["valid", "train"]:
            r1_chunks[n] = r1_src_datasets[n].to_chunks(
                tokenizer, r1_wrapper.args.ctx_args, max_workers=20
            )

    r1_trainer = r1_wrapper.build_trainer(
        datadir / "checkpoints" / r1_model_name,
        r1_train_args,
        dataset=r1_chunks["train"].data,
        eval_dataset=r1_chunks["valid"].data,
    )

    wandb.init(
        project=r1_model_name,
        dir=str(datadir),
        config={"r1_decoding_args": r1_args, "r1_train_args": r1_train_args},
    )

    with run_long_task(f"Training {r1_model_name}"):
        init_perf = r1_trainer.evaluate(max_length=r1_args.generation_max_length)
        print("initial performance:", init_perf)
        r1_trainer.train()

    wandb.log({"time_stats": r1_monitor.timer.total_times()})

    final_perf = r1_trainer.evaluate(max_length=r1_args.generation_max_length)
    print("final performance:", final_perf)
    wandb.finish()

    r1_wrapper.save_pretrained(datadir / f"checkpoints/saved/{r1_model_name}")


processing chunks:   0%|          | 0/3805 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/53864 [00:00<?, ?it/s]

Using amp half precision backend
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Pushover: (Finished: Preparing R1 chunked datasets.) Time taken: 74.0s


[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


***** Running Evaluation *****
  Num examples = 2786
  Batch size = 64


Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
***** Running training *****
  Num examples = 40588
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 15222


initial performance: {'eval_loss': 0.3440489172935486, 'eval_runtime': 75.7023, 'eval_samples_per_second': 36.802, 'eval_steps_per_second': 0.581}


Step,Training Loss,Validation Loss
500,0.2193,0.294917
1000,0.2261,0.282882
1500,0.225,0.288669
2000,0.2118,0.28332
2500,0.2111,0.287434


***** Running Evaluation *****
  Num examples = 2786
  Batch size = 64
Saving model checkpoint to /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500
Configuration saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500/config.json
Model weights saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500/pytorch_model.bin
tokenizer config file saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500/tokenizer_config.json
Special tokens file saved in /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-500/special_tokens_map.json
Deleting older checkpoint [/mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-4500] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 2786
  Batch size = 64
Saving model checkpoint to /mnt/data0/jiayi/checkpoints/SPOT-R1-with_margin-data_full/checkpoint-1000
Configuration saved in /mnt/data0/jiayi/check

Pushover: (Finished: Training SPOT-R1-with_margin-data_full.) Time taken: 1912.1s


***** Running Evaluation *****
  Num examples = 2786
  Batch size = 64


final performance: {'eval_loss': 0.2828815281391144, 'eval_runtime': 75.7831, 'eval_samples_per_second': 36.763, 'eval_steps_per_second': 0.581, 'epoch': 0.49}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,█▂▁▂▁▂▁
eval/runtime,▁▅█▇▇▇▇
eval/samples_per_second,█▃▁▂▃▂▂
eval/steps_per_second,▁▁▁▁▁▁▁
train/epoch,▁▁▃▃▅▅▆▆████
train/global_step,▁▂▂▄▄▅▅▇▇████
train/learning_rate,█▆▄▃▁
train/loss,▅█▇▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.28288
eval/runtime,75.7831
eval/samples_per_second,36.763
eval/steps_per_second,0.581
train/epoch,0.49
train/global_step,2500.0
train/learning_rate,2e-05
train/loss,0.2111
train/total_flos,2.43583156224e+16
train/train_loss,0.21865


Configuration saved in /mnt/data0/jiayi/checkpoints/saved/SPOT-R1-with_margin-data_full/config.json
Model weights saved in /mnt/data0/jiayi/checkpoints/saved/SPOT-R1-with_margin-data_full/pytorch_model.bin
tokenizer config file saved in /mnt/data0/jiayi/checkpoints/saved/SPOT-R1-with_margin-data_full/tokenizer_config.json
Special tokens file saved in /mnt/data0/jiayi/checkpoints/saved/SPOT-R1-with_margin-data_full/special_tokens_map.json


In [8]:
from spot.data import preds_to_accuracies, pretty_print_accuracies
from spot.visualization import display_code_sequence, visualize_batch

r1_wrapper_test = r1_wrapper.scale_ctx_size(2)
r1_accs, r1_data, r1_preds = r1_wrapper_test.eval_on_dataset(r1_src_datasets["test"])
pretty_print_accuracies(r1_accs)


processing chunks:   0%|          | 0/2729 [00:00<?, ?it/s]

predict:   0%|          | 0/2162 [00:00<?, ?it/s]

partial_acc: 0.8386
partial_acc_wo_any: 0.8454
partial_accs:
   FuncArg: 0.8436
   FuncReturn: 0.8622
   ClassAtribute: 0.7711
   GlobalVar: 0.7857
   LocalVar: 0.8651
full_acc: 0.7804
full_accs:
   FuncArg: 0.7966
   FuncReturn: 0.8167
   ClassAtribute: 0.6913
   GlobalVar: 0.6071
   LocalVar: 0.631
n_labels: 8419


In [13]:
display_code_sequence(
    [
        visualize_batch(
            r1_data,
            i+16,
            r1_preds,
            tokenizer,
            r1_wrapper_test.args.ctx_args,
        )
        for i in range(min(16, len(r1_data.chunks_info)))
    ]
)

Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…