Skip to content

Commit

Permalink
Improve example test testPatcher and increase caching speed
Browse files Browse the repository at this point in the history
  • Loading branch information
oSumAtrIX committed Mar 20, 2022
1 parent 81e0220 commit 5d146c3
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 60 deletions.
7 changes: 3 additions & 4 deletions src/main/kotlin/net/revanced/patcher/Patcher.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.Jar2ASM
import java.io.InputStream
import java.io.OutputStream
import java.util.jar.JarOutputStream

/**
* The patcher. (docs WIP)
Expand All @@ -20,12 +19,12 @@ class Patcher(
private val input: InputStream,
signatures: Array<Signature>,
) {
val cache = Cache()
var cache: Cache
private val patches: MutableList<Patch> = mutableListOf()

init {
cache.classes.putAll(Jar2ASM.jar2asm(input))
cache.methods.putAll(MethodResolver(cache.classes.values.toList(), signatures).resolve())
val classes = Jar2ASM.jar2asm(input);
cache = Cache(classes, MethodResolver(classes, signatures).resolve())
}

fun addPatches(vararg patches: Patch) {
Expand Down
10 changes: 5 additions & 5 deletions src/main/kotlin/net/revanced/patcher/cache/Cache.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ package net.revanced.patcher.cache

import org.objectweb.asm.tree.ClassNode

class Cache {
val classes: MutableMap<String, ClassNode> = mutableMapOf()
val methods: MethodMap = MethodMap()
}
class Cache (
val classes: List<ClassNode>,
val methods: MethodMap
)

class MethodMap : LinkedHashMap<String, PatchData>() {
override fun get(key: String): PatchData {
return super.get(key) ?: throw MethodNotFoundException("Method $key not found in method cache")
return super.get(key) ?: throw MethodNotFoundException("Method $key was not found in the method cache")
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/main/kotlin/net/revanced/patcher/cache/PatchData.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import org.objectweb.asm.tree.ClassNode
import org.objectweb.asm.tree.MethodNode

data class PatchData(
val cls: ClassNode,
val declaringClass: ClassNode,
val method: MethodNode,
val sd: ScanData
val scanData: PatternScanData
)

data class ScanData(
data class PatternScanData(
val startIndex: Int,
val endIndex: Int
)
17 changes: 9 additions & 8 deletions src/main/kotlin/net/revanced/patcher/resolver/MethodResolver.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package net.revanced.patcher.resolver

import mu.KotlinLogging
import net.revanced.patcher.cache.MethodMap
import net.revanced.patcher.cache.PatchData
import net.revanced.patcher.cache.ScanData
import net.revanced.patcher.cache.PatternScanData
import net.revanced.patcher.signature.Signature
import net.revanced.patcher.util.ExtraTypes
import org.objectweb.asm.Type
Expand All @@ -13,13 +14,13 @@ import org.objectweb.asm.tree.MethodNode
private val logger = KotlinLogging.logger("MethodResolver")

internal class MethodResolver(private val classList: List<ClassNode>, private val signatures: Array<Signature>) {
fun resolve(): MutableMap<String, PatchData> {
val patchData = mutableMapOf<String, PatchData>()
fun resolve(): MethodMap {
val methodMap = MethodMap()

for ((classNode, methods) in classList) {
for (method in methods) {
for (signature in signatures) {
if (patchData.containsKey(signature.name)) { // method already found for this sig
if (methodMap.containsKey(signature.name)) { // method already found for this sig
logger.debug { "Sig ${signature.name} already found, skipping." }
continue
}
Expand All @@ -30,10 +31,10 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
continue
}
logger.debug { "Method for sig ${signature.name} found!" }
patchData[signature.name] = PatchData(
methodMap[signature.name] = PatchData(
classNode,
method,
ScanData(
PatternScanData(
// sadly we cannot create contracts for a data class, so we must assert
sr.startIndex!!,
sr.endIndex!!
Expand All @@ -44,11 +45,11 @@ internal class MethodResolver(private val classList: List<ClassNode>, private va
}

for (signature in signatures) {
if (patchData.containsKey(signature.name)) continue
if (methodMap.containsKey(signature.name)) continue
logger.error { "Could not find method for sig ${signature.name}!" }
}

return patchData
return methodMap
}

private fun cmp(method: MethodNode, signature: Signature): Pair<Boolean, ScanResult?> {
Expand Down
20 changes: 11 additions & 9 deletions src/main/kotlin/net/revanced/patcher/util/Jar2ASM.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,20 @@ import java.util.jar.JarInputStream
import java.util.jar.JarOutputStream

object Jar2ASM {
fun jar2asm(input: InputStream): Map<String, ClassNode> {
return buildMap {
val jar = JarInputStream(input)
fun jar2asm(input: InputStream) = mutableListOf<ClassNode>().apply {
val jar = JarInputStream(input)
while (true) {
val e = jar.nextJarEntry ?: break
if (e.name.endsWith(".class")) {
val classNode = ClassNode()
ClassReader(jar.readAllBytes()).accept(classNode, ClassReader.EXPAND_FRAMES)
this[e.name] = classNode
this.add(classNode)
}
jar.closeEntry()
}
}
}
fun asm2jar(input: InputStream, output: OutputStream, structure: Map<String, ClassNode>) {

fun asm2jar(input: InputStream, output: OutputStream, classes: List<ClassNode>) {
val jis = JarInputStream(input)
val jos = JarOutputStream(output)

Expand All @@ -33,10 +32,13 @@ object Jar2ASM {
val next = jis.nextJarEntry ?: break
val e = JarEntry(next) // clone it, to not modify the input (if possible)
jos.putNextEntry(e)
if (structure.containsKey(e.name)) {

val clazz = classes.singleOrNull {
clazz -> clazz.name == e.name
};
if (clazz != null) {
val cw = ClassWriter(ClassWriter.COMPUTE_MAXS or ClassWriter.COMPUTE_FRAMES)
val cn = structure[e.name]!!
cn.accept(cw)
clazz.accept(cw)
jos.write(cw.toByteArray())
} else {
jos.write(jis.readAllBytes())
Expand Down
71 changes: 40 additions & 31 deletions src/test/kotlin/net/revanced/patcher/PatcherTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ import net.revanced.patcher.util.ExtraTypes
import net.revanced.patcher.writer.ASMWriter.setAt
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import org.objectweb.asm.tree.LdcInsnNode
import java.io.ByteArrayOutputStream
import org.objectweb.asm.tree.*
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

internal class PatcherTest {
private val testSigs: Array<Signature> = arrayOf(
Expand Down Expand Up @@ -46,14 +43,24 @@ internal class PatcherTest {
patcher.addPatches(
Patch ("TestPatch") {
// Get the method from the resolver cache
val main = patcher.cache.methods["mainMethod"]
val mainMethod = patcher.cache.methods["mainMethod"]
// Get the instruction list
val insn = main.method.instructions!!
val instructions = mainMethod.method.instructions!!
// Let's modify it, so it prints "Hello, ReVanced!"
// Get the start index of our signature
// Get the start index of our opcode pattern
// This will be the index of the LDC instruction
val startIndex = main.sd.startIndex
insn.setAt(startIndex, LdcInsnNode("Hello, ReVanced!"))
val startIndex = mainMethod.scanData.startIndex
// Create a new Ldc node and replace the LDC instruction
val stringNode = LdcInsnNode("Hello, ReVanced!");
instructions.setAt(startIndex, stringNode)
// Now lets print our string to the console output
// First create a list of instructions
val printCode = InsnList();
printCode.add(LdcInsnNode("Hello, ReVanced!"))
printCode.add(MethodInsnNode(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V"))
// Add the list after the second instruction by our pattern
instructions.insert(instructions[startIndex + 1], printCode)

// Finally, tell the patcher that this patch was a success.
// You can also return PatchResultError with a message.
// If an exception is thrown inside this function,
Expand All @@ -62,38 +69,40 @@ internal class PatcherTest {
}
)

// Apply all patches loaded in the patcher
val result = patcher.applyPatches()
// You can check if an error occurred
for ((s, r) in result) {
if (r.isFailure) {
throw Exception("Patch $s failed", r.exceptionOrNull()!!)
}
}

// TODO Doesn't work, needs to be fixed.
// val out = ByteArrayOutputStream()
// patcher.saveTo(out)
// assertTrue(
// // 8 is a random value, it's just weird if it's any lower than that
// out.size() > 8,
// "Output must be at least 8 bytes"
// )
//
// out.close()
//val out = ByteArrayOutputStream()
//patcher.saveTo(out)
//assertTrue(
// // 8 is a random value, it's just weird if it's any lower than that
// out.size() > 8,
// "Output must be at least 8 bytes"
//)
//
//out.close()
testData.close()
}

// TODO Doesn't work, needs to be fixed.
// @Test
// fun noChanges() {
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
// val available = testData.available()
// val patcher = Patcher(testData, testSigs)
//
// val out = ByteArrayOutputStream()
// patcher.saveTo(out)
// assertEquals(available, out.size())
//
// out.close()
// testData.close()
// }
//@Test
//fun noChanges() {
// val testData = PatcherTest::class.java.getResourceAsStream("/test1.jar")!!
// val available = testData.available()
// val patcher = Patcher(testData, testSigs)
//
// val out = ByteArrayOutputStream()
// patcher.saveTo(out)
// assertEquals(available, out.size())
//
// out.close()
// testData.close()
//}
}

0 comments on commit 5d146c3

Please sign in to comment.