In [51]:
import pandas as pd
import time
from setfit import SetFitModel, SetFitTrainer
from datasets import Dataset, DatasetDict, load_dataset
from tqdm.auto import tqdm
import numpy as np
import torch

tqdm.pandas()

In [52]:
# dataset from hf_hub
langs = ['java', 'python', 'pharo']
labels = {
    'java': ['summary', 'Ownership', 'Expand', 'usage', 'Pointer', 'deprecation', 'rational'],
    'python': ['Usage', 'Parameters', 'DevelopmentNotes', 'Expand', 'Summary'],
    'pharo': ['Keyimplementationpoints', 'Example', 'Responsibilities', 'Classreferences', 'Intent', 'Keymessages', 'Collaborators']
}
ds = load_dataset('NLBSE/nlbse25-code-comment-classification')
ds

DatasetDict({
    java_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 7614
    })
    java_test: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 1725
    })
    python_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 1884
    })
    python_test: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 406
    })
    pharo_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 1298
    })
    pharo_test: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels'],
        num_rows: 289
    })
})

In [53]:
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer

In [55]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
mlb = MultiLabelBinarizer()

In [56]:
# Tokenize sentences and encode labels for each split
def preprocess(examples):
    # Tokenize the 'combo' column (sentence + class name)
    tokenized = tokenizer(
        examples['combo'], padding="max_length", truncation=True, max_length=128, return_tensors="np"
    )
    
    # Convert one-hot encoded labels to binary labels
    binary_labels = mlb.fit_transform(examples['labels'])
    
    # Return tokenized inputs and binary labels
    return {
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'binary_labels': binary_labels
    }

In [57]:
for lang in langs:
    print(f"Processing {lang}")
    train_split = f"{lang}_train"
    test_split = f"{lang}_test"
    
    # Preprocess training and test splits
    ds[train_split] = ds[train_split].map(preprocess, batched=True)
    ds[test_split] = ds[test_split].map(preprocess, batched=True)

Processing java
Processing python


Map:   0%|          | 0/1884 [00:00<?, ? examples/s]

Processing pharo


In [58]:
ds

DatasetDict({
    java_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels', 'input_ids', 'attention_mask', 'binary_labels'],
        num_rows: 7614
    })
    java_test: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels', 'input_ids', 'attention_mask', 'binary_labels'],
        num_rows: 1725
    })
    python_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels', 'input_ids', 'attention_mask', 'binary_labels'],
        num_rows: 1884
    })
    python_test: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels', 'input_ids', 'attention_mask', 'binary_labels'],
        num_rows: 406
    })
    pharo_train: Dataset({
        features: ['index', 'class', 'comment_sentence', 'partition', 'combo', 'labels', 'input_ids', 'attention_mask', 'binary_labels'],
        num_rows: 1298
    })
    pharo_test

In [None]:

for lan in langs:
    #model = SetFitModel.from_pretrained("paraphrase-MiniLM-L3-v2", multi_target_strategy="multi-output",device='cuda')
    model = SetFitModel.from_pretrained("paraphrase-MiniLM-L3-v2", multi_target_strategy="multi-output",device='cpu')
    trainer = SetFitTrainer(
        model=model,
        train_dataset=ds[f'{lan}_train'],
        eval_dataset=ds[f'{lan}_test'],
        column_mapping={"combo": "text", "binary_labels": "label"} ,
        num_epochs=5 if lan == 'java' else 10,
        batch_size=32,
    )
    trainer.train()
    trainer.model.save_pretrained(f'./models/{lan}')

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
  trainer = SetFitTrainer(
Applying column mapping to the training dataset
Applying column mapping to the evaluation dataset


AttributeError: 'CallbackHandler' object has no attribute 'tokenizer'

In [62]:
total_flops = 0
total_time = 0
scores = []
for lan in langs:
    # to load trained models:
    # model = SetFitModel.from_pretrained(f'./models/{lan}')
    # to load pretrained models from Hub:
    model = SetFitModel.from_pretrained(f"NLBSE/nlbse25_{lan}")
    with torch.profiler.profile(with_flops=True) as p:
        begin = time.time()
        for i in range(10):
          y_pred = model(ds[f'{lan}_test']['combo']).numpy().T
        total = time.time() - begin
        total_time = total_time + total
    total_flops = total_flops + (sum(k.flops for k in p.key_averages()) / 1e9)
    y_true = np.array(ds[f'{lan}_test']['labels']).T
    for i in range(len(y_pred)):
        assert(len(y_pred[i]) == len(y_true[i]))
        tp = sum([true == pred == 1 for (true,pred) in zip(y_true[i], y_pred[i])])
        tn = sum([true == pred == 0 for (true,pred) in zip(y_true[i], y_pred[i])])
        fp = sum([true == 0 and pred == 1 for (true,pred) in zip(y_true[i], y_pred[i])])
        fn = sum([true == 1 and pred == 0 for (true,pred) in zip(y_true[i], y_pred[i])])
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = (2*tp) / (2*tp + fp + fn)
        scores.append({'lan': lan, 'cat': labels[lan][i],'precision': precision,'recall': recall,'f1': f1})
print("Compute in GFLOPs:", total_flops/10)
print("Avg runtime in seconds:", total_time/10)
generated_scores = pd.DataFrame(scores)
generated_scores

Compute in GFLOPs: 26.670382080000003
Avg runtime in seconds: 3.1714226961135865


Unnamed: 0,lan,cat,precision,recall,f1
0,java,summary,0.87839433,0.83408072,0.85566417
1,java,Ownership,1.0,1.0,1.0
2,java,Expand,0.32352941,0.43137255,0.3697479
3,java,usage,0.9250646,0.83062645,0.87530562
4,java,Pointer,0.79017857,0.96195652,0.86764706
5,java,deprecation,0.81818182,0.6,0.69230769
6,java,rational,0.17647059,0.30882353,0.22459893
7,python,Usage,0.7007874,0.73553719,0.71774194
8,python,Parameters,0.79389313,0.8125,0.8030888
9,python,DevelopmentNotes,0.24390244,0.48780488,0.32520325


In [63]:
baseline_results = pd.read_csv("baseline_results_summary.csv")
baseline_results

Unnamed: 0,index,lan,cat,precision,recall,f1
0,0,java,summary,0.87338501,0.82944785,0.85084959
1,1,java,Ownership,1.0,1.0,1.0
2,2,java,Expand,0.32352941,0.44444444,0.37446809
3,3,java,usage,0.91104294,0.81818182,0.86211901
4,4,java,Pointer,0.73825503,0.94017094,0.82706767
5,5,java,deprecation,0.81818182,0.6,0.69230769
6,6,java,rational,0.16216216,0.29508197,0.20930233
7,7,python,Usage,0.7007874,0.73553719,0.71774194
8,8,python,Parameters,0.79389313,0.8125,0.8030888
9,9,python,DevelopmentNotes,0.24390244,0.48780488,0.32520325


In [64]:
pd.options.display.float_format = '{:.8f}'.format

In [65]:
# Merge generated and baseline scores
comparison = pd.merge(
    generated_scores, 
    baseline_results, 
    on=['lan', 'cat'], 
    suffixes=('_generated', '_baseline')
)

# Calculate metric differences
comparison['precision_diff'] = comparison['precision_generated'] - comparison['precision_baseline']
comparison['recall_diff'] = comparison['recall_generated'] - comparison['recall_baseline']
comparison['f1_diff'] = comparison['f1_generated'] - comparison['f1_baseline']

# Display comparison with differences
comparison[['lan', 'cat', 'precision_diff', 'recall_diff', 'f1_diff']]

Unnamed: 0,lan,cat,precision_diff,recall_diff,f1_diff
0,java,summary,0.00500932,0.00463286,0.00481458
1,java,Ownership,0.0,0.0,0.0
2,java,Expand,0.0,-0.0130719,-0.00472019
3,java,usage,0.01402165,0.01244463,0.01318661
4,java,Pointer,0.05192354,0.02178558,0.04057939
5,java,deprecation,0.0,0.0,0.0
6,java,rational,0.01430843,0.01374156,0.0152966
7,python,Usage,0.0,0.0,0.0
8,python,Parameters,0.0,0.0,0.0
9,python,DevelopmentNotes,0.0,0.0,0.0
