# Colab Initialization

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/radistoubalidis/JSRepair.git

!pip install pytorch_lightning
!python -m pip install lightning
!pip install datasets
!pip install python-dotenv
!pip install rouge-score

In [None]:
%cd ./JSRepair

# Training

In [1]:
from transformers import (
    RobertaTokenizer,
)
from modules.models import CodeT5
from modules.datasets import CodeT5Dataset
from modules.TrainConfig import init_logger, init_checkpoint, Trainer
from modules.filters import add_labels, bug_type_dist_query
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pytorch_lightning import Trainer as plTrainer
import json
import matplotlib.pyplot as plt
import os
import pandas as pd
import sqlite3
import torch
import numpy as np

In [3]:
HF_DIR = 'Salesforce/codet5-base'
TOKENIZER_MAX_LENGTH = 512 #int(input('Tokenizer Max length: '))
DB_PATH = 'commitpack-datasets.db' if os.path.exists('commitpack-datasets.db') else '/content/drive/MyDrive/Thesis/commitpack-datasets.db'
DB_TABLE = 'commitpackft_classified_train'
if not os.path.exists(DB_PATH):
    raise RuntimeError('sqlite3 path doesnt exist.')
VAL_SIZE = 0.3
LOG_PATH = 'logs' if os.path.exists('logs') else '/content/drive/MyDrive/Thesis/logs'
VERSION = int(input('Training version: '))
LOAD_FROM_CPKT = input("Load from existing model (type cpkt path if true): ")
DEBUG = True if int(input('Debug Run (1,0): ')) == 1 else False
BATCH_SIZE = 8 if DEBUG is True else 32
CPKT_PATH = 'checkpoints' if os.path.exists('checkpoints') else '/content/drive/MyDrive/Thesis/checkpoints'
DROPOUT_RATE = 0.125 #float(input('Type dropout rate for classifier: '))
WITH_MOBILE = False #True if int(input('Consider mobile class (1,0): ')) == 1 else False
WITH_LAYER_NORM = True
WITH_ACTIVATION = True

if WITH_MOBILE:
    classLabels = {
        "mobile" : 0.,
        "functionality" : 0.,
        "ui-ux" : 0.,
        "compatibility-performance" : 0.,
        "network-security" : 0.,
        "general": 0.
    }
else:
    classLabels = {
        "functionality" : 0.,
        "ui-ux" : 0.,
        "compatibility-performance" : 0.,
        "network-security" : 0.,
        "general": 0.
    }

num_classes = len(classLabels.keys())
modelSize = HF_DIR.split('-')[-1]
MODEL_DIR = f"CodeT5_{modelSize}_JS_{num_classes}classes_{TOKENIZER_MAX_LENGTH}MaxL"
con = sqlite3.connect(DB_PATH)

## Create Classification Labels

```json
{
    "mobile" : 0,
    "functionality" : 0,
    "ui-ux" : 0,
    "compatibility-performance" : 0,
    "network-security" : 0,
    "general": 0
}

Ένα δείγμα που κατηγοριοποιήθηκε ως σφάλμα λειτουργικότητας(functionality) και ui-ux θα έχει διάνυσμα ταξινόμησης ->
[0,1,1,0,0,0]
```


In [4]:
def load_ds() -> pd.DataFrame:
    query = f"select * from {DB_TABLE}"
    ds_df = pd.read_sql_query(query, con)
    return ds_df

ds_df = load_ds()

ds_df['class_labels'] = ds_df['bug_type'].apply(lambda bT: add_labels(bT.split(','), classLabels))
if DEBUG:
    ds_df = ds_df.sample(500)

if not WITH_MOBILE:
    ds_df = ds_df[ds_df['bug_type'] != 'mobile']

ds_df.head()

Unnamed: 0,index,commit,old_file,new_file,old_contents,new_contents,subject,message,lang,license,repos,processed_message,is_bug,bug_type,class_labels
1736,29792,09d77f68de6320ac509775b7604e247b721528b5,root/tasks/connect.js,root/tasks/connect.js,/*\n\nSets up a connect server to work from th...,/*\n\nSets up a connect server to work from th...,Fix a bug where query strings would 404,Fix a bug where query strings would 404\n,JavaScript,mit,"seattletimes/newsapp-template,seattletimes/new...",fix bug queri string would 404,1,network-security,"[0.0, 0.0, 0.0, 1.0, 0.0]"
4447,42212,7cf0a253819661c84e593d4a8dade96d1f4f253c,ghost/admin/controllers/forgotten.js,ghost/admin/controllers/forgotten.js,/* jshint unused: false */\r\nimport ajax from...,/* jshint unused: false */\r\nimport ajax from...,Stop validation error notification stack,Stop validation error notification stack\n\ncl...,JavaScript,mit,"TryGhost/Ghost,TryGhost/Ghost,TryGhost/Ghost",stop valid error notif stack close # 3383 call...,1,"ui-ux,compatibility-performance","[0.0, 1.0, 1.0, 0.0, 0.0]"
92,22687,28e8fcd5f08aca3bb0b3abb9663a241bfcf395ce,administrator/templates/hubbasicadmin/js/compo...,administrator/templates/hubbasicadmin/js/compo...,if (typeof(Joomla) == 'undefined')\n{\n\tJooml...,if (typeof(Joomla) == 'undefined')\n{\n\tJooml...,Fix for script that can cause infinite loops u...,[TPL_HUBBASICADMIN] Fix for script that can ca...,JavaScript,mit,"anthonyfuentes/hubzero-cms,zooley/hubzero-cms,...",tplhubbasicadmin fix script caus infinit loop ...,1,compatibility-performance,"[0.0, 0.0, 1.0, 0.0, 0.0]"
6839,20412,ef920cf60564622f99642092e67f82b5d24d0db2,src/map/entities/footstep.js,src/map/entities/footstep.js,var Footstep = Entity.extend({\n alpha: 0.5...,var Footstep = Entity.extend({\n alpha: 0.5...,"Fix graphical issue where they ""flicker"" when ...",[Content] Footstep: Fix graphical issue where ...,JavaScript,mit,burningtomatoes/CabinInTheSnow,content footstep fix graphic issu flicker fade...,1,ui-ux,"[0.0, 1.0, 0.0, 0.0, 0.0]"
9824,26661,85bd7255c3fc4b106f32d45b928c8abfb8fa7ea2,lib/cmd.js,lib/cmd.js,"""use strict"";\n\nconst commands = {\n\t\n\t/* ...","""use strict"";\n\nconst {exec} = require(""child...",Fix missing function reference in Make command,Fix missing function reference in Make command\n,JavaScript,isc,Alhadis/Atom-PhoenixTheme,fix miss function refer make command,1,functionality,"[1.0, 0.0, 0.0, 0.0, 0.0]"


## Filter out outlier samples

In [5]:
def count_comment_lines(sample: str) -> int:
    comment_blocks = []
    start_index = -1
    for i, line in enumerate(sample.splitlines()):
        if line.strip().startswith('/*'):
            start_index = i
        elif line.strip().endswith('*/'):
            comment_blocks.append([start_index, i])
            start_index = -1

    comment_lines_count = sum([c[1]-c[0] for c in comment_blocks])

    for i, line in enumerate(sample.splitlines()):
        if line.strip().startswith('//'):
            comment_lines_count += 1
    return comment_lines_count

ds_df['old_contents_comment_lines_count'] = ds_df['old_contents'].apply(lambda sample: count_comment_lines(sample))
ds_df['new_contents_comment_lines_count'] = ds_df['new_contents'].apply(lambda sample: count_comment_lines(sample))

# Filter out samples where the sum of comment lines increased more than 3 lines
# to prevent excessive masking 
ds_df = ds_df[abs(ds_df['old_contents_comment_lines_count'] - ds_df['new_contents_comment_lines_count']) <= 3]
print(f"Total training samples after filtering: {len(ds_df)}")

Total training samples after filtering: 489


## Concatenate Commit Message with the old contents 
- This way, the commit message is directly provided as additional context, and the models (T5, Bert) can process both the buggy code and the commit message in a unified manner.
- This approach will allow the model to learn the relationship between the commit message and the changes made to the code.

In [6]:
tokenizer = RobertaTokenizer.from_pretrained(HF_DIR)


old_codes = ds_df[['message', 'old_contents', 'class_labels']]
old_codes['input_seq'] = old_codes['message'] + ' ' + tokenizer.sep_token + ' ' + old_codes['old_contents']
new_codes = ds_df[['message', 'new_contents', 'class_labels']]
new_codes['output_seq'] = new_codes['message'] + '  ' + tokenizer.sep_token + ' ' + new_codes['new_contents']

TRAIN_old, VAL_old, TRAIN_new, VAL_new = train_test_split(old_codes, new_codes, test_size=VAL_SIZE, random_state=42)

print(f"Total training samples: {len(ds_df)}")

Total training samples: 489


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  old_codes['input_seq'] = old_codes['message'] + ' ' + tokenizer.sep_token + ' ' + old_codes['old_contents']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_codes['output_seq'] = new_codes['message'] + '  ' + tokenizer.sep_token + ' ' + new_codes['new_contents']


## Types of Bugs distribution in samples

In [6]:
query = bug_type_dist_query(WITH_MOBILE, table='commitpackft_classified_train')

info_df = pd.read_sql_query(query, con)
info_df

Unnamed: 0,count(*),bug_type
0,2862,general
1,3147,ui-ux
2,3159,network-security
3,4396,compatibility-performance
4,4532,functionality


### Dataset

In [7]:
TRAIN_encodings = tokenizer(
    TRAIN_old['input_seq'].tolist(),
    max_length=TOKENIZER_MAX_LENGTH,
    pad_to_max_length=True,
    return_tensors='pt',
    padding='max_length',
    truncation=True
)

VAL_encodings = tokenizer(
    VAL_old['input_seq'].tolist(),
    max_length=TOKENIZER_MAX_LENGTH,
    pad_to_max_length=True,
    return_tensors='pt',
    padding='max_length',
    truncation=True
)

TRAIN_decodings = tokenizer(
    TRAIN_new['output_seq'].tolist(),
    max_length=TOKENIZER_MAX_LENGTH,
    pad_to_max_length=True,
    return_tensors='pt',
    padding='max_length',
    truncation=True
)

VAL_decodings = tokenizer(
    VAL_new['output_seq'].tolist(),
    max_length=TOKENIZER_MAX_LENGTH,
    pad_to_max_length=True,
    return_tensors='pt',
    padding='max_length',
    truncation=True
)

### Convert Class Labels into tensors

In [8]:
TRAIN_classes = torch.tensor(TRAIN_old['class_labels'].tolist())
VAL_classes = torch.tensor(VAL_old['class_labels'].tolist())
TRAIN_classes

tensor([[0., 1., 0., 1., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0.],
        ...,
        [0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.]])

#### Compute class weights
$pos\ weight[i] = (Number\ of\ negative\ samples\ for\ class\ i) / (Number\ of\ positive\ samples\ for\ class\ i)$

In [9]:
num_samples = TRAIN_classes.size(0)
num_classes = TRAIN_classes.size(1)

pos_counts = torch.sum(TRAIN_classes, dim=0)
neg_counts = num_samples - pos_counts
class_weights = neg_counts / (pos_counts + 1e-6)
class_weights = class_weights.numpy()

## Initialize Training Settings

In [10]:
logger = init_logger(log_path=LOG_PATH, model_dir=MODEL_DIR, version=VERSION)
checkpoint = init_checkpoint(cpkt_path=CPKT_PATH, model_dir=MODEL_DIR, version=VERSION)
trainer = Trainer(checkpoint,logger,debug=DEBUG, num_epochs=5)

if len(LOAD_FROM_CPKT) > 0 and  os.path.exists(LOAD_FROM_CPKT):
    model = CodeT5.load_from_checkpoint(
        LOAD_FROM_CPKT, 
        class_weights=class_weights, 
        num_classes=num_classes,
        dropout_rate=DROPOUT_RATE,
        with_layer_norm=WITH_LAYER_NORM,
        with_activation=WITH_ACTIVATION
    )
else:
    model = CodeT5(
        class_weights=class_weights, 
        num_classes=num_classes, 
        dropout_rate=DROPOUT_RATE,
        with_layer_norm=WITH_LAYER_NORM,
        with_activation=WITH_ACTIVATION
    )
model.model.train()

TRAIN_dataset = CodeT5Dataset(TRAIN_encodings, TRAIN_decodings, TRAIN_classes)
VAL_dataset = CodeT5Dataset(VAL_encodings, VAL_decodings, VAL_classes)
from transformers import default_data_collator
dataloader = DataLoader(TRAIN_dataset, batch_size=BATCH_SIZE,num_workers=14, shuffle=True, collate_fn=default_data_collator)
val_dataloader = DataLoader(VAL_dataset, batch_size=1, num_workers=14, collate_fn=default_data_collator)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


#### Run Training

In [11]:
trainer.fit(
    model,
    train_dataloaders=dataloader,
    val_dataloaders=val_dataloader
)

/home/disras/miniconda3/envs/thesis/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/disras/projects/JSRepair/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                       | Params | Mode 
--------------------------------------------------------------------
0 | model        | T5ForConditionalGeneration | 222 M  | train
1 | layer_norm   | LayerNorm                  | 1.5 K  | train
2 | hidden_layer | Linear                     | 295 K  | train
3 | activation   | ReLU                       | 0      | train
4 | dropout      | Dropout                    | 0      | train
5 | classifier   | Linear                     | 3.8 K  | train
--------------------------------------------------------------------
223 M     Trainable params
0         Non-trainable params
223 M     Total params
892.731   Total estimated model params size (MB)
546       Modules in train mode
0         

Training: |          | 0/? [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


#### Save Model Config to CSV 

In [None]:
modelConfigsCSV = f"/content/drive/MyDrive/Thesis/model-configs.csv"
if os.path.exists(modelConfigsCSV):
    modelConfig = {
        'name': MODEL_DIR,
        'tokenizer_max_length': TOKENIZER_MAX_LENGTH,
        'num_classes': num_classes,
        'dropout_rate': DROPOUT_RATE,
        'with_activation': WITH_ACTIVATION,
        'with_layer_norm': WITH_LAYER_NORM
    }
    modelConfig_df = pd.DataFrame([modelConfig])
    modelConfig_df.to_csv(modelConfigsCSV, mode='a', index=False, header=False)

# Testing

## Load Test Dataset

In [None]:
test_df = pd.read_sql_query('select * from commitpackft_classified_test', con)
test_df['class_labels'] = ds_df['bug_type'].apply(lambda bT: add_labels(bT, classLabels))
if WITH_MOBILE:
    test_df = test_df[test_df['bug_type'] != 'mobile']

test_df['input_seq'] = test_df['message'] + ' ' + tokenizer.sep_token + ' ' + test_df['old_contents']

if DEBUG:
    test_df = test_df.iloc[:10]

test_df.head()

## Bug Type Distribution in Test Dataset

In [None]:
with open('bug-type-dist-query_test.sql', 'r')as f:
    distQuery = f.read()
f.close()
info_df = pd.read_sql_query(distQuery, con)
info_df

In [None]:
TEST_classes = torch.tensor(ds_df['class_labels'].tolist())
num_samples = TEST_classes.size(0)
num_classes = TEST_classes.size(1)

pos_counts = torch.sum(TEST_classes, dim=0)
neg_counts = num_samples - pos_counts
class_weights = neg_counts / (pos_counts + 1e-6)
class_weights = class_weights.numpy()
class_weights

## Tokenize Data

In [None]:
encoded_samples = model.tokenizer(
    test_df['input_seq'].tolist(),
    max_length=TOKENIZER_MAX_LENGTH,
    padding='max_length',
    truncation=True,
    return_tensors='pt',
)

encoded_labels = model.tokenizer(
    test_df['new_contents'].tolist(),
    max_length=TOKENIZER_MAX_LENGTH,
    padding='max_length',
    truncation=True,
    return_tensors='pt',
)

labels = torch.tensor(ds_df['class_labels'].tolist())

## Testing Script

In [None]:
METRICS_PATH = 'metrics' if os.path.exists('metrics') else '/content/drive/MyDrive/Thesis/metrics'
os.environ['METRICS_PATH'] = METRICS_PATH
os.environ['VERSION'] = str(VERSION)
MODEL_NAME = 'CodeT5'
os.environ['MODEL_NAME'] = MODEL_NAME


torch_ds = CodeT5Dataset(encodings=encoded_samples, decodings=encoded_labels, class_labels=labels)
loader = DataLoader(torch_ds, batch_size=1, num_workers=14)

trainer = plTrainer()
trainer.test(model=model, dataloaders=loader)

## Compute Metrics

**ROUGE (Recall-Oriented understudy for Gisting Evaluation**
- A metric for evaluation text generation/sumamrization models.
- It measures the overlap between machine generated text (prediction) and its human generated corresponding text (reference)\ 
- [0,1] { close to 0: poor similarity, close to 1: better similarity}
- n-gram: seq of n words

Variations
- ROUGE-N : μετράει το σύνολο της επικάλυψης *[πόσες φορές εμφανίζετε στο παραγώμενο κείμενο]* το n-gram μεταξύ των προβλέψεων και του πραγματικού κειμένου

- ROUGE-N_recall : num n gram matches / num of n-gram in reference
- ROUGE-N-precision : nummber of n-gram matches / number of n gram in prediction
- ROUGE-L : Βασίζεται στο μάκρος του μεγαλύτερης κοινής υπό-ακολουθίας (Longest Common Sequence -LCS) . Υπολογίζει το μέτρο f-measure
    - ROUGE-L_recall : LCS / num words in reference
    - ROUGE-L_precision : LCS / num words in prediction

In [None]:
from modules.metrics import CodeRouge
import json

rouge = CodeRouge(['rouge7','rouge8','rouge9','rougeL','rougeLsum'])

rouge.compute(predictions=model.generated_codes, references=test_df['new_contents'].tolist())
rouge.calc_averages()

avgs_path = f"{METRICS_PATH}/{MODEL_NAME}_v{VERSION}/rouge.json"
all_path = f"{METRICS_PATH}/{MODEL_NAME}_v{VERSION}/avg_rouge.csv"
with open(avgs_path, 'a') as f:
    json.dump(rouge.avgs, f, indent=4)

all_scores = []
for r in rouge.rouge_types:
    all_scores += rouge.rouge_type_to_list(r)

metrics_df = pd.DataFrame(all_scores)

for m in ['precision','recall','fmeasure']:
    metrics_df[m] = round(metrics_df[m], 3)
metrics_df.to_csv(all_path, index=False)

## Model Comparisons

### Bar Plots

In [None]:
codebert_avgs = rouge.avgs

comparison_model_path = input('Comparison model avg ROUGE-N metrics path: ')
comparison_model = comparison_model_path.split('/')[-2]
if not os.path.exists(comparison_model_path):
    raise RuntimeError('Metrics path does not exist.')

with open(comparison_model_path, 'r') as f:
    codet5_avgs = json.load(f)


plot_data = {
    f"{MODEL_NAME}_{VERSION}": (round(codebert_avgs['avg_rouge7'].fmeasure, 5), round(codebert_avgs['avg_rouge8'].fmeasure, 5), round(codebert_avgs['avg_rouge9'].fmeasure, 5), round(codebert_avgs['avg_rougeL'].fmeasure, 5), round(codebert_avgs['avg_rougeLsum'].fmeasure, 5)),
    comparison_model: (round(codet5_avgs['avg_rouge7'][2], 5), round(codet5_avgs['avg_rouge8'][2], 5), round(codet5_avgs['avg_rouge9'][2], 5), round(codet5_avgs['avg_rougeL'][2], 5), round(codet5_avgs['avg_rougeLsum'][2], 5)),
}

metric_types = ('Rouge-7', 'Rouge-8','Rouge-9', 'Rouge-L', 'Rouge-Lsum')
x = np.arange(len(metric_types))
width = 0.15
multiplier = 0

fix, ax = plt.subplots(layout='constrained')


for model, values in plot_data.items():
    offset = width * multiplier
    rects = ax.bar(x + offset, values, width, label=model)
    ax.bar_label(rects, padding=3)
    multiplier += 1

ax.set_ylabel('Score')
ax.set_title('F-Measure Model Comparison')
ax.set_xticks(x + width, metric_types)
ax.legend(loc='upper left', ncols=4)
ax.set_ylim(0, 1.2)

plt.savefig(f"{METRICS_PATH}/{MODEL_NAME}_{VERSION}_vs_{comparison_model}.png", dpi=300, bbox_inches='tight')
plt.show()

### Chart

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

codebert_avgs = rouge.avgs  # Assuming rouge is a library/function that provides average scores

comparison_model_path = input('Comparison model avg ROUGE-N metrics path: ')
comparison_model = comparison_model = comparison_model_path.split('/')[-2]
if not os.path.exists(comparison_model_path):
    raise RuntimeError('Metrics path does not exist.')

with open(comparison_model_path, 'r') as f:
    codet5_avgs = json.load(f)

# Define metric types (assuming same metrics for both models)
metric_types = ('Rouge-7', 'Rouge-8', 'Rouge-9', 'Rouge-L', 'Rouge-Lsum')

# Create a figure with 3 rows (subplots) and 1 column
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(20, 16))

# Data dictionaries for each metric (assuming data structure from rouge)
precision_data = {
    f"{MODEL_NAME}_{VERSION}": (codebert_avgs['avg_rouge7'].precision, codebert_avgs['avg_rouge8'].precision, codebert_avgs['avg_rouge9'].precision, codebert_avgs['avg_rougeL'].precision, codebert_avgs['avg_rougeLsum'].precision),
    comparison_model: (codet5_avgs['avg_rouge7'][0], codet5_avgs['avg_rouge8'][0], codet5_avgs['avg_rouge9'][0], codet5_avgs['avg_rougeL'][0], codet5_avgs['avg_rougeLsum'][0]),
}
recall_data = {
    f"{MODEL_NAME}_{VERSION}": (codebert_avgs['avg_rouge7'].recall, codebert_avgs['avg_rouge8'].recall, codebert_avgs['avg_rouge9'].recall, codebert_avgs['avg_rougeL'].recall, codebert_avgs['avg_rougeLsum'].recall),
    comparison_model: (codet5_avgs['avg_rouge7'][1], codet5_avgs['avg_rouge8'][1], codet5_avgs['avg_rouge9'][1], codet5_avgs['avg_rougeL'][1], codet5_avgs['avg_rougeLsum'][1]),
}
f1_data = {
    f"{MODEL_NAME}_{VERSION}": (codebert_avgs['avg_rouge7'].fmeasure, codebert_avgs['avg_rouge8'].fmeasure, codebert_avgs['avg_rouge9'].fmeasure, codebert_avgs['avg_rougeL'].fmeasure, codebert_avgs['avg_rougeLsum'].fmeasure),
    comparison_model: (round(codet5_avgs['avg_rouge7'][2], 5), round(codet5_avgs['avg_rouge8'][2], 5), round(codet5_avgs['avg_rouge9'][2], 5), round(codet5_avgs['avg_rougeL'][2], 5), round(codet5_avgs['avg_rougeLsum'][2], 5)),
}


# Plot Precision (ax1)
for model, precision in precision_data.items():
    ax1.plot(metric_types, precision, label=model, marker='s')  # 's' for square marker
ax1.set_xlabel('ROUGE-N')
ax1.set_ylabel('Precision')
ax1.grid(True)

# Plot Recall (ax2)
for model, recall in recall_data.items():
    ax2.plot(metric_types, recall, label=model, marker='s')  # 'o' for circle marker
ax2.set_xlabel('ROUGE-N')
ax2.set_ylabel('Recall')
ax2.grid(True)

# Plot F1 Score (ax3)
for model, f1 in f1_data.items():
    ax3.plot(metric_types, f1, label=model, marker='s')
ax3.set_xlabel('ROUGE-N')
ax3.set_ylabel('F-measure')
ax3.grid(True)

plt.legend(loc='upper left')
plt.tight_layout()

# Save the entire figure as a single PNG
plt.savefig(f"{METRICS_PATH}/{MODEL_NAME}_{VERSION}_vs_{comparison_model}.png", dpi=300, bbox_inches='tight')
ax