In [2]:
import org.mikrograd.diff.Value
import org.mikrograd.diff.div
import org.mikrograd.diff.plus

val a = Value(-4.0)
val b = Value(2.0)
var c = a + b
var d = a * b + b.pow(3.0)
c += c + 1
c += 1.0 + c + (-a)
d += d * 2 + (b + a).relu()
d += d * 3.0 + (b - a).relu()
val e = c - d
val f = e.pow(2.0)
var g = f / 2
g += 10.0 / f
println("$g")
g.backward()
println("${a.grad}") // prints 138.8338, i.e. the numerical value of dg/da
println("${b.grad}") // prints 645.5773, i.e. the numerical value of dg/db


Value(data=24.70408163265306, grad=0.0)
138.83381924198252
645.5772594752186


In [3]:
fun trace(root: Value): Pair<Set<Value>, Set<Pair<Value, Value>>> {
    val nodes = mutableSetOf<Value>()
    val edges = mutableSetOf<Pair<Value, Value>>()

    fun build(v: Value) {
        if (v !in nodes) {
            nodes.add(v)
            for (child in v._children) {
                edges.add(child to v)
                build(child)
            }
        }
    }

    build(root)
    return nodes to edges
}

In [4]:
trace(g)

([Value(data=24.70408163265306, grad=1.0), Value(data=24.5, grad=1.0), Value(data=49.0, grad=0.4958350687213661), Value(data=-7.0, grad=-6.941690962099126), Value(data=-1.0, grad=-6.941690962099126), Value(data=-3.0, grad=-13.883381924198252), Value(data=-2.0, grad=-27.766763848396504), Value(data=-4.0, grad=138.83381924198252), Value(data=2.0, grad=645.5772594752186), Value(data=-1.0, grad=-13.883381924198252), Value(data=1.0, grad=-13.883381924198252), Value(data=2.0, grad=-6.941690962099126), Value(data=-2.0, grad=-6.941690962099126), Value(data=1.0, grad=-6.941690962099126), Value(data=4.0, grad=-6.941690962099126), Value(data=-1.0, grad=27.766763848396504), Value(data=-6.0, grad=-6.941690962099126), Value(data=6.0, grad=6.941690962099126), Value(data=0.0, grad=27.766763848396504), Value(data=0.0, grad=83.30029154518951), Value(data=-8.0, grad=83.30029154518951), Value(data=8.0, grad=83.30029154518951), Value(data=0.0, grad=27.766763848396504), Value(data=0.0, grad=27.7667638483965

In [7]:
import org.markup.dot.Graph
import org.markup.dot.graph


fun drawDot(root: Value, format: String = "svg", rankdir: String = "LR"): Graph {
    require(rankdir in listOf("LR", "TB"))

    val (nodes, edges) = trace(root)

    return graph {
        directed()

        // Add nodes
        for (n in nodes) {
            node(id = n.hashCode().toString()) {
                shape("record")
                // Using '|' to separate data and grad, only data in this case
                label("{ data %.4f | grad %.4f}".format(n.data, n.grad))
            }
            n.op.takeIf { it.isNotEmpty() }?.let {
                val opId = "\"${n.hashCode()}$it\""
                node(id = opId) {
                    label(it)
                }
                edge(from = opId, to = n.hashCode().toString())
            }
        }

        // Add edges
        for ((n1, n2) in edges) {
            val n2Op = n2.op.takeIf { it.isNotEmpty() }?.let { "\"${n2.hashCode()}$it\"" } ?: n2.hashCode().toString()
            edge(from = n1.hashCode().toString(), to = n2Op)
        }
    }
}


In [8]:
drawDot(g).render()

digraph {
912203870 [shape="record", label="{ data 24,7041 | grad 1,0000}"];
"912203870+" [label="+"];
879083749 [shape="record", label="{ data 24,5000 | grad 1,0000}"];
"879083749*" [label="*"];
11149588 [shape="record", label="{ data 49,0000 | grad 0,4958}"];
"11149588**2.0" [label="**2.0"];
690167637 [shape="record", label="{ data -7,0000 | grad -6,9417}"];
"690167637+" [label="+"];
1714522111 [shape="record", label="{ data -1,0000 | grad -6,9417}"];
"1714522111+" [label="+"];
2093551924 [shape="record", label="{ data -3,0000 | grad -13,8834}"];
"2093551924+" [label="+"];
758791480 [shape="record", label="{ data -2,0000 | grad -27,7668}"];
"758791480+" [label="+"];
475158141 [shape="record", label="{ data -4,0000 | grad 138,8338}"];
1983263863 [shape="record", label="{ data 2,0000 | grad 645,5773}"];
1111677748 [shape="record", label="{ data -1,0000 | grad -13,8834}"];
"1111677748+" [label="+"];
1047391673 [shape="record", label="{ data 1,0000 | grad -13,8834}"];
813872567 [shape="r