# Import libraries

In [3]:
%load_ext autoreload

# Check PyTorch version
import torch
print('Torch', torch.__version__)

import fairseq
print('fairseq', fairseq.__version__)

import sys
import json, os, re

sys.path.insert(0, 'utils')  # Replace with the actual path

Torch 2.3.1
fairseq 0.12.2


In [9]:
DATA_DIR = 'data'
os.environ["DATA_DIR"] = DATA_DIR

In [10]:
%autoreload 2

from tqdm import tqdm
from copy import deepcopy
from collections import defaultdict, OrderedDict
import hashlib
from pathlib import Path
import os

from utils.code_error_checker import check_paren_error, check_ast_error
from utils.code_utils import preprocess_unk, code_toks_to_code_string, get_diff_metric, tokenize_python_code
from utils.fairseq_utils import parse_fairseq_preds, fairseq_preprocess, fairseq_generate, fairseq_train


# Common functions

In [None]:
def eval_one_pred_obj(pred_obj):
    # Deal with UNK
    _, unk_dict = preprocess_unk(pred_obj['code_toks_raw'])
    anonymize_dict = pred_obj['anonymize_dict']
    if anonymize_dict is None:
        anonymize_dict = {}
    anonymize_dict['<unk>'] = unk_dict
    anonymize_dict['<STRING>'] = []
    anonymize_dict['<COMMENT>'] = []
    #
    src = pred_obj['src'] #this is tok_format i.e. ' '.join(code_toks)
    src_code  = code_toks_to_code_string(src, anonymize_dict) #this is string_format
    ret_obj = {'progid': pred_obj['progid'],
               'orig_err_obj': pred_obj['orig_err_obj'],
               'anonymize_dict': pred_obj['anonymize_dict']
               }
    ret_obj['src']  = {'tok_format': src, 'string_format': src_code}
    #Get string_format from predicted code toks
    ret_obj['pred'] = []
    for pred in pred_obj['pred']:
        pred_code = code_toks_to_code_string(pred, anonymize_dict) #this is string_format
        orig_err_obj = pred_obj['orig_err_obj']
        if orig_err_obj['msg'] == 'unbalanced (){}[]':
            #NOTE: `pred` is tok_format i.e. ' '.join(code_toks)
            res = check_paren_error(pred.split())
        else:
            res = check_ast_error(pred_code)
        diff_metric = get_diff_metric(src, pred)
        ret_obj['pred'].append({'tok_format': pred,
                                'string_format': pred_code,
                                'err_obj': res,
                                'diff_metric': diff_metric})
    return ret_obj

def eval_one_split(pred_dir_prefix, split, pred_fname, n_workers=80):
    pred_dir   = f'{pred_dir_prefix}{split}'
    pred_path = Path(f'{pred_dir}/{pred_fname}')
    preds = parse_fairseq_preds(str(pred_path))
    #load progids
    data_dir = DATA_DIR
    progids = [l.strip() for l in open(f'{data_dir}/orig_bad_code/orig.{split}.id')]
    assert len(preds) == len(progids)
    #load original err_obj
    bads = json.load(open(f'{data_dir}/orig_bad_code/orig.bad.json'))
    for j in range(len(preds)):
        progid = progids[j]
        preds[j]['progid'] = progid
        preds[j]['orig_err_obj'] = bads[progid]['err_obj']
        code_toks_raw = bads[progid]['code_toks_joined'].split()
        anonymize_dict = bads[progid]['anonymize_dict']
        if 'window_span' in bads[progid]:
            ws = bads[progid]['window_span']
            code_toks_raw = code_toks_raw[ws[0]:ws[1]]
            anonymize_dict = None
        preds[j]['code_toks_raw'] = code_toks_raw
        preds[j]['anonymize_dict'] = anonymize_dict
    #
    print ('len(preds)', len(preds))
    # with Pool(n_workers) as p:
    #     res = list(tqdm(p.imap(eval_one_pred_obj, preds), total=len(preds)))
    res = list(tqdm(map(eval_one_pred_obj, preds)))  # or list(tqdm([eval_one_pred_obj(pred) for pred in preds]))

    '''
      res: list of {'progid': , 'orig_err_obj': , 'anonymize_dict': ,
                    'src': {'tok_format': , 'string_format': },
                    'pred': {'tok_format':, 'string_format':, 'err_obj': }
                    }
    '''
    with open(f'{pred_path.parent}/{pred_path.stem}.evaluated.json', 'w') as f:
        json.dump(res, f, indent=2)

def get_test_result(pred_dir_prefix, pred_fname):
    #
    def collate_eval():
        success  = []; denom = 0
        success_by_group = defaultdict(list); denom_by_group = defaultdict(int)
        agg_obj = {}
        for split in {3,4}: #heldout test set
            print ('split', split)
            pred_dir   = Path(f'{pred_dir_prefix}{split}')
            pred_path  = pred_dir/pred_fname
            pred_eval_path = f'{pred_path.parent}/{pred_path.stem}.evaluated.json'
            eval_objs = json.load(open(pred_eval_path))
            for eval_obj in eval_objs:
                progid = eval_obj['progid']
                orig_err_type = eval_obj['orig_err_obj']['msg']
                if 'indent' in orig_err_type:
                    orig_err_type = 'indentation error'
                denom += 1
                denom_by_group[orig_err_type] += 1
                for k, pred_obj in enumerate(eval_obj['pred']):
                    pred_err_obj = pred_obj['err_obj']
                    diff_metric  = pred_obj['diff_metric']
                    if (pred_err_obj == 0) and (0 < diff_metric <= 4):
                        name = '{:02d}-{}-{:03d}'.format(split, progid, k)
                        success.append(name)
                        success_by_group[orig_err_type].append(name)
        return success, denom, success_by_group, denom_by_group
    #
    def print_stats(name_list, _denom):
        top1 = set()
        for name in name_list:
            split, progid, k = name.split('-')
            if int(split) in {3,4}: #test set
                if int(k)==0:
                    top1.add(f'{split}-{progid}')
        acc = len(top1)/float(_denom)*100
        print ('   acc: {} ({:.1f}%) | denom {}'.format(len(top1), acc, _denom))
        return acc
    #
    success, denom, success_by_group, denom_by_group = collate_eval()
    acc_dict = {}
    print ('Total'); acc = print_stats(success, denom); acc_dict['total'] = acc
    print ('-'*50)
    for err_type in success_by_group:
        print (f'{err_type.capitalize()}')
        acc = print_stats(success_by_group[err_type], denom_by_group[err_type])
        acc_dict[err_type] = acc
    json.dump(acc_dict, open(Path(pred_dir_prefix).parent/'stats.json', 'w'), indent=2)

# Evaluate Round 0


## Prepare environment

In [19]:
data_dir = Path(DATA_DIR)
round_dir = data_dir/'round_0'

# Run fixer
model_dir  = round_dir/'model-fixer'
model_path = model_dir/'checkpoint.pt'
destdir_root = round_dir/'orig_bad'

n_splits = 5

# Evaluate
pred_dir_root = round_dir/'orig_bad'
pred_dir_prefix = str(pred_dir_root/'fairseq_preprocess__orig_bad.')
pred_fname  = 'model-fixer.pred.txt'

## Fix errors

In [28]:
for split in range(n_splits):
    destdir    = destdir_root/f'fairseq_preprocess__orig_bad.{split}'
    pred_path  = destdir/'model-fixer.pred.txt'
    fairseq_generate(str(destdir), str(model_path), str(pred_path),
                     src='bad', tgt='good', gen_subset='test',
                     beam=10, nbest=10, max_len_a=1, max_len_b=50, max_tokens=7000)

fairseq-generate             data\round_1_fixer_only\orig_bad\fairseq_preprocess__orig_bad.0         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_fixer_only\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 10         --beam 10 --max-tokens 7000 
fairseq-generate             data\round_1_fixer_only\orig_bad\fairseq_preprocess__orig_bad.1         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_fixer_only\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 10         --beam 10 --max-tokens 7000 
fairseq-generate             data\round_1_fixer_only\orig_bad\fairseq_preprocess__orig_bad.2         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_fixer_only\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 10         --beam 10 --max-tokens 7000 
fairseq-ge

## Evaluate the fix

In [29]:
for split in range(n_splits):
    eval_one_split(pred_dir_prefix, split, pred_fname, n_workers=10)

get_test_result(pred_dir_prefix, pred_fname)

len(preds) 7528


7528it [00:09, 792.21it/s]


len(preds) 7528


7528it [00:09, 807.18it/s]


len(preds) 7528


7528it [00:09, 779.94it/s]


len(preds) 7528


7528it [00:09, 786.61it/s]


len(preds) 7527


7527it [00:09, 788.20it/s]


split 3
split 4
Total
   acc: 13064 (86.8%) | denom 15055
--------------------------------------------------
Unbalanced (){}[]
   acc: 3732 (93.3%) | denom 3999
Invalid syntax
   acc: 4317 (90.9%) | denom 4749
Indentation error
   acc: 5015 (79.5%) | denom 6307


# Evaluate FixerOnly - Round 1

## Prepare environment

In [34]:
data_dir = Path(DATA_DIR)
round_dir = data_dir/'round_1_fixer_only'

# Run fixer
model_dir  = round_dir/'model-fixer'
model_path = model_dir/'checkpoint.pt'
destdir_root = round_dir/'orig_bad'

n_splits = 5

# Evaluate
pred_dir_root = round_dir/'orig_bad'
pred_dir_prefix = str(pred_dir_root/'fairseq_preprocess__orig_bad.')
pred_fname  = 'model-fixer.pred.txt'

## Fix errors on bad dataset

In [37]:
for split in range(n_splits):
    destdir    = destdir_root/f'fairseq_preprocess__orig_bad.{split}'
    pred_path  = destdir/'model-fixer.pred.txt'
    fairseq_generate(str(destdir), str(model_path), str(pred_path),
                     src='bad', tgt='good', gen_subset='test',
                     beam=5, nbest=5, max_len_a=1, max_len_b=50, max_tokens=7000)

fairseq-generate             data\round_1_fixer_only\orig_bad\fairseq_preprocess__orig_bad.0         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_fixer_only\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 5         --beam 5 --max-tokens 7000 
fairseq-generate             data\round_1_fixer_only\orig_bad\fairseq_preprocess__orig_bad.1         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_fixer_only\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 5         --beam 5 --max-tokens 7000 
fairseq-generate             data\round_1_fixer_only\orig_bad\fairseq_preprocess__orig_bad.2         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_fixer_only\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 5         --beam 5 --max-tokens 7000 
fairseq-generate

## Evaluate

In [39]:
print(pred_dir_prefix, pred_fname)
for split in range(n_splits):
    eval_one_split(pred_dir_prefix, split, pred_fname, n_workers=10)

get_test_result(pred_dir_prefix, pred_fname)

data\round_1_fixer_only\orig_bad\fairseq_preprocess__orig_bad. model-fixer.pred.txt
len(preds) 7528


7528it [00:05, 1384.20it/s]


len(preds) 7528


7528it [00:05, 1444.85it/s]


len(preds) 7528


7528it [00:05, 1402.32it/s]


len(preds) 7528


7528it [00:05, 1427.16it/s]


len(preds) 7527


7527it [00:05, 1356.43it/s]


split 3
split 4
Total
   acc: 13067 (86.8%) | denom 15055
--------------------------------------------------
Unbalanced (){}[]
   acc: 3730 (93.3%) | denom 3999
Invalid syntax
   acc: 4321 (91.0%) | denom 4749
Indentation error
   acc: 5016 (79.5%) | denom 6307


# Evaluate BIFI - Round 1

## Prepare environment

In [43]:
data_dir = Path(DATA_DIR)
round_dir = data_dir/'round_1_bifi'

# Run fixer
model_dir  = round_dir/'model-fixer'
model_path = model_dir/'checkpoint.pt'
destdir_root = round_dir/'orig_bad'

n_splits = 5

# Evaluate
pred_dir_root = round_dir/'orig_bad'
pred_dir_prefix = str(pred_dir_root/'fairseq_preprocess__orig_bad.')
pred_fname  = 'model-fixer.pred.txt'

## Fix errors

In [48]:
for split in range(n_splits):
    destdir    = destdir_root/f'fairseq_preprocess__orig_bad.{split}'
    pred_path  = destdir/'model-fixer.pred.txt'
    fairseq_generate(str(destdir), str(model_path), str(pred_path),
                     src='bad', tgt='good', gen_subset='test',
                     beam=1, nbest=1, max_len_a=1, max_len_b=50, max_tokens=7000)

fairseq-generate             data\round_1_bifi\orig_bad\fairseq_preprocess__orig_bad.0         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_bifi\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 1         --beam 1 --max-tokens 7000 
fairseq-generate             data\round_1_bifi\orig_bad\fairseq_preprocess__orig_bad.1         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_bifi\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 1         --beam 1 --max-tokens 7000 
fairseq-generate             data\round_1_bifi\orig_bad\fairseq_preprocess__orig_bad.2         --source-lang bad --target-lang good         --gen-subset test         --path data\round_1_bifi\model-fixer\checkpoint.pt         --max-len-a 1         --max-len-b 50         --nbest 1         --beam 1 --max-tokens 7000 
fairseq-generate             data\round_1_bifi\orig_

## Evaluate

In [49]:
print(pred_dir_prefix, pred_fname)
for split in range(n_splits):
    eval_one_split(pred_dir_prefix, split, pred_fname, n_workers=10)

get_test_result(pred_dir_prefix, pred_fname)

data\round_1_bifi\orig_bad\fairseq_preprocess__orig_bad. model-fixer.pred.txt
len(preds) 7528


7528it [00:01, 5761.37it/s]


len(preds) 7528


7528it [00:01, 5734.75it/s]


len(preds) 7528


7528it [00:01, 5672.66it/s]


len(preds) 7528


7528it [00:01, 5573.60it/s]


len(preds) 7527


7527it [00:01, 6041.65it/s]


split 3
split 4
Total
   acc: 13171 (87.5%) | denom 15055
--------------------------------------------------
Unbalanced (){}[]
   acc: 3757 (93.9%) | denom 3999
Invalid syntax
   acc: 4335 (91.3%) | denom 4749
Indentation error
   acc: 5079 (80.5%) | denom 6307
