Skip to content

Commit

Permalink
Add new Java rule to check for a weak key (#419)
Browse files Browse the repository at this point in the history
Checks for RSA and DSA keys where the lengths are too short.

Signed-off-by: Eric Brown <eric.brown@securesauce.dev>
  • Loading branch information
ericwb committed Apr 10, 2024
1 parent e8fe980 commit 79449c2
Show file tree
Hide file tree
Showing 13 changed files with 359 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
|----|------|-------------|
| JAV001 | [javax.crypto — weak cipher](rules/java/stdlib/javax-crypto-weak-cipher.md) | Use of a Broken or Risky Cryptographic Algorithm in `javax.crypto` Package |
| JAV002 | [java.security — weak hash](rules/java/stdlib/java-security-weak-hash.md) | Reversible One Way Hash in `java.security` Package |
| JAV003 | [java.security — weak key](rules/java/stdlib/java-security-weak-key.md) | Inadequate Encryption Strength Using Weak Keys in `java.security` Package |

## Python Standard Library

Expand Down
3 changes: 3 additions & 0 deletions docs/rules/java/stdlib/java-security-weak-key.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# java.security — weak key

::: precli.rules.java.stdlib.java_security_weak_key
2 changes: 1 addition & 1 deletion precli/parsers/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def visit_call_expression(self, nodes: list[Node]):
args=func_call_args,
)

self.analyze_node(self.context["node"].type, call=call)
self.analyze_node(tokens.CALL_EXPRESSION, call=call)

if call.var_node is not None:
symbol = self.current_symtab.get(call.var_node.text.decode())
Expand Down
145 changes: 119 additions & 26 deletions precli/parsers/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,46 +66,125 @@ def _get_func_ident(self, node: Node) -> Node:
if node.type == tokens.IDENTIFIER:
return node

def visit_method_invocation(self, nodes: list[Node]):
# TODO: field_access "." identifier argument_list
# or
# identifier "." identifier argument_list
if nodes[0] != tokens.IDENTIFIER:
def visit_local_variable_declaration(self, nodes: list[Node]):
# type_identifier variable_declarator
if nodes[0].type != tokens.TYPE_IDENTIFIER:
return

if nodes[2] != tokens.IDENTIFIER:
if nodes[1].type != tokens.VARIABLE_DECLARATOR:
return

obj_name = self.resolve(nodes[0])
method = nodes[2].text.decode()
if None in (obj_name, method):
var_nodes = nodes[1].named_children

if (
len(var_nodes) > 1
and var_nodes[0].type == tokens.IDENTIFIER
and var_nodes[1].type
in (
tokens.METHOD_INVOCATION,
tokens.ATTRIBUTE,
tokens.IDENTIFIER,
tokens.STRING_LITERAL,
tokens.CHARACTER_LITERAL,
tokens.DECIMAL_INTEGER_LITERAL,
tokens.HEX_INTEGER_LITERAL,
tokens.OCTAL_INTEGER_LITERAL,
tokens.DECIMAL_FLOATING_POINT_LITERAL,
tokens.BINARY_INTEGER_LITERAL,
tokens.TRUE,
tokens.FALSE,
tokens.NULL_LITERAL,
)
):
left_hand = self.resolve(var_nodes[0], default=var_nodes[0])
right_hand = self.resolve(var_nodes[1], default=var_nodes[1])

# This is in case a variable is reassigned
self.current_symtab.put(
var_nodes[0].text.decode(), tokens.IDENTIFIER, right_hand
)

# This is to help full resolution of an attribute/call.
# This results in two entries in the symtab for this assignment.
self.current_symtab.put(left_hand, tokens.IDENTIFIER, right_hand)

if var_nodes[1].type == tokens.METHOD_INVOCATION:
meth_invoke = var_nodes[1]
if (
meth_invoke.children[1].type == "."
and meth_invoke.children[2].type == tokens.IDENTIFIER
):
# (field_access | identifier) "." identifier argument_list
obj_node = meth_invoke.children[0]
method_node = meth_invoke.children[2]
else:
# identifier argument_list
obj_node = meth_invoke.children[0]
method_node = meth_invoke.children[0]

arg_list_node = self.child_by_type(
meth_invoke, tokens.ARGUMENT_LIST
)
call_args = self.get_func_args(arg_list_node)

call = Call(
node=var_nodes[1],
name=right_hand,
name_qual=right_hand,
# func_node=func_node, # no equivalent for Java
var_node=obj_node,
ident_node=method_node,
arg_list_node=arg_list_node,
args=call_args,
)
symbol = self.current_symtab.get(left_hand)
symbol.push_call(call)

self.visit(nodes)

def visit_method_invocation(self, nodes: list[Node]):
meth_invoke = self.context["node"]
if nodes[0].type not in (tokens.FIELD_ACCESS, tokens.IDENTIFIER):
return

func_call_qual = ".".join([obj_name, method])
func_call_args = self.get_func_args(nodes[3])
if nodes[1].type == "." and nodes[2].type == tokens.IDENTIFIER:
# (field_access | identifier) "." identifier argument_list
obj_node = nodes[0]
method_node = nodes[2]

obj_name = self.resolve(obj_node)
method = method_node.text.decode()
if None in (obj_name, method):
return

# (field_access | identifier) "." identifier argument_list
func_call_qual = ".".join([obj_name, method])
else:
# identifier argument_list
obj_node = nodes[0]
method_node = nodes[0]
method_name = self.resolve(method_node)
if method_name is None:
return
func_call_qual = method_name

arg_list_node = self.child_by_type(meth_invoke, tokens.ARGUMENT_LIST)
func_call_args = self.get_func_args(arg_list_node)

call = Call(
node=self.context["node"],
node=meth_invoke,
name=func_call_qual,
name_qual=func_call_qual,
# func_node=func_node, # no equivalent for Java
var_node=nodes[0],
ident_node=nodes[2],
arg_list_node=nodes[3],
var_node=obj_node,
ident_node=method_node,
arg_list_node=arg_list_node,
args=func_call_args,
)

self.analyze_node(self.context["node"].type, call=call)
self.analyze_node(tokens.METHOD_INVOCATION, call=call)

if call.var_node is not None:
symbol = self.current_symtab.get(call.var_node.text.decode())
if symbol is not None and symbol.type == tokens.IDENTIFIER:
symbol.push_call(call)
else:
# TODO: why is var_node None?
pass
symbol = self.current_symtab.get(call.var_node.text.decode())
if symbol is not None and symbol.type == tokens.IDENTIFIER:
symbol.push_call(call)

self.visit(nodes)

Expand Down Expand Up @@ -139,10 +218,24 @@ def resolve(self, node: Node, default=None):
try:
match node.type:
case tokens.METHOD_INVOCATION:
nodetext = node.children[0].text.decode()
if (
node.children[1].type == "."
and node.children[2].type == tokens.IDENTIFIER
):
# (field_access | identifier) "." identifier
# argument_list
part1 = node.children[0].text.decode()
part2 = node.children[2].text.decode()
nodetext = ".".join([part1, part2])
else:
# identifier argument_list
nodetext = node.children[0].text.decode()
symbol = self.get_qual_name(node.children[0])
if symbol is not None:
value = self.join_symbol(nodetext, symbol)
case tokens.FIELD_ACCESS:
# TODO
pass
case tokens.IDENTIFIER:
symbol = self.get_qual_name(node)
if symbol is not None:
Expand Down
4 changes: 2 additions & 2 deletions precli/parsers/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def visit_call(self, nodes: list[Node]):
self.current_symtab.remove(identifier)
self.current_symtab.put(identifier, tokens.IMPORT, module)

self.analyze_node(self.context["node"].type, call=call)
self.analyze_node(tokens.CALL, call=call)

if call.var_node is not None:
symbol = self.current_symtab.get(call.var_node.text.decode())
Expand All @@ -219,7 +219,7 @@ def visit_call(self, nodes: list[Node]):
self.visit(nodes)

def visit_assert(self, nodes: list[Node]):
self.analyze_node(self.context["node"].type)
self.analyze_node(tokens.ASSERT)
self.visit(nodes)

def visit_with_item(self, nodes: list[Node]):
Expand Down
3 changes: 3 additions & 0 deletions precli/parsers/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RELATIVE_IMPORT = "relative_import"
WILDCARD_IMPORT = "wildcard_import"
CALL = "call"
CALL_EXPRESSION = "call_expression"
METHOD_INVOCATION = "method_invocation"
ARGUMENT_LIST = "argument_list"
KEYWORD_ARGUMENT = "keyword_argument"
Expand All @@ -21,12 +22,14 @@
TYPE_IDENTIFIER = "type_identifier"
VARIABLE_DECLARATOR = "variable_declarator"
COMPARISON_OPERATOR = "comparison_operator"
ASSERT = "assert"
AS_PATTERN = "as_pattern"
SELECTOR_EXPRESSION = "selector_expression"
FIELD_IDENTIFIER = "field_identifier"
IDENTIFIER = "identifier"
TYPE = "type"
ATTRIBUTE = "attribute"
FIELD_ACCESS = "field_access"
DICTIONARY = "dictionary"
LIST = "list"
TUPLE = "tuple"
Expand Down
136 changes: 136 additions & 0 deletions precli/rules/java/stdlib/java_security_weak_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2024 Secure Saurce LLC
r"""
# Inadequate Encryption Strength Using Weak Keys in `java.security` Package
Using weak key sizes for cryptographic algorithms like RSA and DSA can
compromise the security of your encryption and digital signatures. Here's a
brief overview of the risks associated with weak key sizes for these
algorithms:
RSA (Rivest-Shamir-Adleman):
RSA is widely used for both encryption and digital signatures. Weak key sizes
in RSA can be vulnerable to factorization attacks, such as the famous RSA-129
challenge, which was factored in 1994 after 17 years of effort. Using small
key sizes makes it easier for attackers to factor the modulus and recover
the private key.
It's generally recommended to use RSA key sizes of 2048 bits or more for
security in the present day, with 3072 bits or higher being increasingly
preferred for long-term security.
DSA (Digital Signature Algorithm):
DSA is used for digital signatures and relies on the discrete logarithm
problem. Using weak key sizes in DSA can make it susceptible to attacks that
involve solving the discrete logarithm problem, like the GNFS (General
Number Field Sieve) algorithm.
For DSA, key sizes of 2048 bits or more are recommended for modern security.
Note that DSA is not as commonly used as RSA or ECC for new applications, and
ECDSA (Elliptic Curve Digital Signature Algorithm) is often preferred due to
its efficiency and strong security properties.
## Example
```java
import java.security.*;
public class KeyPairGeneratorRSA {
public static void main(String[] args) {
try {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(512);
KeyPair keyPair = keyPairGenerator.generateKeyPair();
} catch (NoSuchAlgorithmException e) {
System.err.println("RSA algorithm not available.");
}
}
}
```
## Remediation
Its recommended to increase the key size to at least 2048 for DSA and RSA
algorithms.
```java
import java.security.*;
public class KeyPairGeneratorRSA {
public static void main(String[] args) {
try {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
KeyPair keyPair = keyPairGenerator.generateKeyPair();
} catch (NoSuchAlgorithmException e) {
System.err.println("RSA algorithm not available.");
}
}
}
```
## See also
- [KeyPairGenerator (Java SE & JDK)](https://docs.oracle.com/en/java/javase/21/docs/api/java.base/java/security/KeyPairGenerator.html)
- [CWE-326: Inadequate Encryption Strength](https://cwe.mitre.org/data/definitions/326.html)
_New in version 0.5.0_
""" # noqa: E501
from precli.core.call import Call
from precli.core.level import Level
from precli.core.location import Location
from precli.core.result import Result
from precli.rules import Rule


class KeyPairGeneratorWeakKey(Rule):
def __init__(self, id: str):
super().__init__(
id=id,
name="inadequate_encryption_strength",
description=__doc__,
cwe_id=326,
message="Using '{0}' key sizes less than '{1}' bits is considered "
"vulnerable to attacks.",
wildcards={
"java.security.*": [
"KeyPairGenerator",
],
},
)

def analyze_method_invocation(self, context: dict, call: Call) -> Result:
if call.name_qualified not in [
"java.security.KeyPairGenerator.getInstance.initialize"
]:
return

argument = call.get_argument(position=0)
keysize = argument.value

symbol = context["symtab"].get(call.var_node.text.decode())
if "getInstance" not in [
x.identifier_node.text.decode() for x in symbol.call_history
]:
return

get_instance_call = symbol.call_history[0]
algorithm = get_instance_call.get_argument(position=0).value_str
if algorithm is None or algorithm.upper() not in ("DSA", "RSA"):
return

if isinstance(keysize, int) and keysize < 2048:
fixes = Rule.get_fixes(
context=context,
deleted_location=Location(node=argument.node),
description="Use a minimum key size of 2048 for RSA keys.",
inserted_content="2048",
)

return Result(
rule_id=self.id,
location=Location(node=argument.node),
level=Level.ERROR if keysize <= 1024 else Level.WARNING,
message=self.message.format("RSA", 2048),
fixes=fixes,
)
2 changes: 2 additions & 0 deletions precli/rules/java/stdlib/javax_crypto_weak_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def analyze_method_invocation(self, context: dict, call: Call) -> Result:

argument = call.get_argument(position=0)
transformation = argument.value_str
if transformation is None:
return

# DES/CBC/PKCS5Padding
cipher, *mode_padding = transformation.split("/")
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ precli.rules.java =
# precli/rules/java/stdlib/java_security_weak_hash.py
JAV002 = precli.rules.java.stdlib.java_security_weak_hash:MessageDigestWeakHash

# precli/rules/java/stdlib/java_security_weak_key.py
JAV003 = precli.rules.java.stdlib.java_security_weak_key:KeyPairGeneratorWeakKey

precli.rules.python =
# precli/rules/python/stdlib/assert.py
PY001 = precli.rules.python.stdlib.assert:Assert
Expand Down
Loading

0 comments on commit 79449c2

Please sign in to comment.