In [1]:
import sys, subprocess, os

# Define function x that given a command string, runs it with subprocess and streams the output
def x(cmd):
    return subprocess.run(cmd.split(" ")).returncode

DATA_PATH = 'data/deepjit_summary_data'
ADDITIONAL_DATA_PATHS = ['data/benchmarks/java.json', 'data/Go-22k/test.json', 'data/Python-22k/test.json', 'data/JS-22k/test.json', 'data/Java-22k/test.json', 'data/benchmarks/python.json', 'data/benchmarks/go.json', 'data/benchmarks/js.json']

BATCH_SIZE = 2
ACCUM_ITERS = 64 // BATCH_SIZE
MAX_EPOCHS = 0
# model type can be 'bert' or 'longformer'
MODEL_TYPE = 'bert'

MY_REPO_URL = 'https://github.com/pelmers/llms-for-code-comment-consistency.git'

# If running as a script, allow os.environ to overwrite these options
if __name__ == '__main__':
    import os
    for k, v in os.environ.items():
        if k in globals():
            # First check if v is a boolean or a number and convert to the right type
            if v.lower() == 'true':
                v = True
            elif v.lower() == 'false':
                v = False
            elif v.isnumeric():
                v = int(v)
            # Or a float
            elif '.' in v and v.replace('.', '').isnumeric():
                v = float(v)
            # Or a list
            elif v.startswith('[') and v.endswith(']'):
                v = v[1:-1].split(',')
                v = [a.strip() for a in v]
            globals()[k] = v

# If we're not in my repo (../.git exists?), then clone it
if os.path.exists('../.git'):
    data_rel = '../'
else:
    if not os.path.exists('llms-for-code-comment-consistency'):
        assert x(f'git clone {MY_REPO_URL}') == 0
    data_rel = 'llms-for-code-comment-consistency/'

FOLDER_NAME = 'coco-bert-longformer'
if not os.path.exists(FOLDER_NAME):
    x('git clone https://github.com/pelmers/coco-bert-longformer.git --branch replication')

DATA_FOLDER = os.path.join(FOLDER_NAME, 'data', 'summary')

DATA_ARCHIVE = os.path.join(data_rel, DATA_PATH) + '.tar.gz'

def ensure_data_archive(data_path):
    # If the data folder does not exist, extract it from the .tar.gz file
    folder_path = os.path.join(data_rel, data_path)
    if not os.path.exists(folder_path):
        print(f'Ensuring data exists at {folder_path}.tar.gz')
        archive_path = folder_path + '.tar.gz'
        archive_data_path = data_path + '.tar.gz'
        if not os.path.exists(archive_path):
            print(f'Downloading data for {archive_data_path} from server')
            assert x(f'wget -O {archive_path} https://file2.pelmers.com/{archive_data_path}') == 0

for data_path in [DATA_PATH] + [os.path.dirname(p) for p in ADDITIONAL_DATA_PATHS]:
    ensure_data_archive(data_path)

def extract_archive(archive_path):
    print(f'Extracting data from {archive_path} to {DATA_FOLDER}...')
    x('mkdir -p {}'.format(DATA_FOLDER))
    assert x('tar -xvzf {} -C {}'.format(archive_path, DATA_FOLDER)) == 0

extract_archive(DATA_ARCHIVE)

RUN_LANGUAGE = 'java'
RUN_LANGUAGE = 'go' if 'go' in DATA_PATH.lower() else RUN_LANGUAGE
RUN_LANGUAGE = 'py' if 'python' in DATA_PATH.lower() else RUN_LANGUAGE
RUN_LANGUAGE = 'js' if 'javascript' in DATA_PATH.lower() or 'js' in DATA_PATH.lower() else RUN_LANGUAGE

model_save_path = f'{FOLDER_NAME}/models/{MODEL_TYPE}-{RUN_LANGUAGE}-{DATA_PATH.split("/")[-1]}-trained.pt'

x('pip install transformers scikit-learn pandas')

Ensuring data exists at ../notebooks/data/benchmarks.tar.gz
Ensuring data exists at ../notebooks/data/Go-22k.tar.gz
Ensuring data exists at ../notebooks/data/Python-22k.tar.gz
Ensuring data exists at ../notebooks/data/Java-22k.tar.gz
Ensuring data exists at ../notebooks/data/benchmarks.tar.gz
Ensuring data exists at ../notebooks/data/benchmarks.tar.gz
Ensuring data exists at ../notebooks/data/benchmarks.tar.gz
Extracting data from ../notebooks/data/deepjit_summary_data.tar.gz to coco-bert-longformer/data/summary...
deepjit_summary_data/metadata.json
deepjit_summary_data/._train.json
deepjit_summary_data/._test.json
deepjit_summary_data/train.json
deepjit_summary_data/test.json
deepjit_summary_data/valid.json
deepjit_summary_data/._valid.json


0

In [2]:
sys.path.append(FOLDER_NAME)

import constants
constants.set_constants(batch_size=BATCH_SIZE, accum_iters=ACCUM_ITERS, max_epochs=MAX_EPOCHS)

import train

import torch
import numpy as np

import random

# Seed options taken from original repository
seed = random.choice([12, 17, 22])

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

print(f'Using random seed {seed}')
if torch.cuda.is_available():
    x('nvidia-smi')

  from .autonotebook import tqdm as notebook_tqdm


Using random seed 17


In [3]:
NEGATIVE_TO_POSITIVE_RATIO = {
    'java': 19,
    'go': 19,
    'py': 19,
    'js': 19,
}

import pandas as pd

def load_file(json_path):
    return pd.read_json(json_path)

In [4]:
# Pick device based on whether gpu is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Start training {MODEL_TYPE} model, will save to {model_save_path}')
data_extracted_folder = os.path.join(DATA_FOLDER, DATA_PATH.split('/')[-1])
train_df = load_file(os.path.join(data_extracted_folder, 'train.json'))
valid_df = load_file(os.path.join(data_extracted_folder, 'valid.json'))
model = train.train(device, MODEL_TYPE, model_save_path, train_df, valid_df, negative_class_weight=NEGATIVE_TO_POSITIVE_RATIO[RUN_LANGUAGE])

Start training bert model, will save to coco-bert-longformer/models/bert-trained.pt
Total number of parameters: 109483778


In [5]:
print('Evaluating model')

from torch.utils.data import DataLoader

from eval import test, CocoDataset

test_df = load_file(os.path.join(data_extracted_folder, 'test.json'))
test_data = CocoDataset(test_df, MODEL_TYPE)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)

test(model, test_loader, device, negative_class_weight=NEGATIVE_TO_POSITIVE_RATIO[RUN_LANGUAGE])

Evaluating model


100%|██████████| 533/533 [23:47<00:00,  2.68s/it]

test_loss: 0.798 test_precision: 0.500 test_recall: 0.998 test_f1: 0.666 test_acc: 0.499
test_weighted_f1: 0.103





In [6]:
# Now test with the best model by validation f1

print('Evaluating best model')
model.load_state_dict(torch.load(model_save_path))
test(model, test_loader, device, negative_class_weight=NEGATIVE_TO_POSITIVE_RATIO[RUN_LANGUAGE])

Evaluating best model


100%|██████████| 533/533 [23:43<00:00,  2.67s/it]

test_loss: 0.710 test_precision: 0.692 test_recall: 0.017 test_f1: 0.033 test_acc: 0.505
test_weighted_f1: 0.033





In [7]:
# Then test on additional test files

print('Evaluating additional test paths')
for path in ADDITIONAL_DATA_PATHS:
    pfne = os.path.basename(path).split('.')[0].lower()
    print(f'Evaluating {path}')
    archive_path = os.path.join(data_rel, os.path.dirname(path) + '.tar.gz')
    extract_archive(archive_path)
    data_extracted_folder = os.path.join(DATA_FOLDER, os.path.dirname(path).split('/')[-1])
    test_df = load_file(os.path.join(data_extracted_folder, os.path.basename(path)))
    test_data = CocoDataset(test_df, MODEL_TYPE)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)
    test_language = 'py' if 'python' in pfne else 'go' if 'go' in pfne else 'js' if 'js' in pfne else 'java'

    print(f'Using test language {test_language}')
    test(model, test_loader, device, negative_class_weight=NEGATIVE_TO_POSITIVE_RATIO[test_language])

Evaluating additional test paths
Evaluating data/benchmarks/java.json
Extracting data from ../notebooks/data/benchmarks.tar.gz to coco-bert-longformer/data/summary...
benchmarks/js.json
benchmarks/python.json
benchmarks/java.json
benchmarks/go.json


100%|██████████| 25/25 [00:44<00:00,  1.79s/it]


test_loss: 0.699 test_precision: 0.667 test_recall: 0.080 test_f1: 0.143 test_acc: 0.520
test_weighted_f1: 0.077
Evaluating data/Go-22k/test.json
Extracting data from ../notebooks/data/Go-22k.tar.gz to coco-bert-longformer/data/summary...
Go-22k/train.json
Go-22k/metadata.json
Go-22k/extras.json
Go-22k/valid.json
Go-22k/test.json


 75%|███████▌  | 829/1100 [36:30<11:55,  2.64s/it]


KeyboardInterrupt: 