### SEED GATHERING GET CONTENT

In [1]:
from tree_sitter_parser import LANGUAGE, make_parser, node_to_string
import datasets
import os
import signal
from multiprocessing import Pool
import os
import boto3
import smart_open
from datasets import load_dataset,Dataset
from botocore import UNSIGNED
from botocore.config import Config
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))

def download_contents(blob_id, src_encoding):
    s3_url = f"s3://softwareheritage/content/{blob_id}"
    with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": s3}) as fin:
        content = fin.read().decode(src_encoding)
    
    return content

In [2]:
JAVA_METHOD_QUERY = LANGUAGE.query("""
(
  (method_declaration
    name: (identifier) @method.name
    (modifiers)? @method.modifiers
    (type_identifier)? @method.return_type
    parameters: (formal_parameters) @method.parameters) @method.declaration
)
""")



# JAVA_METHOD_QUERY = LANGUAGE.query("""
# (method_declaration
#     name: (identifier) @method.name
#     body: (block) @method.body
# ) @method.declaration
# """)


def get_methods(src, tree):
    captures = JAVA_METHOD_QUERY.captures(tree.root_node)
    res = []
    for capture in captures:
        node, ty = capture
        if ty != "method.declaration":
            continue
        # Filter for top-level methods (starting column 0)
        _, col = node.start_point
        if col != 0:
            continue
        res.append(node_to_string(src, node))
    return res

def parse_ex_java(parser, ex):
    ex = download_contents(ex["blob_id"], ex["src_encoding"])
    try:
        buf = bytes(ex, "utf8")
        tree = parser.parse(buf)
        return get_methods(buf, tree)
    except:
        return []

def process_chunk_java(idx_and_chunk):
    assert PARSERS is not None
    idx, chunk = idx_and_chunk
    parser = PARSERS[idx]
    chunk_new_methods = set()
    for ex in chunk:
        chunk_new_methods.update(parse_ex_java(parser, ex))
    return chunk_new_methods

def main_java(args):
    global PARSERS
    ds = datasets.load_dataset(
        args.dataset,
        data_dir=args.data_dir,
        split="train",
    )
    methods = set()
    PARSERS = [make_parser() for _ in range(args.num_workers)]
    total_len = len(ds)
    CHUNK_SIZE = 1000 * args.num_workers

    print(f"Total length: {total_len}")
    print(f"Chunk size: {CHUNK_SIZE}")

    chunk = []
    p = Pool(args.num_workers)
    for i, ex in enumerate(ds):
        if i % (total_len // 100) == 0:
            print(f"{i}/{total_len}")
        try:
            chunk.append(ex)
            if len(chunk) == CHUNK_SIZE or i == total_len - 1:
                print(f"Processing chunk {i // CHUNK_SIZE}")
                subchunk_size = len(chunk) // args.num_workers
                subchunks = [chunk[i:i + subchunk_size]
                             for i in range(0, len(chunk), subchunk_size)]
                new_methods_iter = p.imap(
                    process_chunk_java, [(i, subchunk) for i, subchunk in enumerate(subchunks)])
                print("Getting new methods")
                len_before = len(methods)
                while True:
                    try:
                        def timeout_handler(_, __):
                            raise KeyboardInterrupt
                        signal.signal(signal.SIGALRM, timeout_handler)
                        signal.alarm(60)
                        methods.update(next(new_methods_iter))
                        signal.alarm(0)
                    except KeyboardInterrupt:
                        signal.alarm(0)
                        print("Keyboard interrupt. Terminating pool")
                        p.terminate()
                        p = Pool(args.num_workers)
                        break
                    except StopIteration:
                        break
                    except Exception as e:
                        print(e)

                signal.alarm(0)

                PARSERS = [make_parser() for _ in range(args.num_workers)]

                print(
                    f"Done processing chunk {i // CHUNK_SIZE}. Got {len(methods) - len_before} new methods")

                chunk = []
        except Exception as e:
            print(e)
            chunk = []

        if i == total_len - 1:
            break

    p.close()

    new_ds_dict = {
        "content": list(methods),
        "id": list(range(len(methods)))
    }

    new_ds = datasets.Dataset.from_dict(new_ds_dict)

In [4]:
code = """
public class Example {
    public void sayHello() {
        System.out.println("Hello, world!");
    }
    public int add(int a, int b) {
        return a + b;
    }
}
"""

parser = make_parser()  # Use your make_parser function
tree = parser.parse(bytes(code, "utf8"))

captures = JAVA_METHOD_QUERY.captures(tree.root_node)
for node, ty in captures:
    print(f"Type: {ty}, Code: {node_to_string(bytes(code, 'utf8'), node)}")


Type: method.declaration, Code: public void sayHello() {
        System.out.println("Hello, world!");
    }
Type: method.name, Code: sayHello
Type: method.parameters, Code: ()
Type: method.declaration, Code: public int add(int a, int b) {
        return a + b;
    }
Type: method.name, Code: add
Type: method.parameters, Code: (int a, int b)


In [5]:
NUMWORKERS = os.cpu_count()

In [67]:
ds = datasets.load_dataset("bigcode/the-stack-v2-dedup", "Java", cache_dir="../cache", streaming=True, split="train")

data = []
n =  10000
for i, sample in enumerate(ds):
    data.append(sample)
    if i >= n:  # Stop after collecting 2000 samples
        break

map_style_dataset = Dataset.from_list(data)

# Verify the Dataset
print(map_style_dataset)

map_style_dataset.save_to_disk(f"sampled_dataset_{n}")

Resolving data files:   0%|          | 0/757 [00:00<?, ?it/s]

Dataset({
    features: ['blob_id', 'directory_id', 'path', 'content_id', 'detected_licenses', 'license_type', 'repo_name', 'snapshot_id', 'revision_id', 'branch_name', 'visit_date', 'revision_date', 'committer_date', 'github_id', 'star_events_count', 'fork_events_count', 'gha_license_id', 'gha_event_created_at', 'gha_created_at', 'gha_language', 'src_encoding', 'language', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename'],
    num_rows: 10001
})


Saving the dataset (0/1 shards):   0%|          | 0/10001 [00:00<?, ? examples/s]

In [42]:
n =  10000
loaded_dataset = Dataset.load_from_disk(f"sampled_dataset_{n}")
ds = loaded_dataset

In [43]:
funs = set()
PARSERS = [make_parser() for _ in range(NUMWORKERS)]
total_len = len(ds)
CHUNK_SIZE = 1000 * NUMWORKERS

print(f"Total length: {total_len}")
print(f"Chunk size: {CHUNK_SIZE}")

chunk = []
p = Pool(NUMWORKERS)

Total length: 10001
Chunk size: 20000


In [44]:
for i, ex in enumerate(iter(ds)):
    # if i % (total_len // 100) == 0:
    #     print(f"{i}/{total_len}")
    try:
        chunk.append(ex)
        if len(chunk) == CHUNK_SIZE or i == total_len - 1:
            print(f"Processing chunk {i // CHUNK_SIZE}")
            # divide the chunk into NUM_WORKERS chunks
            subchunk_size = len(chunk) // NUMWORKERS
            subchunks = [chunk[i:i + subchunk_size]
                         for i in range(0, len(chunk), subchunk_size)]
            new_funs_iter = p.imap(
                process_chunk_java, [(i, subchunk) for i, subchunk in enumerate(subchunks)])
            print("Getting new functions")
            len_before = len(funs)
            while True:
                try:
                    def timeout_handler(_, __):
                        raise KeyboardInterrupt  # it's fineeeeeee
                    signal.signal(signal.SIGALRM, timeout_handler)
                    signal.alarm(60)
                    funs.update(next(new_funs_iter))
                    signal.alarm(0)
                except KeyboardInterrupt:
                    signal.alarm(0)
                    print("Keyboard interrupt. Terminating pool")
                    p.terminate()
                    p = Pool(NUMWORKERS)
                    break
                except StopIteration:
                    break
                except Exception as e:
                    print(e)

            signal.alarm(0)

            PARSERS = [make_parser() for _ in range(NUMWORKERS)]

            print(
                f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions")

            chunk = []
    except Exception as e:
        print(e)
        chunk = []

    if i == total_len - 1:
        break


p.close()
new_ds_dict = {
    "content": list(funs),
    "id": list(range(len(funs)))
}

new_ds = datasets.Dataset.from_dict(new_ds_dict)
new_ds

Processing chunk 0
Getting new functions
list index out of range
Done processing chunk 0. Got 267 new functions


Dataset({
    features: ['content', 'id'],
    num_rows: 267
})

In [46]:
ds = new_ds
print(ds['content'][1])

public void setYear(int Yr)
{
if (Yr < 2000 || Yr > 2017)
   Year =0;
else
   Year = Yr;
}


### SEED GATHERING HIGH-QUALITY SUBSET

In [47]:
ds = Dataset.load_from_disk("sampled_dataset_10000")

In [48]:
ds = ds.map(
    lambda row: {'content': download_contents(row['blob_id'], row['src_encoding'])},
    num_proc=os.cpu_count()
)

Map (num_proc=20):   0%|          | 0/10001 [00:00<?, ? examples/s]

In [49]:
import subprocess
import tempfile
import signal
import hashlib
import os
from typing import List, Dict
from tqdm import tqdm
from tree_sitter_parser import LANGUAGE, global_parser

# Query to find return statements in Java code
RETURN_QUERY = LANGUAGE.query("""
(return_statement) @return
""")

def does_have_return(src: str) -> bool:
    """
    Check if the given Java code contains a return statement with a value.
    """
    tree = global_parser.parse(bytes(src, "utf8"))
    root = tree.root_node
    captures = RETURN_QUERY.captures(root)
    for node, _ in captures:
        # If it doesn't have an argument, it's not a return with a value
        if len(node.children) <= 1:  # Includes "return" itself
            continue
        else:
            return True
    return False


In [50]:
print("Filtering to only functions with return statements")
ds = ds.filter(lambda ex: does_have_return(
    ex["content"]), num_proc=os.cpu_count())
ds


Filtering to only functions with return statements


Filter (num_proc=20):   0%|          | 0/10001 [00:00<?, ? examples/s]

Dataset({
    features: ['blob_id', 'directory_id', 'path', 'content_id', 'detected_licenses', 'license_type', 'repo_name', 'snapshot_id', 'revision_id', 'branch_name', 'visit_date', 'revision_date', 'committer_date', 'github_id', 'star_events_count', 'fork_events_count', 'gha_license_id', 'gha_event_created_at', 'gha_created_at', 'gha_language', 'src_encoding', 'language', 'is_vendor', 'is_generated', 'length_bytes', 'extension', 'filename', 'content'],
    num_rows: 5835
})

In [51]:
def run_javac(directory: str) -> Dict[str, int]:
    """
    Runs the `javac` command in the given directory and parses the output to count errors for each file.
    """
    try:
        result = subprocess.run(
            ["javac", "*.java"],
            cwd=directory,
            capture_output=True,
            timeout=120,
            text=True,
        )
    except Exception as e:
        print(f"Error running javac: {e}")
        return {}

    file_error_map = {}
    error_lines = result.stderr.split("\n")
    for line in error_lines:
        if line.strip():
            parts = line.split(":")
            if len(parts) >= 2:
                file_name = parts[0].strip()
                if file_name not in file_error_map:
                    file_error_map[file_name] = 0
                if "error" in line:
                    file_error_map[file_name] += 1

    return file_error_map

def typecheck_batch(files: List[str]) -> Dict[str, str]:
    """
    Type-checks a batch of Java files and filters out files with compilation errors.
    """
    filemap: Dict[str, str] = {}
    with tempfile.TemporaryDirectory() as tempdir:
        for content in files:
            # Generate a unique filename using SHA-1 hash
            hash_object = hashlib.sha1(bytes(content, "utf8"))
            hex_dig = hash_object.hexdigest()
            filemap[hex_dig] = content
            file_path = os.path.join(tempdir, hex_dig + ".java")
            with open(file_path, "w") as f:
                f.write(content)

        # Run javac in the temporary directory
        error_map = run_javac(tempdir)
        print(error_map)

        if not error_map:
            return {}

        for file_name, error_count in error_map.items():
            no_java = file_name.replace(".java", "")
            if error_count > 0 and no_java in filemap:
                del filemap[no_java]

        print(f"Pass rate: {len(filemap)}/{len(files)}")
        return filemap


In [52]:
batch = []
max_i = len(ds) - 1

new_ds = {
    "content": [],
    "sha1": [],
    "id": [],
}

e_id = 0
for i, ex in enumerate(tqdm(ds, total=len(ds))):
    try:
        code = ex["content"]

        batch.append(code)

        if len(batch) == 250 or i == max_i:
            filemap = typecheck_batch(batch)
            for sha1, contents in filemap.items():
                new_ds["content"].append(contents)
                new_ds["sha1"].append(sha1)
                new_ds["id"].append(e_id)
                e_id += 1
            batch = []
            
    except Exception as e:
        print(f"There was an error: {e}")
        continue

new_ds_hf = datasets.Dataset.from_dict(new_ds)

  4%|▍         | 250/5835 [00:00<00:09, 560.03it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


  9%|▊         | 500/5835 [00:00<00:06, 776.95it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 13%|█▎        | 750/5835 [00:00<00:05, 878.17it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 17%|█▋        | 1000/5835 [00:01<00:05, 938.32it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 21%|██▏       | 1250/5835 [00:01<00:04, 928.66it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 26%|██▌       | 1500/5835 [00:01<00:04, 981.49it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 30%|██▉       | 1750/5835 [00:01<00:04, 1017.61it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 34%|███▍      | 2000/5835 [00:02<00:03, 999.81it/s] 

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 39%|███▊      | 2250/5835 [00:02<00:03, 958.69it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 43%|████▎     | 2500/5835 [00:02<00:03, 995.76it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 47%|████▋     | 2750/5835 [00:02<00:03, 1016.10it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 51%|█████▏    | 3000/5835 [00:03<00:02, 1027.67it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 56%|█████▌    | 3250/5835 [00:03<00:02, 1038.44it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 60%|█████▉    | 3500/5835 [00:03<00:02, 1004.85it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 64%|██████▍   | 3750/5835 [00:03<00:02, 992.73it/s] 

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 69%|██████▊   | 4000/5835 [00:04<00:01, 1001.97it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 73%|███████▎  | 4250/5835 [00:04<00:01, 1060.89it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 81%|████████  | 4739/5835 [00:04<00:00, 1121.95it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 84%|████████▎ | 4879/5835 [00:05<00:00, 974.28it/s] 

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 86%|████████▌ | 5000/5835 [00:05<00:00, 852.07it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 90%|████████▉ | 5250/5835 [00:05<00:00, 937.27it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


 94%|█████████▍| 5500/5835 [00:05<00:00, 1012.60it/s]

{'error': 1, 'Usage': 0}
Pass rate: 250/250


100%|██████████| 5835/5835 [00:06<00:00, 963.18it/s] 

{'error': 1, 'Usage': 0}
Pass rate: 250/250
{'error': 1, 'Usage': 0}
Pass rate: 85/85





In [53]:
new_ds_hf

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 5835
})

In [54]:
save_dir = "../datasets/Dec14"
new_ds_hf.save_to_disk(save_dir)

Saving the dataset (0/1 shards):   0%|          | 0/5835 [00:00<?, ? examples/s]

In [55]:
new_ds_hf.to_json("../datasets/Dec14/sample_from_2000.json")

Creating json from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

28491434

### SEED GATHERING FILTER DATASET

In [56]:
ds = Dataset.load_from_disk("../datasets/Dec14")

In [57]:
import datasets
import os
from tree_sitter_parser import global_parser, LANGUAGE, make_parser
import benchmark_data
from tqdm import tqdm
import torch
import argparse
from vllm import LLM, SamplingParams
import random

In [70]:
def unindent(s):
    """
    Remove leading indentation from a multi-line string.
    """
    lines = s.splitlines()
    non_blank_lines = [line for line in lines if line.strip()]
    min_indent = min(len(line) - len(line.lstrip())
                     for line in non_blank_lines) if non_blank_lines else 0
    unindented_lines = [line[min_indent:] if len(
        line) >= min_indent else line for line in lines]
    return '\n'.join(unindented_lines)


def java_extract_javadoc(code):
    """
    Extract the first Javadoc-style comment (`/* ... */`) from Java code
    and return the comment along with the remaining code.

    Args:
        code (str): The Java code as a string.

    Returns:
        tuple: The extracted Javadoc comment (str) and the remaining code (str).
    """
    # Find the opening of the Javadoc comment
    first_comment_start = code.find("/*")
    if first_comment_start == -1:
        raise ValueError("No Javadoc comment found in the code.")

    # Find the closing of the Javadoc comment
    first_comment_end = code.find("*/", first_comment_start)
    if first_comment_end == -1:
        raise ValueError("Javadoc comment is not properly closed.")

    # Extract the comment
    comment = code[first_comment_start + 3:first_comment_end]  # Skip `/*` and include content
    comment = unindent(comment).strip()  # Unindent and clean up the comment

    # Remove the Javadoc comment from the code
    remaining_code = code[:first_comment_start] + code[first_comment_end + 2:]

    return comment, remaining_code.strip()


In [87]:
FN_BLOCK_QUERY = LANGUAGE.query("""
(method_declaration
    body: (block) @method-body)
""")


def template_few_shot(code, answer, rationale):
    doc, code = java_extract_javadoc(code)
    assert answer == "No" or answer == "Yes"
    prompt = f"""<issue_start>username_0: I have a function in Java and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```Java
{code}
```

Here is my description of this program:
```
{doc}
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is: {answer}

{rationale}

Upvotes: 200"""
    return prompt


FEW_SHOTS = [
    (
        """
        public List<String> simpleScanNetwork() {
            /*
             * Do a simple network scan, which only works if your network configuration
             * is 192.168.1.x
             */
            String baseIp = "192.168.1.";
            List<String> addresses = new ArrayList<>();
            addresses.add("127.0.0.1");

            for (int index = 1; index < 255; index++) {
                addresses.add(baseIp + index);
            }

            return addresses;
        }
        """,
        "No",
        "The simpleScanNetwork method you have provided seems to generate addresses that then would be used for a network scan, but does not actually perform it, unlike the method claims."
    ),
    (
        """
        import java.util.*;
        
        public class DataFrameUtils {
            public static DataFrame coerceInteger(DataFrame df) {
                /*
                 * Loop through the columns of a DataFrame. If it is numeric,
                 * convert it to integer and fill missing values with zeros.
                 * This is somewhat heavy-handed in an attempt to force
                 * systems to recognize sparse columns as integers.
                 */
                List<String> except = Arrays.asList("latitude", "longitude", "zipCode");

                df.forEachColumn((name, series) -> {
                    if (series.isNumeric() && !except.contains(name)) {
                        series.fillNaN(0).toInteger();
                    }
                });

                return df;
            }
        }
        """,
        "Yes",
        "The docstring does seem to match the implementation! The method loops through the columns of a DataFrame and coerces them as explained."
    ),
    (
        """
        public class NameTransformer {
            /*
             * Converts a DataFrame to a dictionary.
             *
             * @param data The input DataFrame.
             * @return A map containing transformed names.
             */
            public static Map<String, Map<String, String>> transformDataFrameToDict(DataFrame data) {
                data.setColumn("en_name", data.getColumn("en_name").toUpperCase());
                data.setColumn("en_name_f", data.getColumn("en_name").split(" ")[0]);
                data.setColumn("en_name_l", data.getColumn("en_name").split(" ")[1]);
                data.setColumn("jp_name_f", data.getColumn("jp_name").split("・")[0]);
                data.setColumn("jp_name_l", data.getColumn("jp_name").split("・")[1]);

                Map<String, String> fullNameMap = data.zipToMap("en_name", "jp_name");
                Map<String, String> firstNameMap = data.zipToMap("en_name_f", "jp_name_f");
                Map<String, String> lastNameMap = data.zipToMap("en_name_l", "jp_name_l");

                return Map.of(
                    "fullNameMap", fullNameMap,
                    "firstNameMap", firstNameMap,
                    "lastNameMap", lastNameMap
                );
            }
        }
        """,
        "No",
        "The transformDataFrameToDict method does indeed convert a DataFrame into a dictionary, but it transforms various columns that were not described in the docstring. For instance, nowhere in the docstring is it mentioned handling Japanese characters or the column names."
    ),
    (
        """
        public double inchesToMeters(double inches) {
            /*
             * Convert inches to meters.
             */
            return inches * 0.0254;
        }
        """,
        "Yes",
        "inchesToMeters is a very simple method. The docstring explains concisely its purpose, which is converting inches to meters."
    ),
    (
        """
        public BufferedImage squareCrop(BufferedImage image, Integer targetSize) {
            /*
             * Crop the image to `targetSize`. If targetSize is null, the image
             * is cropped to the smallest dimension, making it square.
             */
            int width = image.getWidth();
            int height = image.getHeight();

            if (targetSize == null) {
                targetSize = Math.min(width, height);
            }

            int dx = (width - targetSize) / 2;
            int dy = (height - targetSize) / 2;

            return image.getSubimage(dx, dy, targetSize, targetSize);
        }
        """,
        "Yes",
        "Following the standard description for docstrings for methods, the squareCrop method description tells exactly what the method does."
    ),
    (
        """
        public Map<String, String> setupMotifFiles(Args args) {
            /*
             * Convenience method, ensures the setup is the same across
             * multiplicity/orientation/spacing workflows.
             */
            Map<String, String> motifFiles = new HashMap<>();
            motifFiles.put("early", String.format("%s/%s/ggr.scanmotifs.h5",
                args.getInput("inference").get(args.getCluster()).get("scanmotifs_dir"),
                args.getInput("inference").get(args.getCluster()).get("scanmotifs_early_dir")
            ));
            motifFiles.put("mid", String.format("%s/%s/ggr.scanmotifs.h5",
                args.getInput("inference").get(args.getCluster()).get("scanmotifs_dir"),
                args.getInput("inference").get(args.getCluster()).get("scanmotifs_mid_dir")
            ));
            motifFiles.put("late", String.format("%s/%s/ggr.scanmotifs.h5",
                args.getInput("inference").get(args.getCluster()).get("scanmotifs_dir"),
                args.getInput("inference").get(args.getCluster()).get("scanmotifs_late_dir")
            ));
            return motifFiles;
        }
        """,
        "No",
        "The docstring for setupMotifFiles just says this is a convenience method. There is definitely not enough information to re-implement this method from the docstring alone."
    ),
    (
        """
        public double trip(double[] u, double[] v) {
            /*
             * Returns the scalar triple product of vectors u and v and z-axis.
             * The convention is z dot (u cross v). Dotting with the z-axis simplifies
             * it to the z component of the u cross v.
             *
             * The product is:
             * - positive if v is to the left of u, that is,
             *   the shortest right-hand rotation from u to v is ccw.
             * - negative if v is to the right of u, that is,
             *   the shortest right-hand rotation from u to v is cw.
             * - zero if v is collinear with u.
             *
             * Essentially, trip is the z component of the cross product of u x v.
             */
            return (u[0] * v[1] - u[1] * v[0]);
        }
        """,
        "Yes",
        "The docstring for the trip method is very detailed and describes the method's purpose and the mathematical formula used to calculate the scalar triple product."
    )
]



def prompt_fmt_java(code):
    doc, code = java_extract_javadoc(code)
    random.shuffle(FEW_SHOTS)
    buf = ""
    for few in FEW_SHOTS:
        buf += template_few_shot(*few)
    buf += f"""<issue_start>username_0: I have a function in Java and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```java
{code}
```

Here is my description of this program:
```
{doc}
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.
Upvotes: 100<issue_comment>username_1: Sure, no problem. I will be able to help.
My answer is:"""
    return buf


def auto_dtype():
    if torch.cuda.is_bf16_supported():
        return "bfloat16"
    return "auto"


def chunkify(lst, n):
    chunks = []
    for i in range(0, len(lst), n):
        chunk = []
        for j in range(n):
            if i + j < len(lst):
                chunk.append(lst[i + j])
        chunks.append(chunk)
    return chunks


In [60]:
dataset = ds
dataset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 5835
})

In [61]:
print(f"Loaded {len(dataset)} examples. Running pre-filtering...")

BAD_WORDS = ["todo", "fixme", "bug"]
BAD_IMPORTS = [
    "java.util.Scanner", 
    "java.lang.Runtime", 
    "java.lang.ProcessBuilder", 
    "javax.swing", 
    "java.awt"
]
BAD_IMPORTS = [f"import {b};" for b in BAD_IMPORTS]
# BAD_SUBSTRINGS = BAD_WORDS + BAD_IMPORTS
BAD_SUBSTRINGS = BAD_WORDS

bench_filter = benchmark_data.filter_out()
all_bench = bench_filter["human_eval_docstrings"] + \
    bench_filter["human_eval_solutions"] + \
    bench_filter["mbpp_docstrings"] + \
    bench_filter["mbpp_solutions"]

Loaded 5835 examples. Running pre-filtering...
num strings from mbpp_docstrings: 120
num strings from mbpp_solutions: 120
num strings from human_eval_docstrings: 164
num strings from human_eval_solutions: 161


In [None]:
RETURN_QUERY = LANGUAGE.query("""
(return_statement) @return
""")

def does_have_return(src: str) -> bool:
    """
    Check if the given Java code contains a return statement with a value.
    """
    tree = global_parser.parse(bytes(src, "utf8"))
    root = tree.root_node
    captures = RETURN_QUERY.captures(root)
    for node, _ in captures:
        # If it doesn't have an argument, it's not a return with a value
        if len(node.children) <= 1:  # Includes "return" itself
            continue
        else:
            return True
    return False

def pre_filtering_java(ex):
    # """
    # Pre-filter Java code examples based on specific quality criteria.
    # """
    code = ex["content"]
    code_bytes = code.encode('utf-8')

    # # Filter out bad substrings
    lower = code.lower()
    for word in BAD_SUBSTRINGS:
        if word in lower:
            return False

    # Too many lines of code -- say 150
    lines = code.split("\n")
    if len(lines) > 150:
        return False

    # # Exclude methods without meaningful parameters
    for line in lines:
        # Look for method declarations
        if line.strip().startswith(("public", "private", "protected")) and "()" in line:
            return False

    # # Filter out methods with no return statement
    parser = make_parser()
    if not does_have_return(code):
        return False

    # try:
    #     # Parse the Java code with Tree-sitter
    #     tree = global_parser.parse(code_bytes)
    #     block, _ = FN_BLOCK_QUERY.captures(tree.root_node)[0]

    #     # Get the Javadoc, filter if not a valid Javadoc
    #     preceding_comments = block.prev_sibling
    #     if not preceding_comments or preceding_comments.type != "comment":
    #         return False

    #     # Extract and validate the Javadoc content
    #     docstring_text = preceding_comments.text.decode('utf-8').strip()
    #     # if not docstring_text.startswith("/*") or not docstring_text.endswith("*/"):
    #     if not docstring_text.startswith("/*") or not docstring_text.endswith("*/"):
    #         return False
    # except Exception as e:
    #     print(f"Error in filtering: {e}")
    #     return False

    return True  # Passes all checks


# threads = os.cpu_count() - 1  # type: ignore
dataset = ds.filter(pre_filtering_java, num_proc=os.cpu_count())
dataset

  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


Filter (num_proc=20):   0%|          | 0/5835 [00:00<?, ? examples/s]

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 1289
})

In [64]:
dataset.save_to_disk("../datasets/Dec14/Java_after_pre_filtering")

Saving the dataset (0/1 shards):   0%|          | 0/1289 [00:00<?, ? examples/s]

In [2]:
dataset = Dataset.load_from_disk("../datasets/Dec14/Java_after_pre_filtering")


NameError: name 'Dataset' is not defined

In [None]:
model = LLM(f"../../../StarCoder", dtype=auto_dtype(),
            gpu_memory_utilization=0.95, tensor_parallel_size=1)


In [34]:
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "bigcode/starcoder2-15b"
# CHECKPOINT = "bigcode/starcoderbase-1b"
# CHECKPOINT = "bigcode/starcoder2-3b"
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)



In [None]:
# pip install bitsandbytes accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# to use 4bit use `load_in_4bit=True` instead
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

checkpoint = "bigcode/starcoder2-15b"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=quantization_config).to("cuda")
model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=quantization_config).to("cuda")


In [66]:
dataset

Dataset({
    features: ['content', 'sha1', 'id'],
    num_rows: 1289
})

In [67]:
print(f"Now running stage 3 filtering on {len(dataset)} examples...")

Now running stage 3 filtering on 1289 examples...


In [73]:
print(dataset[7]['content'])

package com.example.aichan2;

import java.util.ArrayList;

import android.app.Activity;
import android.content.Intent;
import android.speech.RecognizerIntent;
import android.util.Log;

public class Mike {
	/*
	 *  音声認識用インテントの発行
	 */
	public void start(Activity caller, int requestCode) {
		Intent intent = new Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH);
		intent.putExtra(RecognizerIntent.EXTRA_LANGUAGE_MODEL, RecognizerIntent.LANGUAGE_MODEL_FREE_FORM);
		intent.putExtra(RecognizerIntent.EXTRA_PROMPT, "なに？");
		
		caller.startActivityForResult(intent, requestCode);
	}
	
	/*
	 *  onActivityResultが受け取ったdataから文字列のArrayListを取り出す
	 */
	public ArrayList<String> getStringArrayList(Intent data) {
		return data.getStringArrayListExtra(RecognizerIntent.EXTRA_RESULTS);
	}
	
	/*
	 *  onActivityResultが受け取ったdataから一番最初の文字列を取り出す
	 */
	public String getString(Intent data) {
		ArrayList<String> result = data.getStringArrayListExtra(RecognizerIntent.EXTRA_RESULTS);
		for(int i=0; i<result.size(); i++)

In [74]:
# Define a dummy Java function
dummy_java = """
public class Dummy {
    /*
     * Dummy method for testing.
     */
    public void dummy() {
        // This is a dummy function
    }
}
"""

# Apply formatting
dummy_java_prompt = prompt_fmt_java(dummy_java)

# Get tokens
few_shot_toks = len(tokenizer.encode(dummy_java_prompt)) - len(tokenizer.encode(dummy_java))

# Output the token overhead
print(f"Few-shot prompt has {few_shot_toks} tokens")


Few-shot prompt has 2726 tokens


In [15]:
import subprocess
import tempfile
import signal
import hashlib
import os
from typing import List, Dict
from tqdm import tqdm
from tree_sitter_parser import LANGUAGE, global_parser


prompts = []
for ex in tqdm(dataset, total=len(dataset), desc="Generating prompts"):
    try:
        code = ex["content"]
        toks = len(tokenizer.encode(code)) + few_shot_toks
        if toks > 16380:
            print(f"Skipping example with {toks} tokens")
            # to skip, just add dummy prompt
            prompts.append(dummy_java_prompt)
            continue
        p = prompt_fmt_java(code)
        prompts.append(p)
    except:
        continue
# responses = []
# for chunk in tqdm(chunkify(prompts, 512), desc="Generating responses"):
#     outs = model.generate(chunk, SamplingParams(
#         temperature=0.0, stop="\n", max_tokens=5))
#     contents = [o.outputs[0].text for o in outs]
#     for c in contents:
#         yes_count = c.lower().count("yes")
#         no_count = c.lower().count("no")
#         if yes_count > no_count:
#             responses.append(True)
#         elif yes_count < no_count:
#             responses.append(False)
#         else:
#             # default to No
#             responses.append(False)



NameError: name 'dataset' is not defined

In [89]:
print(prompts[0])

<issue_start>username_0: I have a function in Java and I'd like someone to check my description of this function.
I'm doing this so that I can write a good docstring for this function.

Here is the code for the function:
```Java
public List<String> simpleScanNetwork() {
            
            String baseIp = "192.168.1.";
            List<String> addresses = new ArrayList<>();
            addresses.add("127.0.0.1");

            for (int index = 1; index < 255; index++) {
                addresses.add(baseIp + index);
            }

            return addresses;
        }
```

Here is my description of this program:
```
* Do a simple network scan, which only works if your network configuration
* is 192.168.1.x
```

Do not attempt to execute the function or to judge its correctness.
Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function.
Also, answer with "No" if the description does not match the function.<issue_comment>user

In [91]:
from tqdm import tqdm

# Initialize responses and prompts
responses = []
prompts = []

# Sampling parameters
sampling_params = {"temperature": 0.0, "stop": "\n", "max_tokens": 5}

# Ensure sampling_params is a valid dictionary
if sampling_params is None or not isinstance(sampling_params, dict):
    sampling_params = {"temperature": 0.0, "stop": "\n", "max_tokens": 5}

# Generate prompts and responses
for ex in tqdm(dataset, total=len(dataset), desc="Processing dataset"):
    try:
        # Generate prompt
        code = ex["content"]
        toks = len(tokenizer.encode(code)) + few_shot_toks
        if toks > 16380:
            print(f"Skipping example with {toks} tokens")
            # Add dummy prompt for skipped examples
            prompts.append(dummy_java_prompt)
            continue

        p = prompt_fmt_java(code)
        prompts.append(p)

        # Generate responses in chunks of 512 prompts
        for chunk in tqdm(chunkify(prompts, 512), desc="Generating responses"):
            if chunk is None:
                print("Warning: Encountered a None chunk. Skipping...")
                continue

            # Generate responses for the chunk
            outs = model.generate(chunk, sampling_params.copy())
            contents = [o.outputs[0].text for o in outs]

            for c in contents:
                if c is None:
                    print("Warning: Encountered None content. Treating as default response.")
                    responses.append(False)
                    continue

                # Count "yes" and "no" in the response
                yes_count = c.lower().count("yes")
                no_count = c.lower().count("no")

                if yes_count > no_count:
                    responses.append(True)
                elif yes_count < no_count:
                    responses.append(False)
                else:
                    # Default to No
                    responses.append(False)
    except Exception as e:
        # Log any unexpected errors and skip to the next example
        # print(f"Error processing example: {e}")
        prompts.append(dummy_java_prompt)
        responses.append(False)


Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]103.33it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]88.28it/s] 
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating responses:   0%|          | 0/1 [00:00<?, ?it/s]
Generating respon

In [9]:
ds.save_to_disk("../datasets/Dec17_Final_Step1_2")

Saving the dataset (0/1 shards):   0%|          | 0/1289 [00:00<?, ? examples/s]

In [10]:
ds = Dataset.load_from_disk("../datasets/Dec17_Final_Step1_2")

In [11]:
ds

Dataset({
    features: ['seed', 'sha1', 'id'],
    num_rows: 1289
})

In [14]:
print(ds['seed'][0])

import java.util.Arrays;

public class BestTimeBuyAndSellStock {
    //--------------  Solution 1 -------------------//
    // brute force (Time Limit Exceeded)
    public int maxProfit(int[] prices) {
        int N = prices.length;
        int res = 0;  // min, buy and sell on the same day
        for (int i = 0; i < N; i++) {
            for (int j = i; j < N; j++) {
                res = Math.max(res, prices[j] - prices[i]);
            }
        }
        return res;
    }

    //--------------  Solution 2  -------------------//
    // most intuitive solution
    public int maxProfit2(int[] prices) {
        // input validation
        if (prices == null || prices.length <= 1) {
            return 0;
        }

        // record the min up util now
        int min = Integer.MAX_VALUE;
        int res = 0;
        for (int i : prices) {
            min = Math.min(min, i);
            res = Math.max(res, i - min);
        }
        return res;
    }

    //--------------  Solution 3 

In [52]:
subset.save_to_disk("../datasets/Dec14_prompt")

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

In [50]:
new_ds = subset.filter(  # horrible hack!
    lambda ex, i: responses[i] and "def dummy()" not in ex["content"], with_indices=True)
print(f"Filtered {len(dataset) - len(new_ds)} examples")

Filter:   0%|          | 0/10 [00:00<?, ? examples/s]

Filtered 4156 examples
