In [1]:
import pandas as pd
import torch
from transformers import BertTokenizer
from torch.utils.data import DataLoader, SequentialSampler
from pys.functions import CustomBertModel, create_dataset, test
from pys.data import filtered_labels_at_least_5_list
from pys.params import batch_size

In [2]:
batch_size = batch_size
labels_list = filtered_labels_at_least_5_list
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
label_mapping = {label: idx for idx, label in enumerate(filtered_labels_at_least_5_list)}
reverse_label_mapping = {v: k for k, v in label_mapping.items()}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
cve_data_csv = "../csv/cve_data.csv"
output_csv = "../csv/cve_data_predictions.csv"

cve_data_df = pd.read_csv(cve_data_csv)
cve_data_df = cve_data_df[cve_data_df['Artifact Id'] != 'd3f:System Software']

In [8]:
def cve_test(model_path):
    _model = CustomBertModel(num_labels=len(filtered_labels_at_least_5_list))

    state_dict = torch.load(model_path, weights_only=True)
    _model.load_state_dict(state_dict)
    _model.to(device)
    _model.eval()

    dataset = create_dataset(cve_data_df, tokenizer, label_mapping)
    data_loader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=batch_size)

    predictions, _ = test(_model, data_loader, device)
    predictions = [reverse_label_mapping[pred] for pred in predictions]


    prediction_column_name = f"{model_path} prediction"

    cve_data_df[prediction_column_name] = predictions


    # Save the predictions to CSV
    cve_data_df[['Artifact Id', prediction_column_name]].to_csv(output_csv, index=False)
    print(f"Predictions saved to {output_csv}")

In [10]:
model_path = "../models/model_2024-12-16_14-53-35.pth"
cve_test(model_path)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

F1 Score (Weighted): 0.2737
Accuracy: 19.44%
Predictions saved to cve_data_predictions.csv
