# Run inference on a single code snippet

In [None]:
from df2sql import sqlite2postgres
from modules.models import CodeT5, CodeBertJS
from transformers import RobertaTokenizer
from difflib import unified_diff
from difflib import SequenceMatcher
import sqlite3
import pandas as pd
import torch
import os
import random

# Load Model and Tokenizer

In [None]:
CPKT_PATH = 'checkpoints/CodeT5JS_5classes_650MaxL_v3.ckpt' if os.path.exists('checkpoints/CodeT5JS_5classes_650MaxL_v3.ckpt') else ''
MODEL_NAME = CPKT_PATH.split('/')[-1].split('.')[0].split('_')[0]
if MODEL_NAME == 'CodeT5JS':
    HF_DIR = 'Salesforce/codet5-base'
    model = CodeT5.load_from_checkpoint(
        CPKT_PATH, 
        num_classes=5,
        model_dir=HF_DIR,
        with_activation=True,
        with_layer_norm=True
    )
else:
    HF_DIR = 'microsoft/codebert-base-mlm'
    model = CodeBertJS.load_from_checkpoint(CPKT_PATH)

In [15]:
buggy_code = """
# Write a function to display the Fibonacci sequence using recursion
function fibonacci(n) {
  if (n <= 1) {
    return n;
  } else {
    return fibonacci(n + 1) + fibonacci(n + 2);
  }
}
"""
correct_code = """
# Write a function to display the Fibonacci sequence using recursion
function fibonacci(n) {
  if (n <= 1) {
    return n;
  } else {
    return fibonacci(n - 1) + fibonacci(n - 2);
  }
}
"""

bug_type = 'functionality'
print('------------------------------------------')
print('Buggy Code')
print(buggy_code)
print('------------------------------------------')
print('Correct Code')
print(correct_code)

------------------------------------------
Buggy Code

# Write a function to display the Fibonacci sequence using recursion
function fibonacci(n) {
  if (n <= 1) {
    return n;
  } else {
    return fibonacci(n + 1) + fibonacci(n + 2);
  }
}

------------------------------------------
Correct Code

# Write a function to display the Fibonacci sequence using recursion
function fibonacci(n) {
  if (n <= 1) {
    return n;
  } else {
    return fibonacci(n - 1) + fibonacci(n - 2);
  }
}



In [None]:
DB_TABLE = 'humanevalpack'
DB_PATH = 'humanevalpack.db'
QUERY = f"select * from {DB_TABLE}"
con = sqlite3.connect(DB_PATH)
df = pd.read_sql_query(QUERY, con, index_col='index')
# sqlite2postgres(df, 'humanevalpack')

small_samples = df[df['canonical_solution'].str.len() <= 250]
# sample = small_samples[small_samples['task_id'] == 'JavaScript/4'].iloc[0].to_dict()
sample = small_samples.iloc[random.randint(0,len(small_samples)-1)].to_dict()
buggy_code = sample['declaration'] + sample['buggy_solution']
correct_code = sample['declaration'] + sample['canonical_solution']
desc = sample['prompt']
print(desc)
print('------------------------------------------')
print('Buggy Code')
print(buggy_code)
print('------------------------------------------')
print('Correct Code')
print(correct_code)

# Run inference on buggy code

In [16]:
tokenizer = RobertaTokenizer.from_pretrained(HF_DIR)
encoded_buggy_code = tokenizer(buggy_code, padding=True, truncation=True, return_tensors='pt')
encoded_correct_code = tokenizer(correct_code, padding=True, truncation=True, return_tensors='pt')
batch = {
    'input_ids': encoded_buggy_code['input_ids'],
    'attention_mask': encoded_buggy_code['attention_mask'],
    'labels': encoded_correct_code['input_ids'],
}
model.eval()
model.to('cpu')
with torch.no_grad():
    _, out, bug_class = model.forward(batch)
    probs = torch.softmax(bug_class, dim=1)
    pred_class = model.classes[torch.argmax(probs, dim=1).item()]

generated_code = tokenizer.batch_decode(torch.argmax(out, dim=-1), skip_special_tokens=True)[0]

In [18]:
print(generated_code)


# Write a function to display the Fibonacci sequence using recursion
function fibonacci(n) {
  if (n <= 1) {
    return n;
  } else {
    return fibonacci(n + 1) + fibonacci(n - 2);
  }
}



# Συγκρίσεις

#### Διαφορές : Κώδικας με σφάλματα - Διορθωμένος κώδικας (ground truth)

In [19]:
real_codeDiff = unified_diff(buggy_code.splitlines(), correct_code.splitlines())
print("\n".join(real_codeDiff))

--- 

+++ 

@@ -4,6 +4,6 @@

   if (n <= 1) {
     return n;
   } else {
-    return fibonacci(n + 1) + fibonacci(n + 2);
+    return fibonacci(n - 1) + fibonacci(n - 2);
   }
 }


#### Διαφορες : Κώδικας με σφάλματα - Κώδικας που παρήγαγε το μοντέλο

In [20]:
model_codeDiff = unified_diff(buggy_code.splitlines(), generated_code.splitlines())
print("\n".join(model_codeDiff))

--- 

+++ 

@@ -4,6 +4,6 @@

   if (n <= 1) {
     return n;
   } else {
-    return fibonacci(n + 1) + fibonacci(n + 2);
+    return fibonacci(n + 1) + fibonacci(n - 2);
   }
 }


#### Διαφορές : Κώδικας που παρήγαγε το μοντέλο - Διορθωμένος κώδικας

In [21]:
codeDiff = unified_diff(generated_code.splitlines(), correct_code.splitlines())
print("\n".join(codeDiff))

--- 

+++ 

@@ -4,6 +4,6 @@

   if (n <= 1) {
     return n;
   } else {
-    return fibonacci(n + 1) + fibonacci(n - 2);
+    return fibonacci(n - 1) + fibonacci(n - 2);
   }
 }


### Σύγκριση χαρακτήρων:

#### Σύγκριση χαρακτήρα προς χαρακτήρα μεταξύ του κώδικα με σφάλματα (ακολουθία εισόδου) με τον διορθωμένο κώδικα (ground truth)

In [22]:
sm = SequenceMatcher(None, buggy_code, correct_code)

for opcode, i1,i2,j1,j2 in sm.get_opcodes():
    if opcode != 'equal':
        print(opcode)
        if opcode == 'insert':
            print(generated_code[j1:j2])
        elif opcode == 'replace': 
            print(buggy_code[i1:i2])
            print(generated_code[j1:j2])
        elif opcode == 'delete':
            print(buggy_code[i1:i2])

replace
+
+
replace
+
-


### Σύκγριση Χαρακτήρων:

#### Σύγκριση χαρακτήρα προς χαρακτήρα μεταξύ του κώδικα που παρήγαγε το μοντέλο με τον διορθωμένο κώδικα (ground truth)

In [23]:
sm = SequenceMatcher(None, buggy_code, generated_code)

for opcode, i1,i2,j1,j2 in sm.get_opcodes():
    if opcode != 'equal':
        print(opcode)
        if opcode == 'insert':
            print(generated_code[j1:j2])
        elif opcode == 'replace': 
            print(buggy_code[i1:i2])
            print(generated_code[j1:j2])
        elif opcode == 'delete':
            print(buggy_code[i1:i2])

replace
+
-
