In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from pathlib import Path
from typing import *

import pandas as pd
import plotly.express as px

from spot.utils import cst, proj_root, run_long_task, tqdm

os.chdir(proj_root())

datadir = Path(os.getenv("datadir"))
repos_dir = datadir / "SPOT-data/repos"

In [2]:
# experiment configurations

from spot.data import SrcDataset, get_dataset_name
import torch

from spot.data import ChunkedDataset, get_model_name
from spot.model import CtxArgs, DecodingArgs, ModelSPOT, ModelWrapper, ModelTrainingArgs

drop_comments = True
src_datasets_path = (
    datadir / f"SPOT-data" / get_dataset_name(drop_comments=drop_comments)
)
src_datasets_path.mkdir(parents=True, exist_ok=True)

data_reduction = 1
ctx_args = CtxArgs(
    ctx_size=1024,
    left_margin=256 + 128,
    right_margin=256 - 128,
    types_in_ctx=False,
)

r0_model_name = get_model_name(
    drop_comments=drop_comments, ctx_args=ctx_args, data_reduction=data_reduction
)

print("loading from: ", src_datasets_path)
src_datasets = dict[str, SrcDataset]()
for n in ["train", "valid", "test"]:
    with open(src_datasets_path / f"{n}.pkl", "rb") as f:
        src_datasets[n] = pickle.load(f)
        src_datasets[n].repos_root = repos_dir


loading from:  /mnt/data0/jiayi/SPOT-data/src_datasets-drop_comments


In [9]:
# train the model
from spot.train import train_r0_model

train_args = ModelTrainingArgs(
    train_batch_size=8,
    eval_batch_size=64,
    max_epochs=3,
)

r0_wrapper = train_r0_model(
    drop_comments=drop_comments,
    data_reduction=data_reduction,
    train_args=train_args,
    ctx_args=ctx_args,
)


R0 model name:  SPOT-model-R0-(384, 512, 128)-drop_comments


processing chunks:   0%|          | 0/3158 [00:00<?, ?it/s]

processing chunks:   0%|          | 0/43679 [00:00<?, ?it/s]

Pushover: (Finished: Preparing chunked datasets.) Time taken: 72.8s


Using amp half precision backend
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmrvplusone[0m. Use [1m`wandb login --relogin`[0m to force relogin


***** Running Evaluation *****
  Num examples = 2394
  Batch size = 64


Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
***** Running training *****
  Num examples = 34182
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 6411


initial eval loss: {'eval_loss': 1.7518380880355835, 'eval_runtime': 44.3886, 'eval_samples_per_second': 53.933, 'eval_steps_per_second': 0.428}


Step,Training Loss,Validation Loss
500,0.4614,0.341228
1000,0.3651,0.332701
1500,0.331,0.311589
2000,0.3103,0.306835
2500,0.2706,0.306942
3000,0.2504,0.306813
3500,0.2463,0.305475
4000,0.2415,0.302821
4500,0.2264,0.309899
5000,0.206,0.307135


***** Running Evaluation *****
  Num examples = 2394
  Batch size = 64
Saving model checkpoint to /mnt/data0/jiayi/checkpoints/SPOT-model-R0-(384, 512, 128)-drop_comments/checkpoint-500
Configuration saved in /mnt/data0/jiayi/checkpoints/SPOT-model-R0-(384, 512, 128)-drop_comments/checkpoint-500/config.json
Model weights saved in /mnt/data0/jiayi/checkpoints/SPOT-model-R0-(384, 512, 128)-drop_comments/checkpoint-500/pytorch_model.bin
tokenizer config file saved in /mnt/data0/jiayi/checkpoints/SPOT-model-R0-(384, 512, 128)-drop_comments/checkpoint-500/tokenizer_config.json
Special tokens file saved in /mnt/data0/jiayi/checkpoints/SPOT-model-R0-(384, 512, 128)-drop_comments/checkpoint-500/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 2394
  Batch size = 64
Saving model checkpoint to /mnt/data0/jiayi/checkpoints/SPOT-model-R0-(384, 512, 128)-drop_comments/checkpoint-1000
Configuration saved in /mnt/data0/jiayi/checkpoints/SPOT-model-R0-(384, 512, 128)-drop_commen

Pushover: (Finished: Training SPOT-model-R0-(384, 512, 128)-drop_comments.) Time taken: 5052.0s


***** Running Evaluation *****
  Num examples = 2394
  Batch size = 64


final eval loss: {'eval_loss': 0.30282124876976013, 'eval_runtime': 43.0988, 'eval_samples_per_second': 55.547, 'eval_steps_per_second': 0.441, 'epoch': 2.57}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,█▁▁▁▁▁▁▁▁▁▁▁▁
eval/runtime,█▂▁▁▁▁▁▁▁▁▁▁▁
eval/samples_per_second,▁▇▇██████████
eval/steps_per_second,▁▇███████████
train/epoch,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▆▆▇▇▇▇████
train/global_step,▁▂▂▂▂▃▃▄▄▄▄▅▅▅▅▆▆▇▇▇▇████
train/learning_rate,█▇▇▆▅▅▄▃▂▂▁
train/loss,█▅▄▄▃▂▂▂▂▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.30282
eval/runtime,43.0988
eval/samples_per_second,55.547
eval/steps_per_second,0.441
train/epoch,2.57
train/global_step,5500.0
train/learning_rate,0.0
train/loss,0.2081
train/total_flos,1.071522304229376e+17
train/train_loss,0.28337


Configuration saved in /mnt/data0/jiayi/checkpoints/saved/SPOT-model-R0-(384, 512, 128)-drop_comments/config.json
Model weights saved in /mnt/data0/jiayi/checkpoints/saved/SPOT-model-R0-(384, 512, 128)-drop_comments/pytorch_model.bin
tokenizer config file saved in /mnt/data0/jiayi/checkpoints/saved/SPOT-model-R0-(384, 512, 128)-drop_comments/tokenizer_config.json
Special tokens file saved in /mnt/data0/jiayi/checkpoints/saved/SPOT-model-R0-(384, 512, 128)-drop_comments/special_tokens_map.json


In [5]:
# or load trained model
r0_wrapper = ModelWrapper.from_pretrained(
    datadir / f"checkpoints/saved/{r0_model_name}"
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
r0_wrapper.model.to(device)
r0_wrapper.args.do_sample=False
r0_wrapper.args


positional arguments and argument "destination" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.



DecodingArgs(ctx_args=CtxArgs(left=384, window=512, right=128), sampling_batch_size=128, max_workers=20, generation_max_length=128, do_sample=False, top_p=0.9)

In [6]:
# evaluate the model

from spot.data import pretty_print_accuracies

size_factors = [1, 2, 3]
with run_long_task("Evaluate R0 accuracy vs ctx_size"):
    acc_series = []
    for factor in size_factors:
        wrapper = r0_wrapper.scale_ctx_size(factor)
        accs = wrapper.eval_on_dataset(
            src_datasets["test"], tqdm_args={"leave": False}
        )[0]
        acc_series.append(accs)
        print(f"===ctx_size factor: {factor}===")
        print(f"ctx_args: {wrapper.args.ctx_args}")
        pretty_print_accuracies(accs)

import plotly.express as px

acc_df = pd.DataFrame(
    {
        "ctx_size": size_factors,
        "partial_acc": [x["partial_acc"] for x in acc_series],
        "full_acc": [x["full_acc"] for x in acc_series],
    }
)
px.line(acc_df, x="ctx_size", y=["partial_acc", "full_acc"], title=r0_model_name)


processing chunks:   0%|          | 0/2282 [00:00<?, ?it/s]

predict:   0%|          | 0/1884 [00:00<?, ?it/s]

===ctx_size factor: 1===
ctx_args: CtxArgs(left=384, window=512, right=128)
partial_acc: 0.8415
partial_acc_wo_any: 0.8458
partial_accs:
   FuncArg: 0.8378
   FuncReturn: 0.8699
   ClassAtribute: 0.7867
   GlobalVar: 0.8214
   LocalVar: 0.8735
full_acc: 0.7821
full_accs:
   FuncArg: 0.7847
   FuncReturn: 0.8255
   ClassAtribute: 0.7062
   GlobalVar: 0.6607
   LocalVar: 0.6838
n_labels: 8421


processing chunks:   0%|          | 0/2282 [00:00<?, ?it/s]

predict:   0%|          | 0/1870 [00:00<?, ?it/s]

===ctx_size factor: 2===
ctx_args: CtxArgs(left=1280, window=512, right=256)
partial_acc: 0.8578
partial_acc_wo_any: 0.8614
partial_accs:
   FuncArg: 0.8617
   FuncReturn: 0.8776
   ClassAtribute: 0.8024
   GlobalVar: 0.8393
   LocalVar: 0.873
full_acc: 0.8085
full_accs:
   FuncArg: 0.8215
   FuncReturn: 0.8385
   ClassAtribute: 0.7263
   GlobalVar: 0.6964
   LocalVar: 0.7302
n_labels: 8418


processing chunks:   0%|          | 0/2282 [00:00<?, ?it/s]

predict:   0%|          | 0/1879 [00:00<?, ?it/s]

===ctx_size factor: 3===
ctx_args: CtxArgs(left=2176, window=512, right=384)
partial_acc: 0.8584
partial_acc_wo_any: 0.8637
partial_accs:
   FuncArg: 0.8687
   FuncReturn: 0.8744
   ClassAtribute: 0.7957
   GlobalVar: 0.8214
   LocalVar: 0.8611
full_acc: 0.8111
full_accs:
   FuncArg: 0.83
   FuncReturn: 0.8367
   ClassAtribute: 0.7218
   GlobalVar: 0.6964
   LocalVar: 0.7302
n_labels: 8413
Pushover: (Finished: Evaluate R0 accuracy vs ctx_size.) Time taken: 1219.5s


In [10]:
from spot.visualization import display_code_sequence, visualize_batch

n_visual_exs = 16

wrapper = r0_wrapper.scale_ctx_size(2)
_, visual_data, visual_preds = wrapper.eval_on_dataset(
    src_datasets["test"][:n_visual_exs], tqdm_args={"leave": False}
)

display_code_sequence(
    [
        visualize_batch(
            visual_data,
            i,
            visual_preds,
            wrapper.tokenizer,
            wrapper.args.ctx_args,
        )
        for i in range(min(n_visual_exs, len(visual_preds)))
    ]
)

processing chunks:   0%|          | 0/33 [00:00<?, ?it/s]

predict:   0%|          | 0/24 [00:00<?, ?it/s]

Tab(children=(HTML(value="<pre style='line-height: 1.2; padding: 10px; color: rgb(212,212,212); background-col…