In [111]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import json
from IPython.display import HTML, Markdown, display
import re

In [112]:
def format_doc(doc):
    # Try to split sections for better formatting
    doc = re.sub(r'(Parameters:)', r'<br/><br/>\1', doc)
    doc = re.sub(r'(Returns:)', r'<br/><br/>\1', doc)    
    doc = re.sub(r'(Example[s]?:)', r'<br/><br/>\1', doc)
    doc = re.sub(r'(@returns)', r'<br/><br/>\1', doc)
    doc = re.sub(r'(@see)', r'<br/><br/>\1', doc)
    doc = re.sub(r'(@param)', r'<br/><br/>\1', doc)
    doc = re.sub(r'(@since)', r'<br/><br/>\1', doc)
    doc = re.sub(r'(@deprecated)', r'<br/><br/>\1', doc)
    doc = re.sub(r'(>>> )', r'<br/>\1', doc)
    return doc

class DocumentationGenerator:
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)

    def generate_documentation(self, code, lang):
        prompt = (
                f'''Generate concise, formatted documentation with ONLY  description, parameters, and return type, AVOIDING hallucination and examples for the following {lang} code.

            Code:
            {code}

            Documentation:'''
            )
        inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512)
        outputs = self.model.generate(
            inputs.input_ids,
            max_new_tokens=128,
            num_beams=4,
            early_stopping=True,
            # temperature=0.7,
            # top_p=0.9,
            # repetition_penalty=1.4,  # discourage repetition
        )
        result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return result.split("Documentation")[-1].strip()
    
    def display_comparison(self, code, expected_doc, generated_doc_baseline, generated_doc_finetuned):
        html_content = f"""
        <h3>Code:</h3>
        <pre>{code}</pre>
        <h3>Expected Documentation:</h3>
        <pre>{expected_doc}</pre>
        <h3>Generated Documentation Baseline:</h3>
        <pre>{generated_doc_baseline}</pre>
        <h3>Generated Documentation Fine-tuned:</h3>
        <pre>{format_doc(generated_doc_finetuned)}</pre>
        """
        display(HTML(html_content))
    
    def display_comparison_md(self, code, expected_doc, generated_doc_baseline, generated_doc_finetuned):
        md = f"""
        ### Code
        ```python
        {code}
        ```
        ### Expected Documentation
        {expected_doc}

        ### Generated Documentation Baseline
        {generated_doc_baseline}

        ### Generated Documentation Fine-tuned
        {generated_doc_finetuned}
        """
        display(Markdown(md))
    

In [113]:
# json_data = [
#     {
#     "code": "protected void notifyReconnectionFailed(Exception exception) { if (isReconnectionAllowed()) { for (ConnectionListener listener : connection.connectionListeners) { listener.reconnectionFailed(exception); } } }",
#     "doc": "Fires listeners when a reconnection attempt has failed. @param exception the exception that occured.",
#     "lang": "python"
#   },
#   {
#     "code": "def write_parquet(cls, path, output, sc, partition_num = 1, bigdl_type=\"float\"): \"\"\" write ImageFrame as parquet file \"\"\" return callBigDlFunc(bigdl_type, \"writeParquet\", path, output, sc, partition_num)",
#     "doc": "write ImageFrame as parquet file",
#     "lang": "python"
#   },
#   {
#     "code": "def list_timezones(self): \"\"\"Return the list of all known timezones.\"\"\" from h2o.expr import ExprNode return h2o.H2OFrame._expr(expr=ExprNode(\"listTimeZones\"))._frame()",
#     "doc": "Return the list of all known timezones.",
#     "lang": "python"
#   },
#   {
#     "code": "def is_valid_email(email): \"\"\" Validates and email address. Note: valid emails must follow the <name>@<domain><.extension> patterns. \"\"\" try: validate_email(email) except ValidationError: return False if simple_email_re.match(email): return True return False",
#     "doc": "Validates and email address. Note: valid emails must follow the <name>@<domain><.extension> patterns.",
#     "lang": "python"
#   },
#   {
#     "code": "def collect_exceptions(rdict_or_list, method='unspecified'): elist = [] if isinstance(rdict_or_list, dict): rlist = rdict_or_list.values() else: rlist = rdict_or_list for r in rlist: if isinstance(r, RemoteError): en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info # Sometimes we could have CompositeError in our list. Just take # the errors out of them and put them in our new list. This # has the effect of flattening lists of CompositeErrors into one # CompositeError if en=='CompositeError': for e in ev.elist: elist.append(e) else: elist.append((en, ev, etb, ei)) if len(elist)==0: return rdict_or_list else: msg = \"one or more exceptions from call to method: %s\" % (method) # This silliness is needed so the debugger has access to the exception # instance (e in this case) try: raise CompositeError(msg, elist) except CompositeError as e: raise e",
#     "doc": "check a result dict for errors, and raise CompositeError if any exist. Passthrough otherwise.",
#     "lang": "python"
#   },
#   {
#     "code": "@CheckReturnValue @SchedulerSupport(SchedulerSupport.NONE) @SuppressWarnings(\"unchecked\") public static <T> Observable<T> empty() { return RxJavaPlugins.onAssembly((Observable<T>) ObservableEmpty.INSTANCE); }",
#     "doc": "Returns an Observable that emits no items to the {@link Observer} and immediately invokes its {@link Observer#onComplete onComplete} method. <p> <img width=\"640\" height=\"190\" src=\"https://raw.github.com/wiki/ReactiveX/RxJava/images/rx-operators/empty.png\" alt=\"\"> <dl> <dt><b>Scheduler:</b></dt> <dd>{@code empty} does not operate by default on a particular {@link Scheduler}.</dd> </dl> @param <T> the type of the items (ostensibly) emitted by the ObservableSource @return an Observable that emits no items to the {@link Observer} but immediately invokes the {@link Observer}'s {@link Observer#onComplete() onComplete} method @see <a href=\"http://reactivex.io/documentation/operators/empty-never-throw.html\">ReactiveX operators documentation: Empty</a>",
#     "lang": "java"
#   },
#   {
#     "code": "public static void runTaskCollection() throws Exception { Map<String, String> propMap = BenchmarkUtils .getSystemProperties(new String[] { Constants.TASK_CLASS_NAME, Constants.WARM_UPS, Constants.RUNS, Constants.TASK_PARAMS }); TaskCollection collection = TaskFactory.createTaskCollection(propMap .get(Constants.TASK_CLASS_NAME), BenchmarkUtils .getCommaSeparatedParameter(Constants.TASK_PARAMS)); RepeatEachTaskRunner runner = new RepeatEachTaskRunner(Integer.valueOf(propMap .get(Constants.WARM_UPS)), Integer.valueOf(propMap .get(Constants.RUNS))); runner.run(collection); }",
#     "doc": "/* A macro method, reads system properties, instantiates the task collection and runs it",
#     "lang": "java",
#   },
#   {
#     "code": "public static void dumpInstanceTaxomomyToFile( final InstanceTaxonomy<? extends ElkEntity, ? extends ElkEntity> taxonomy, final String fileName, final boolean addHash) throws IOException { final FileWriter fstream = new FileWriter(fileName); final BufferedWriter writer = new BufferedWriter(fstream); try { dumpInstanceTaxomomy(taxonomy, writer, addHash); } finally { writer.close(); } }",
#     "doc": "Convenience method for printing an {@link InstanceTaxonomy} to a file at the given location. @see org.semanticweb.elk.reasoner.taxonomy.TaxonomyPrinter#dumpInstanceTaxomomy @param taxonomy @param fileName @param addHash if true, a hash string will be added at the end of the output using comment syntax of OWL 2 Functional Style @throws IOException If an I/O error occurs",
#     "lang": "java"
#   },
#   {
#     "code": "public static void dumpTaxomomy( final Taxonomy<? extends ElkEntity> taxonomy, final Writer writer, final boolean addHash) throws IOException { writer.append(\"Ontology(\\n\"); processTaxomomy(taxonomy, writer); writer.append(\")\\n\"); if (addHash) { writer.append(\"\\n# Hash code: \" + getHashString(taxonomy) + \"\\n\"); } writer.flush(); }",
#     "doc": "Print the contents of the given {@link Taxonomy} to the specified Writer. Expressions are ordered for generating the output, ensuring that the output is deterministic. @param taxonomy @param writer @param addHash if true, a hash string will be added at the end of the output using comment syntax of OWL 2 Functional Style @throws IOException If an I/O error occurs",
#     "lang": "java"
#   },
#   {
#     "code": "public void commitAddedRules() { // commit changes in the context initialization rules ChainableContextInitRule nextContextInitRule; Chain<ChainableContextInitRule> contextInitRuleChain; nextContextInitRule = addedContextInitRules_; contextInitRuleChain = getContextInitRuleChain(); while (nextContextInitRule != null) { nextContextInitRule.addTo(contextInitRuleChain); nextContextInitRule = nextContextInitRule.next(); } // commit changes in rules for IndexedClassExpression ChainableSubsumerRule nextClassExpressionRule; Chain<ChainableSubsumerRule> classExpressionRuleChain; for (ModifiableIndexedClassExpression target : addedContextRuleHeadByClassExpressions_ .keySet()) { LOGGER_.trace(\"{}: committing context rule additions\", target); nextClassExpressionRule = addedContextRuleHeadByClassExpressions_ .get(target); classExpressionRuleChain = target.getCompositionRuleChain(); while (nextClassExpressionRule != null) { nextClassExpressionRule.addTo(classExpressionRuleChain); nextClassExpressionRule = nextClassExpressionRule.next(); } } for (ModifiableIndexedClass target : addedDefinitions_.keySet()) { ModifiableIndexedClassExpression definition = addedDefinitions_ .get(target); ElkAxiom reason = addedDefinitionReasons_.get(target); LOGGER_.trace(\"{}: committing definition addition {}\", target, definition); if (!target.setDefinition(definition, reason)) throw new ElkUnexpectedIndexingException(target); } initAdditions(); }",
#     "doc": "Commits the added rules to the main index and removes them from this {@link DifferentialIndex}.",
#     "lang": "java"
#   }
# ]

In [114]:
json_data = [
    {
      "code": "def greet(name, greeting='Hello'): return f'{greeting}, {name}!'",
      "doc": "Returns a greeting message for the given name. Parameters: name (str): The person's name. greeting (str, optional): The greeting to use. Returns: str: The greeting message.",
      "lang": "python"
    },
    {
      "code": "def safe_divide(a, b): return a / b if b != 0 else None",
      "doc": "Safely divides a by b. Returns None if b is zero. Parameters: a (float): Numerator. b (float): Denominator. Returns: float or None: The result or None if division by zero.",
      "lang": "python"
    },
    {
      "code": "def max_in_list(lst): return max(lst) if lst else None",
      "doc": "Returns the maximum value in a list. Returns None if the list is empty. Parameters: lst (list): List of numbers. Returns: number or None: The maximum value or None.",
      "lang": "python"
    },
    {
      "code": "def to_int(value): try: return int(value) except (ValueError, TypeError): return None",
      "doc": "Converts a value to integer. Returns None if conversion fails. Parameters: value: The value to convert. Returns: int or None: The integer value or None.",
      "lang": "python"
    },
    {
      "code": "def append_item(item, lst=None): if lst is None: lst = []; lst.append(item); return lst",
      "doc": "Appends an item to a list. Creates a new list if none is provided. Parameters: item: Item to append. lst (list, optional): List to append to. Returns: list: The updated list.",
      "lang": "python"
    },
    {
      "code": "public static String greet(String name) { return name == null ? \"Hello, Guest!\" : \"Hello, \" + name + \"!\"; }",
      "doc": "Returns a greeting message. If name is null, greets 'Guest'. Parameters: name (String): The person's name. Returns: String: The greeting message.",
      "lang": "java"
    },
    {
      "code": "public static Double safeDivide(double a, double b) { return b == 0 ? null : a / b; }",
      "doc": "Safely divides a by b. Returns null if b is zero. Parameters: a (double): Numerator. b (double): Denominator. Returns: Double: The result or null if division by zero.",
      "lang": "java"
    },
    {
      "code": "public static Integer maxInArray(int[] arr) { if (arr == null || arr.length == 0) return null; int max = arr[0]; for (int n : arr) if (n > max) max = n; return max; }",
      "doc": "Returns the maximum value in an array. Returns null if array is null or empty. Parameters: arr (int[]): Array of integers. Returns: Integer: The maximum value or null.",
      "lang": "java"
    },
    {
      "code": "public static Integer toInt(String value) { try { return Integer.parseInt(value); } catch (NumberFormatException e) { return null; } }",
      "doc": "Converts a string to integer. Returns null if parsing fails. Parameters: value (String): The value to convert. Returns: Integer: The integer value or null.",
      "lang": "java"
    },
    {
      "code": "public static List<String> appendItem(String item, List<String> list) { if (list == null) list = new ArrayList<>(); list.add(item); return list; }",
      "doc": "Appends an item to a list. Creates a new list if none is provided. Parameters: item (String): Item to append. list (List<String>, optional): List to append to. Returns: List<String>: The updated list.",
      "lang": "java"
    }
]

In [115]:
model_name = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
qwen_baseline = DocumentationGenerator(model_name)

model_name = "../../models/Qwen_Qwen2.5-Coder-0.5B-Instruct-Jarvislabs"
qwen_finetuned = DocumentationGenerator(model_name)

In [116]:
for item in json_data:
    code = item['code']
    lang = item['lang']
    expected_doc = item['doc']
    doc_baseline = qwen_baseline.generate_documentation(code, lang)
    doc_finetuned = qwen_finetuned.generate_documentation(code, lang)
    qwen_finetuned.display_comparison(code, expected_doc, doc_baseline, doc_finetuned)
    # qwen_finetuned.display_comparison_md(code, expected_doc, doc_baseline, doc_finetuned)

    # print(f"\n----------------------------------------------------------------")
    # print(f"\nCode: {code}")
    # print("\nExpected Documentation:")
    # print(expected_doc)
    # # display(HTML(expected_doc))    
    # print("\nGenerated Documentation:")
    # display(doc)
    # # doc = re.sub(r'(@\w+)', r'<br>\1', doc)
    # # display(HTML(doc))

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
