In [28]:
import re
import ast
import math
import json
import random
import pandas as pd
from collections import Counter, defaultdict
from typing import List, Iterable, Tuple, Dict, Optional

In [29]:
# load data
df = pd.read_csv('methods.csv')

In [30]:
df

Unnamed: 0,file,name,start_line,end_line,signature,code_tokens,dataset_split
0,spring-boot/smoke-test/spring-boot-smoke-test-...,configure,52,56,protected SpringApplicationBuilder configure(S...,"[""protected"", ""SpringApplicationBuilder"", ""con...",train
1,spring-boot/buildpack/spring-boot-buildpack-pl...,loadJsonFromDistributionManifestList,33,40,void loadJsonFromDistributionManifestList(),"[""void"", ""loadJsonFromDistributionManifestList...",train
2,spring-boot/smoke-test/spring-boot-smoke-test-...,testLegacyDot,165,172,void testLegacyDot(),"[""void"", ""testLegacyDot"", ""("", "")"", ""{""]",train
3,spring-boot/documentation/spring-boot-docs/src...,setHost,41,52,public void setHost(Host host),"[""public"", ""void"", ""setHost"", ""("", ""Host"", ""ho...",train
4,spring-boot/module/spring-boot-security/src/ma...,healthMatcher,78,81,private ServerWebExchangeMatcher healthMatcher(),"[""private"", ""ServerWebExchangeMatcher"", ""healt...",train
...,...,...,...,...,...,...,...
34298,spring-boot/documentation/spring-boot-docs/src...,setCheckLocation,41,44,public void setCheckLocation(boolean checkLoca...,"[""public"", ""void"", ""setCheckLocation"", ""("", ""b...",test
34299,spring-boot/core/spring-boot-autoconfigure/src...,expressionIsTrue,45,50,void expressionIsTrue(),"[""void"", ""expressionIsTrue"", ""("", "")"", ""{""]",test
34300,spring-boot/module/spring-boot-servlet/src/tes...,dispatcherServlet,251,260,DispatcherServlet dispatcherServlet(),"[""DispatcherServlet"", ""dispatcherServlet"", ""(""...",test
34301,spring-boot/module/spring-boot-actuator-autoco...,runWhenEnabledPropertyIsFalseShouldNotHaveEndp...,60,65,void runWhenEnabledPropertyIsFalseShouldNotHav...,"[""void"", ""runWhenEnabledPropertyIsFalseShouldN...",test


In [11]:
def split_train_test(df):
    """
    Split data into two dataframes for strictly train and strictly test. 
    """
    if "dataset_split" in df.columns and df["dataset_split"].notnull().any():
        train_df = df[df["dataset_split"].str.lower() == "train"]
        test_df = df[df["dataset_split"].str.lower() == "test"]
        return train_df, test_df

In [31]:
train_df, test_df = split_train_test(df)

In [32]:
train_df

Unnamed: 0,file,name,start_line,end_line,signature,code_tokens,dataset_split
0,spring-boot/smoke-test/spring-boot-smoke-test-...,configure,52,56,protected SpringApplicationBuilder configure(S...,"[""protected"", ""SpringApplicationBuilder"", ""con...",train
1,spring-boot/buildpack/spring-boot-buildpack-pl...,loadJsonFromDistributionManifestList,33,40,void loadJsonFromDistributionManifestList(),"[""void"", ""loadJsonFromDistributionManifestList...",train
2,spring-boot/smoke-test/spring-boot-smoke-test-...,testLegacyDot,165,172,void testLegacyDot(),"[""void"", ""testLegacyDot"", ""("", "")"", ""{""]",train
3,spring-boot/documentation/spring-boot-docs/src...,setHost,41,52,public void setHost(Host host),"[""public"", ""void"", ""setHost"", ""("", ""Host"", ""ho...",train
4,spring-boot/module/spring-boot-security/src/ma...,healthMatcher,78,81,private ServerWebExchangeMatcher healthMatcher(),"[""private"", ""ServerWebExchangeMatcher"", ""healt...",train
...,...,...,...,...,...,...,...
24995,spring-boot/module/spring-boot-micrometer-trac...,shouldSupplyNoopTracer,42,50,void shouldSupplyNoopTracer(),"[""void"", ""shouldSupplyNoopTracer"", ""("", "")"", ""{""]",train
24996,spring-boot/core/spring-boot/src/test/java/org...,processNoMatchesReturnsNullContribution,61,65,void processNoMatchesReturnsNullContribution(),"[""void"", ""processNoMatchesReturnsNullContribut...",train
24997,spring-boot/module/spring-boot-actuator-autoco...,createWhenHasStatusAggregatorBeanReturnsInstan...,115,130,void createWhenHasStatusAggregatorBeanReturnsI...,"[""void"", ""createWhenHasStatusAggregatorBeanRet...",train
24998,spring-boot/buildpack/spring-boot-buildpack-pl...,decodeBase64,98,103,private static byte[] decodeBase64(String cont...,"[""private"", ""static"", ""byte"", ""["", ""]"", ""decod...",train


Preprocessing helper functions

In [62]:
JAVA_SINGLE_LINE_COMMENT = re.compile(r"//.*?$", re.MULTILINE)
JAVA_MULTI_LINE_COMMENT  = re.compile(r"/\*.*?\*/", re.DOTALL)
JAVA_STRING_LITERAL      = re.compile(r"\"(?:\\.|[^\"\\])*\"")  # dealing with things in string. match string that starts with double quote, contains any sequence of characters and ends with double quote
JAVA_CHAR_LITERAL        = re.compile(r"'(?:\\.|[^'\\])'")      # dealing with char. match single quote string literals
IMPORT_OR_PACKAGE_LINE   = re.compile(r"^\s*(import|package)\b.*?$", re.MULTILINE) # deal with getting import/package declarations in code

# Token pattern: identifiers, numbers, multi-char ops, and single-char punctuation/operators.
TOKEN_PATTERN = re.compile(
    r"""
    [A-Za-z_][A-Za-z_0-9]*        |  # identifiers/keywords
    \d+\.\d+|\d+                  |  # numbers
    ==|!=|<=|>=|&&|\|\||::|->     |  # multi-char ops
    [{}()\[\].,;:+\-*/%<>!=&|^~?]    # single-char ops/punctuations
    """,
    re.VERBOSE
)
# adding these special tokens
BOS = "<BOS>"
EOS = "<EOS>"

In [63]:
def clean_java_code(text):
    """
    Remove comments, imports/package, and whitespace. Also format string/char literals to be a placeholder token so that we reduce noise.
    """
    # Remove comments
    text = re.sub(JAVA_MULTI_LINE_COMMENT, " ", text)
    text = re.sub(JAVA_SINGLE_LINE_COMMENT, " ", text)
    # remove import/packages
    text = re.sub(IMPORT_OR_PACKAGE_LINE, " ", text)
    # Replace string and char literals with placeholders
    text = re.sub(JAVA_STRING_LITERAL, "\"<STR>\"", text)
    text = re.sub(JAVA_CHAR_LITERAL, "'<CHR>'", text)
    # remove whitespace
    text = re.sub(r"\s+", " ", text).strip()
    return text

In [65]:
def tokenize_java(text):
    """
    Tokenize text
    """
    return re.findall(TOKEN_PATTERN, text)

In [66]:
def parse_tokens_field(val):
    """
    Interpret a 'tokens' cell as either:
    - JSON-like list ('["public","class","Foo"]')
    - Space-separated string ('public class Foo { }')
    """
    if val is None:
        return None
    if isinstance(val, list):
        return [str(t) for t in val] #all strings
    s = str(val).strip() #list of strings
    if not s:
        return None
    # JSON-like list?
    if s.startswith("[") and s.endswith("]"): # check if list
        try:
            arr = ast.literal_eval(s)
            if isinstance(arr, list):
                return [str(t) for t in arr]
        except Exception:
            pass
    #space-separated tokens
    if " " in s or "\t" in s:
        return s.split()
    # single token string if nothing else works
    return [s]


In [67]:
def extract_token_sequences(df):
    """
    Given a DataFrame row-wise, produce a list of token sequences.
    """
    cols = set(c.lower() for c in df.columns)
    tokenized = []

    tokens_col = None
    for cand in ["code_tokens"]:
        if cand in cols:
            tokens_col = [c for c in df.columns if c.lower() == cand][0]
            break

    if tokens_col:
        for _, row in df.iterrows():
            seq = parse_tokens_field(row[tokens_col])
            if seq:
                tokenized.append([BOS] * 2 + seq + [EOS])  # 2 BOS for >= trigram padding
        return tokenized

In [89]:
class NgramLM:
    """
    N-gram Language Model with:
      - Count-based
      - Add-k smoothing
    """
    def __init__(self, n = 3, smoothing = "add-k", k = 1.0):
        assert n >= 1
        self.n = n
        self.smoothing = smoothing
        self.k = k

        # Counts
        self.ngram_counts = [Counter() for _ in range(self.n)]  # 1..n
        self.context_counts = [Counter() for _ in range(self.n)]

        # vocal
        self.vocab = set()

    def fit(self, sequences):
        """
        Build n-gram and context counts.
        """
        for seq in sequences:
            self.vocab.update(seq) # add to vocab
            L = len(seq)
            for m in range(1, self.n + 1):
                for i in range(L - m + 1):
                    ngram = tuple(seq[i:i+m])
                    self.ngram_counts[m-1][ngram] += 1
                    context = ngram[:-1]  # empty for unigrams
                    self.context_counts[m-1][context] += 1
    @ property
    def V(self):
        return len(self.vocab)

    def prob(self, context, token):
        """
        P(token | context) using add-k smoothing.
        """
        # Iterate from full-order down to unigrams
        for m in range(self.n, 0, -1):
            ctx = context[-(m-1):] if m > 1 else tuple()
            ngram = ctx + (token,)
            num = self.ngram_counts[m-1][ngram] # numerator
            den = self.context_counts[m-1][ctx] # denominator
            if den > 0 or m == 1:
                if self.smoothing == "add-k":
                    return (num + self.k) / (den + self.k * self.V) # dealing with 0 case
                else:
                    # default to no smoothing
                    if den == 0:
                        # backoff to lower order so continue if doesn't work when context unseen; if we reach m=1, force tiny prob
                        continue
                    return num / den
        # If everything failed, return tiny prob
        return 1.0 / (self.V * 1e3 + 1)

    def next_token_dist(self, context_tokens):
        """
        Return a sorted distribution over next possible tokens for a given context.
        """
        context = tuple(context_tokens[-(self.n-1):]) if self.n > 1 else tuple()
        dist = []
        for tok in self.vocab:
            p = self.prob(context, tok)
            dist.append((tok, p))
        # sort by prob
        dist.sort(key=lambda x: x[1], reverse=True)
        return dist

    def sample(self, seed_tokens, max_len = 50):
        """
        Sample a full sequence continuation example from the model, starting with seed_tokens.
        Stops at EOS or when reached the max_len.
        """
        tokens = seed_tokens[:]
        for _ in range(max_len):
            dist = self.next_token_dist(tokens)
            # draw according to probabilities
            r = random.random()
            cum = 0.0
            for tok, p in dist:
                cum += p
                if r <= cum:
                    tokens.append(tok)
                    break
            if tokens[-1] == EOS:
                break
        return tokens

perplexity - accuracy evaluation

In [81]:
def sequence_logprob(model, seq):
    """
    Compute log probability of a single token sequence under the model.
    """
    logp = 0.0
    for i in range(len(seq)):
        context = tuple(seq[max(0, i - (model.n - 1)):i]) if model.n > 1 else tuple()
        tok = seq[i]
        p = model.prob(context, tok)
        # handling too small p
        p = max(p, 1e-12)
        logp += math.log(p)
    return logp


def perplexity(model, sequences):
    """
    Perplexity = exp( - (1/total_tokens) * sum_log_probs )
    """
    total_tokens = 0
    total_logp = 0.0
    for seq in sequences:
        total_logp += sequence_logprob(model, seq)
        total_tokens += len(seq)
    if total_tokens == 0: # handling 0 denom
        return float("inf")
    return math.exp(- total_logp / total_tokens)


In [82]:
# Build token sequences
train_sequences = extract_token_sequences(train_df)
test_sequences  = extract_token_sequences(test_df)


In [83]:
train_sequences

[['<BOS>',
  '<BOS>',
  'protected',
  'SpringApplicationBuilder',
  'configure',
  '(',
  'SpringApplicationBuilder',
  'application',
  ')',
  '{',
  '<EOS>'],
 ['<BOS>',
  '<BOS>',
  'void',
  'loadJsonFromDistributionManifestList',
  '(',
  ')',
  '{',
  '<EOS>'],
 ['<BOS>', '<BOS>', 'void', 'testLegacyDot', '(', ')', '{', '<EOS>'],
 ['<BOS>',
  '<BOS>',
  'public',
  'void',
  'setHost',
  '(',
  'Host',
  'host',
  ')',
  '{',
  '<EOS>'],
 ['<BOS>',
  '<BOS>',
  'private',
  'ServerWebExchangeMatcher',
  'healthMatcher',
  '(',
  ')',
  '{',
  '<EOS>'],
 ['<BOS>',
  '<BOS>',
  'public',
  'RepositoryDetectionStrategies',
  'getDetectionStrategy',
  '(',
  ')',
  '{',
  '<EOS>'],
 ['<BOS>',
  '<BOS>',
  'void',
  'bootBuildImageWithDockerHostMinikube',
  '(',
  ')',
  '{',
  '<EOS>'],
 ['<BOS>',
  '<BOS>',
  'final',
  'protected',
  'ConfigurableApplicationContext',
  'createContext',
  '(',
  'String',
  'driverClassName',
  ',',
  'Class',
  '<',
  '?',
  '>',
  '.',
  '.',
  '

In [96]:
# Try varying N size and get each perplexity from them
results = []
for N in [2, 3, 5, 7]:
    lm = NgramLM(n=N, smoothing="add-k", k=1.0)
    lm.fit(train_sequences)
    ppl = perplexity(lm, test_sequences)
    results.append((N, ppl, lm.V))
    print(f"N={N} | Vocab={lm.V} | Test Perplexity={ppl}")

N=2 | Vocab=21209 | Test Perplexity=46.659859797654114
N=3 | Vocab=21209 | Test Perplexity=77.84711852208532
N=5 | Vocab=21209 | Test Perplexity=159.74469082646425
N=7 | Vocab=21209 | Test Perplexity=169.44302466198585


Demo Sampled completion

In [98]:
demo_seq = ['static']
# Use first 20 tokens (without trailing EOS) as context
demo_prefix = [t for t in demo_seq if t != EOS][:20]

# Use best (so lowest) perplexity model among tried ones
best_N, _, _ = min(results, key=lambda x: x[1])
print('best N')
print(best_N) # this should be 2 for checking
best_model = NgramLM(n=best_N, smoothing="add-k", k=1.0)
best_model.fit(train_sequences)

dist = best_model.next_token_dist(demo_prefix)
top10 = dist[:10]

print("Context prefix:", " ".join(demo_prefix[-(best_N-1):]))
print("Top-10 candidates:")
for tok, p in top10:
    print(f"{tok}  p={p}")

# Sampling continuation
sampled = best_model.sample(demo_prefix, max_len=40)
print("Sampled continuation:")
print(" ".join(sampled))

best N
2
Context prefix: static
Top-10 candidates:
void  p=0.015486527616148956
String  p=0.004699668785247516
boolean  p=0.0016560737624205532
Stream  p=0.0012980037597350282
List  p=0.0012532450093993375
int  p=0.0007608987557067407
T  p=0.0004475875033569063
PropertySourcesPlaceholderConfigurer  p=0.0004028287530212156
SslBundle  p=0.0004028287530212156
ImageReference  p=0.0004028287530212156
Sampled continuation:
static bindWhenCollectionParameterWithEmptyDefaultValueShouldReturnEmptyInstance setTlsVerify setTrustStore SortHandlerMethodArgumentResolverCustomizer usesMatchersBasedOffConfiguredActuatorBasePath generatedBuildInfoUsesCustomBuildTime matchWhenHasNoEndpoints setSourceType toHierarchicalName setBatchSize withRunImage ofWhenHasIllegalCharacterThrowsException publicDefineClass getBindConstructorWhenIsTypeWithPrivateConstructorReturnsNull singleWebEndpointCanBeExposed validateWithWrappedExceptionMessageWhenInvalidThrowsException warnIfNotHttps testPasswordEncoding graphQlWeb