Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
retr0reg committed Apr 5, 2023
1 parent 8bf19d9 commit 95c8e0e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 45 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ output/*
logs/*
generate_code_segments/vuln/outputs.txt
generate_code_segments/nvuln/outputs.txt

**/__pycache__
45 changes: 0 additions & 45 deletions test.py

This file was deleted.

58 changes: 58 additions & 0 deletions vaildity_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import sys
import random

# 加载模型和标记器
model = BertForSequenceClassification.from_pretrained("./pwnbert_finetuned")
tokenizer = BertTokenizer.from_pretrained("./pwnbert_finetuned")

# def predict_vulnerability(model, tokenizer, code):
# inputs = tokenizer(code, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
# outputs = model(**inputs)
# logits = outputs.logits
# probabilities = torch.softmax(logits, dim=-1)
# label = torch.argmax(probabilities).item()

# return label

# 准备要测试的文本数据
def random_test():
n=0
for i in range(int(sys.argv[1])):
typies = ['nvuln','vuln']
ram = random.choice(typies)
with open(f'generate_code_segments/eval/{ram}/program_{i}.c', 'r') as f:
text = f.read()
# print(text)
text = input("::: ")
inputs = tokenizer(text, return_tensors='pt')
outputs = model(**inputs)
# 使用模型输出进行预测或评估
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
label = torch.argmax(probabilities).item()
if typies[label] == ram:
print("Correct")
n+=1

else:
print("Wrong")
n+=0

print(f"Accuracy: {(n/int(sys.argv[1]))*100}%")

def input_test():
text = input("Input your code ::: ")
inputs = tokenizer(text, return_tensors='pt')
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
label = torch.argmax(probabilities).item()
if label:
print("\nVuln!")
else:
print("\nNot Vuln!")

if __name__ == "__main__":
input_test()

0 comments on commit 95c8e0e

Please sign in to comment.