In [1]:
import argparse
import os
import time
import pathlib
import sys

from dataclasses import dataclass, field, asdict
from typing import Optional, List, Tuple, Union

import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoTokenizer, Trainer, DataCollatorWithPadding, TrainingArguments, \
    EvalPrediction, HfArgumentParser

# _project_root = str(pathlib.Path(__file__).resolve().parents[1])
_project_root = '.'
sys.path.insert(0, _project_root)

from dataset.classifier_dataset import ClassifierDataSet
from model.model import BertFNNClassifier
from model.focal_model import FocalBertFNNClassifier
from model.customized_loss import FocalLoss
from model.bert_fnn import BasicBertFNN

@dataclass
class ModelArguments:
    base_model_name: str = field(default="Bert")
    base_model_path: Optional[str] = field(default='/data/chentianyu/libminer/input/bert-base-uncased_raw')
    evaluate_only: bool = field(default=False)
    to_evaluate_checkpoint: str = field(default=None)
    loss_fct: str = field(default="LabelSmoothingLoss")
    loss_epsilon: float = field(default=0)
    loss_alpha: float = field(default=0.95)
    loss_gamma: float = field(default=2)
    num_labels: float = field(default=1)


  from .autonotebook import tqdm as notebook_tqdm


[2023-11-09 11:28:00,883] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
parser = HfArgumentParser(ModelArguments)
import os, sys
sys.argv = ['f1.py']
model_args = parser.parse_args_into_dataclasses()[0]


loss_args = {
    "epsilon": model_args.loss_epsilon,
    "alpha": model_args.loss_alpha,
    "gamma": model_args.loss_gamma
}

loss_fct = FocalLoss(**loss_args)

from_pretrained_args = {
    "pretrained_model_name_or_path": model_args.base_model_path,
    "use_cache": True,
    "num_labels": 1,
    "torch_dtype": torch.float32,
}


parser = argparse.ArgumentParser(description="train bert fnn",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--data_url', type=str, help='the training data', default="/data/chentianyu/libminer/input/")
parser.add_argument('--train_url', type=str, help='the path model saved', default="/efs_data/chentianyu/VulLibMiner/")
parser.add_argument('--sep_token', type=str, help='sep token of lib corpus', default=" ")
parser.add_argument('--mask_rate', type=float, help='rate of mask lib corpus', default=0)

args, _ = parser.parse_known_args()
input_path = args.data_url
output_path = args.train_url
bert_base_path = os.path.join(input_path, "bert-base-uncased_raw")


In [37]:
model = BertForSequenceClassification.from_pretrained('/data/chentianyu/libminer/input/bert-base-uncased_raw/')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /data/chentianyu/libminer/input/bert-base-uncased_raw/ and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [78]:
model = BasicBertFNN.from_pretrained('/data/chentianyu/libminer/input/bert-base-uncased_raw/')

Some weights of BasicBertFNN were not initialized from the model checkpoint at /data/chentianyu/libminer/input/bert-base-uncased_raw/ and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [95]:
del sys.modules['model.model']
del sys.modules['dataset.classifier_dataset']
# BertFNNClassifier??

In [3]:
model = BertFNNClassifier.from_pretrained(**from_pretrained_args)

Some weights of BertFNNClassifier were not initialized from the model checkpoint at /data/chentianyu/libminer/input/bert-base-uncased_raw and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [41]:
type(model)

transformers.models.bert.modeling_bert.BertForSequenceClassification

In [5]:
data_base_path = os.path.join(input_path, "dataset_v1_4")
train_data_set = ClassifierDataSet(os.path.join(data_base_path, "train.json"),
                                   args.sep_token, args.mask_rate, bert_base_path, (128, 80, 128, 256))
valid_data_set = ClassifierDataSet(os.path.join(data_base_path, "validate.json"),
                                   args.sep_token, args.mask_rate, bert_base_path, (128, 80, 128, 256))

In [None]:
test_data_set = ClassifierDataSet(os.path.join(data_base_path, "test.json"),
                                   args.sep_token, args.mask_rate, bert_base_path, (2048, 80, 128, 256))

In [22]:
data_base_path = os.path.join(input_path, "dataset_v1_4")
test_data_set = ClassifierDataSet(os.path.join(data_base_path, "test_to_4096_new.json"),
                                   args.sep_token, args.mask_rate, bert_base_path, (2048, 80, 128, 256))

In [5]:
sum([item['labels'] for item in train_data_set])

tensor(1973)

In [5]:
from sklearn.metrics import roc_auc_score, precision_recall_curve, matthews_corrcoef, accuracy_score
import numpy as np

def sigmoid(z: np.ndarray):
    return 1 / (1 + np.exp(-z))

def np_divide(a, b):
    # return 0 with divide by 0 while performing element wise divide for np.arrays
    return np.divide(a, b, out=np.zeros_like(a), where=(b != 0))

def f1_score(p, r):
    if p + r < 1e-5:
        return 0.0
    return 2 * p * r / (p + r)

def modified_topk(predictions, labels, k):
    rerank = [item for item in zip(predictions, labels)]
    rerank.sort(reverse=True)

    if sum(labels) == 0:
        return None
    hits = sum([item[1] for item in rerank[:k]])
    prec = hits / min(k, sum(labels))
    rec = hits / sum(labels)
#     print(prec, rec)
    return prec, rec, f1_score(prec, rec)

def compute_metrics(eval_pred: EvalPrediction):
    scores = sigmoid(eval_pred.predictions.reshape(-1, ))
    labels = eval_pred.label_ids.reshape(-1, )
    precision, recall, thresholds = precision_recall_curve(labels, scores)

    # while computing f1 = (2 * precision * recall) / (precision + recall), some element in (precision+recall) will be 0
    f1 = np_divide(2 * precision * recall, precision + recall)
    f1_idx = np.argmax(f1)
    f1_best = f1[f1_idx]
    # mcc = np.array([matthews_corrcoef(labels, (scores >= threshold).astype(int)) for threshold in thresholds])
    # mcc_idx = np.argmax(mcc)
    # mcc_best = mcc[mcc_idx]

    # acc = np.array([accuracy_score(labels, (scores >= threshold).astype(int)) for threshold in thresholds])
    # acc_idx = np.argmax(acc)
    # acc_best = acc[acc_idx]

    auc = roc_auc_score(y_true=labels, y_score=scores)

    k = 1
    res = [modified_topk(predictions, labels, k) for predictions, labels in \
               zip(eval_pred.predictions.reshape(-1, 128), eval_pred.label_ids.reshape(-1, 128))]
    res = [item for item in res if item != None]
    ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
    print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

    return {"auc": auc,
            # "accuracy_best": acc_best,
            # "accuracy_threshold": thresholds[acc_idx],
            "f1_best": f1_best,
            "f1_threshold": thresholds[f1_idx],
            "precision_1": sum(ps) / len(ps),
            "recall_1": sum(rs) / len(rs),
            "f1_1": f1_score(sum(ps) / len(ps), sum(rs) / len(rs)),
            # "mcc_best": mcc_best,
            # "mcc_threshold": thresholds[mcc_idx],
            "num_samples": len(labels)}


In [6]:
training_args = TrainingArguments(
    output_dir=os.path.join(output_path, time.strftime("BCE_Weight_%Y_%m%d_%H_%M", time.localtime(time.time()))),
    num_train_epochs=20,
    evaluation_strategy="epoch",
    eval_delay=1,
    save_strategy="epoch",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=os.path.join(output_path, 'logs'),
    logging_steps=10,
    optim="adamw_torch",
    learning_rate=2e-5
    # disable_tqdm=True
)

In [62]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data_set,
    eval_dataset=valid_data_set,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

The following is test scripts

In [None]:
model = BertFNNClassifier.from_pretrained('/efs_data/chentianyu/VulLibMiner/BCE_Weight_2023_1019_10_16/checkpoint-2325/')

In [29]:
model = BertFNNClassifier.from_pretrained('/efs_data/chentianyu/VulLibMiner/BCE_Weight_2023_1019_18_03/checkpoint-13020/')

In [26]:
model = BertFNNClassifier.from_pretrained('/efs_data/chentianyu/VulLibMiner/BCE_Weight_2023_1019_18_03/checkpoint-13950/')

In [140]:
model = BertFNNClassifier.from_pretrained('/efs_data/chentianyu/VulLibMiner/Focal_v1_2023_1018_18_38/checkpoint-11625/')

In [103]:
model = BertFNNClassifier.from_pretrained('/efs_data/chentianyu/VulLibMiner/BCE_Weight_2023_1019_13_04/checkpoint-11160/')

In [17]:
len(test_data_set)/633

4096.0

In [None]:
%%time
trainer = Trainer(
    model=model,
    args=training_args,
    # train_dataset=train_data_set,
    # eval_dataset=valid_data_set,
    compute_metrics=compute_metrics
)
test_res_bce_4096 = trainer.predict(test_data_set)

In [10]:
%%time
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data_set,
    eval_dataset=valid_data_set,
    compute_metrics=compute_metrics
)
test_res_bce = trainer.predict(test_data_set)

0.7598039215686274 0.6368230625583565 0.6928989139926128
CPU times: user 19min 13s, sys: 6min 5s, total: 25min 19s
Wall time: 17min 22s


In [30]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data_set,
    eval_dataset=valid_data_set,
    compute_metrics=compute_metrics
)
test_res_bce = trainer.predict(test_data_set)

0.7636022514071295 0.6252434557312606 0.6875311030160254


In [43]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data_set,
    eval_dataset=valid_data_set,
    compute_metrics=compute_metrics
)
test_res_bce = trainer.predict(test_data_set)

0.7598039215686274 0.6368230625583565 0.6928989139926128


In [36]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data_set,
    eval_dataset=valid_data_set,
    compute_metrics=compute_metrics
)
test_res_bce = trainer.predict(test_data_set)

0.7654784240150094 0.626145805414098 0.6888369635941427


In [20]:
def f1_score(p, r):
    if p + r < 1e-5:
        return 0.0
    return 2 * p * r / (p + r)

def modified_metrics_simple(predictions, labels, label_num, k):
    rerank = [item for item in zip(predictions, labels)][:512]
    rerank.sort(reverse=True)

    if label_num == 0:
        return None
    hits = sum([item[1] for item in rerank[:k]])
    prec = hits / min(k, label_num)
    rec = hits / label_num
#     print(prec, rec)
    return prec, rec, f1_score(prec, rec)

In [11]:
import json

with open(os.path.join(data_base_path, "test_to_4096.json"), 'r') as f:
    test_set_raw = json.load(f)
label_nums = [len(item['labels']) for item in test_set_raw]

In [13]:
import json

with open(os.path.join(data_base_path, "test.json"), 'r') as f:
    test_set_raw = json.load(f)
label_nums = [len(item['labels']) for item in test_set_raw]

In [79]:
k, ws = 1, 128
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res.predictions.reshape(-1, 128), test_res.label_ids.reshape(-1, 128), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.5929203539823009 0.4731879477454699 0.526330702092298


In [111]:
k, ws = 1, 128
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce.predictions.reshape(-1, 128), test_res_bce.label_ids.reshape(-1, 128), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.647787610619469 0.49761004482243426 0.5628536437529342


In [142]:
k, ws = 1, 128
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce.predictions.reshape(-1, 128), test_res_bce.label_ids.reshape(-1, 128), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.6460176991150443 0.49795138489828755 0.5624022754436476


In [151]:
k, ws = 1, 128
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce.predictions.reshape(-1, 128), test_res_bce.label_ids.reshape(-1, 128), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.6460176991150443 0.5011498678312838 0.5644366069204285


In [27]:
k, ws = 1, 256
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce.predictions.reshape(-1, 256), test_res_bce.label_ids.reshape(-1, 256), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.6407079646017699 0.4966092020074321 0.5595298837690412


In [59]:
k, ws = 3, 2048
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce.predictions.reshape(-1, ws), test_res_bce.label_ids.reshape(-1, ws), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.7244837758112096 0.7049076734474966 0.7145616732524707


In [41]:
k, ws = 3, 256
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce.predictions.reshape(-1, 256), test_res_bce.label_ids.reshape(-1, 256), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.7094395280235991 0.6900109182852546 0.6995903591878746


In [81]:
k, ws = 1, 128
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce.predictions.reshape(-1, 256), test_res_bce.label_ids.reshape(-1, 256), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.6707964601769911 0.5222413132590122 0.5872699625987495


In [15]:
k, ws = 1, 2048
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce.predictions.reshape(-1, ws), test_res_bce.label_ids.reshape(-1, ws), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.6690265486725664 0.5204566524920508 0.5854631954608753


In [21]:
k, ws = 1, 4096
res = [modified_metrics_simple(predictions, labels, label_num, k) for predictions, labels, label_num in \
           zip(test_res_bce_4096.predictions.reshape(-1, ws), test_res_bce_4096.label_ids.reshape(-1, ws), label_nums)]
res = [item for item in res if item != None]
ps, rs, fs = [v[0] for v in res], [v[1] for v in res], [v[2] for v in res]
print(sum(ps) / len(ps), sum(rs) / len(rs), f1_score(sum(ps) / len(ps), sum(rs) / len(rs)))

0.5681415929203539 0.4400879209286291 0.4959828074599471


In [None]:
test_set_raw[0]['top_k']

In [10]:
def combine_output(test_set_raw, test_res_bce, num):
    predictions = test_res_bce.predictions.reshape(-1, num)
    test_output = [{'cve_id': vuln['cve_id'], 'desc': vuln['desc'], \
                    'labels': vuln['labels'], 'top_k':[]} for vuln in test_set_raw]
    for vuln_raw, vuln, scores in zip(test_set_raw, test_output, test_res_bce.predictions.reshape(-1, num)):
        vuln['top_k'] = [{'lib_name': lib['lib_name'], 're_rank_score': float(score)} \
                         for (lib, score) in zip(vuln_raw['top_k'], scores)]
    return test_output

In [12]:
output = combine_output(test_set_raw, test_res_bce_4096, 4096)

In [61]:
output = combine_output(test_set_raw, test_res_bce, 2048)

In [14]:
with open(os.path.join('/data/chentianyu/libminer/output/eval_result_v1_bce_weight_4096/test.json'), 'w') as f:
    json.dump(output, f)

In [94]:
rec_pre = [sum(item)/num for item, num in zip(test_res.label_ids.reshape(-1, 128), label_nums) if num > 0]
sum(rec_pre) / len(rec_pre)

0.7590006895759108

In [89]:
717/1003

0.7148554336989033

In [55]:
sum(test_res.label_ids[:512])

3

In [44]:
test_res.predictions.reshape(-1, 128).shape

(633, 128)

In [None]:
for item in zip(test_res.predictions.reshape(-1, 128), test_res.label_ids.reshape(-1, 128)):
    print(item, sum(item[1]))
    break

In [None]:
torch.nn.BCEWithLogitsLoss??

In [28]:
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(torch.tensor([[-500, 103], [-100.0, 100.0]]).view(-1, 2), torch.tensor([0, 1]).view(-1))

In [None]:
torch.nn.functional.binary_cross_entropy_with_logits??

In [None]:
BertForSequenceClassification??

In [131]:
1 - torch.tensor([1, 1, 0])

tensor([0, 0, 1])