Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CBC BC, salting and decryption routine to the AES implementation #456

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,26 @@ package com.tari.android.wallet.infrastructure.security.encryption

import java.io.InputStream
import java.io.OutputStream
import java.security.MessageDigest
import javax.crypto.Cipher
import javax.crypto.CipherOutputStream
import java.security.SecureRandom
import java.util.*
import javax.crypto.*
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.SecretKeySpec


interface SymmetricEncryptionAlgorithm {

fun encrypt(
key: CharArray,
inputStreamProvider: () -> InputStream,
outputStreamProvider: () -> OutputStream
password: CharArray,
sourceStreamProvider: () -> InputStream,
destinationStreamProvider: () -> OutputStream
)

fun decrypt(
password: CharArray,
sourceStreamProvider: () -> InputStream,
destinationStreamProvider: () -> OutputStream
)

companion object {
Expand All @@ -53,40 +62,108 @@ interface SymmetricEncryptionAlgorithm {

}

private class AES(
private val digest: MessageDigest = MessageDigest.getInstance(DIGEST_ALGORITHM)
) : SymmetricEncryptionAlgorithm {
private class AES(private val random: SecureRandom = SecureRandom()) :
SymmetricEncryptionAlgorithm {

// TODO add salt?
// TODO add padding and initialization vector?
override fun encrypt(
key: CharArray,
password: CharArray,
sourceStreamProvider: () -> InputStream,
destinationStreamProvider: () -> OutputStream
) {
val (salt, keySpec) = defineKeySpec(password)
val cbcInitializationVector = ByteArray(IV_SIZE).apply(random::nextBytes)
val cipher = Cipher.getInstance(CIPHER_AES_TRANSFORMATION)
.apply { init(Cipher.ENCRYPT_MODE, keySpec, IvParameterSpec(cbcInitializationVector)) }
createEncryptedFile(
sourceStreamProvider,
destinationStreamProvider,
cipher,
cbcInitializationVector,
salt
)
}

private fun defineKeySpec(key: CharArray): Pair<ByteArray, SecretKeySpec> {
val salt = ByteArray(SALT_SIZE).apply(random::nextBytes)
val spec = PBEKeySpec(key, salt, KEY_HASHING_TIMES, BIT_KEY_SIZE)
val secret: SecretKey =
SecretKeyFactory.getInstance(SECRET_KEYGEN_ALGORITHM).generateSecret(spec)
spec.clearPassword()
Arrays.fill(key, NULL_BYTE.toChar())
return Pair(salt, SecretKeySpec(secret.encoded, ALGORITHM_AES))
}

private fun createEncryptedFile(
inputStreamProvider: () -> InputStream,
outputStreamProvider: () -> OutputStream
outputStreamProvider: () -> OutputStream,
cipher: Cipher,
cbcInitializationVector: ByteArray,
salt: ByteArray
) {
val input = inputStreamProvider()
val output = outputStreamProvider()
val encryptionKey: ByteArray = key.toString().toByteArray(Charsets.UTF_8)
.let { digest.digest(it).copyOf(KEY_SIZE_128_BIT) }
val sks = SecretKeySpec(encryptionKey, ALGORITHM_AES)
val cipher =
Cipher.getInstance(CIPHER_AES_TRANSFORMATION).apply { init(Cipher.ENCRYPT_MODE, sks) }
output.write(salt)
output.write(cbcInitializationVector)
val cos = CipherOutputStream(output, cipher)
val buffer = ByteArray(8)
var bytesRead: Int
while (input.read(buffer).also { bytesRead = it } != -1) {
cos.write(buffer, 0, bytesRead)
}
val buffer = ByteArray(STREAM_BUFFER_SIZE)
generateSequence { input.read(buffer) }
.takeWhile { it != EOF }
.forEach { cos.write(buffer, NO_OFFSET, it) }
cos.flush()
cos.close()
input.close()
}

override fun decrypt(
password: CharArray,
sourceStreamProvider: () -> InputStream,
destinationStreamProvider: () -> OutputStream
) = sourceStreamProvider().use { readStream ->
destinationStreamProvider().use { writeStream ->
val salt: ByteArray = ByteArray(SALT_SIZE).apply { readStream.read(this) }
val iv: ByteArray = ByteArray(IV_SIZE).apply { readStream.read(this) }
val key = deriveKey(password, salt)
decipherFile(key, iv, readStream, writeStream)
}
}

private fun decipherFile(
sks: SecretKeySpec,
iv: ByteArray,
sourceStream: InputStream,
destinationStream: OutputStream
) {
val cipher = Cipher.getInstance(CIPHER_AES_TRANSFORMATION)
.apply { init(Cipher.DECRYPT_MODE, sks, IvParameterSpec(iv)) }
val input = CipherInputStream(sourceStream, cipher)
val buffer = ByteArray(STREAM_BUFFER_SIZE)
generateSequence { input.read(buffer) }
.takeWhile { it != EOF }
.forEach { destinationStream.write(buffer, 0, it) }
destinationStream.flush()
}

private fun deriveKey(password: CharArray, salt: ByteArray): SecretKeySpec {
val spec = PBEKeySpec(password, salt, KEY_HASHING_TIMES, BIT_KEY_SIZE)
val secret: SecretKey =
SecretKeyFactory.getInstance(SECRET_KEYGEN_ALGORITHM).generateSecret(spec)
return SecretKeySpec(secret.encoded, ALGORITHM_AES)
}

private companion object {
private const val DIGEST_ALGORITHM = "SHA-1"
private const val ALGORITHM_AES = "AES"
private const val CIPHER_AES_TRANSFORMATION = ALGORITHM_AES
private const val KEY_SIZE_128_BIT = 16
private const val CIPHER_AES_TRANSFORMATION = "AES/CBC/PKCS5Padding"
private const val SECRET_KEYGEN_ALGORITHM = "PBKDF2withHmacSHA1"
private const val KEY_SIZE = 32
private const val BYTE_SIZE = 8
private const val BIT_KEY_SIZE = KEY_SIZE * BYTE_SIZE
private const val STREAM_BUFFER_SIZE = 1024
private const val EOF = -1
private const val NO_OFFSET = 0
private const val NULL_BYTE: Byte = 0x00
private const val IV_SIZE = 16
private const val SALT_SIZE = 16
private const val KEY_HASHING_TIMES = 10000
}

}