In [1]:
// Magics to load test framework and imports.
%load djl_imports

In [2]:
// NDManager is a factory class to produce n-dimensional arrays.
// Think of these as PyTorch tensors.
NDManager manager = NDManager.newBaseManager();

In [3]:
// Classical merge sort
public class MergeSort {
    public static int[] sort(int[] nums) {
        if (nums.length == 1) {
            return nums;
        }

        int mid = nums.length / 2;
        int[] left = new int[mid];
        int[] right = new int[nums.length - mid];

        for (int i = 0; i < left.length; i++) {
            left[i] = nums[i];
        }
        for (int i = 0; i < right.length; i++) {
            right[i] = nums[mid + i];
        }

        left = sort(left);
        right = sort(right);

        return merge(left, right);

        }

    public static int[] merge(int[] left, int[] right) {
        int[] result = new int[left.length + right.length];

        int leftInd = 0;
        int rightInd = 0;
        int resultInd = 0;

        while (leftInd < left.length || rightInd < right.length) {
            if (leftInd < left.length && rightInd < right.length) {
                if (left[leftInd] < right[rightInd]) {
                    result[resultInd++] = left[leftInd++];
                } else {
                    result[resultInd++] = right[rightInd++];
                }
            } else if (leftInd < left.length) {
                result[resultInd++] = left[leftInd++];
            } else if (rightInd < right.length) {
                result[resultInd++] = right[rightInd++];
            }
        }

        return result;
    }
}

In [4]:
int[] t1In       = new int[]{974, 100, 34, 700, -1, 555, 832};
int[] t1Expected = new int[]{-1, 34, 100, 555, 700, 832, 974};

int[] result = MergeSort.sort(t1In);
Assert.assertArrayEquals(result, t1Expected);


In [31]:
// Now try to soften it!
public class DifferentiableMergeSort {
    public static NDArray softMerge(NDArray left, NDArray right) {
        NDArray merged = manager.create(new float[(int)left.size() + (int)right.size()]);
        int i = 0;
        int j = 0;
        int k = 0;
        
        while (i < left.size() && j < right.size()) {
            NDArray scores = manager.create(new float[]{left.getFloat(i), right.getFloat(j)});
            NDArray weights = scores.softmax(/*axis=*/0).toType(left.getDataType(), false);

            float mergedElement = weights.getFloat(0) * left.getFloat(i) + weights.getFloat(1) * right.getFloat(j);
            merged.setScalar(new NDIndex(k++), mergedElement);

            i += (weights.getFloat(0) > 0.5) ? 1 : 0;
            j += (weights.getFloat(1) > 0.5) ? 1 : 0;
        }

        if (i < left.size()) {
            merged.set(new NDIndex(k + ":" + (k + left.size() - i)), left.get(new NDIndex(i + ":")));
        }
        if (j < right.size()) {
           merged.set(new NDIndex(k + ":" + (k + right.size() - j)), right.get(new NDIndex(j + ":")));
        }

        return merged;
    }

    public static NDArray sort(NDArray arr) {
        if (arr.size() <= 1) {
            return arr;
        }

        int mid = (int) arr.size() / 2;
        NDArray leftHalf = sort(arr.get(new NDIndex(":" + mid)));
        NDArray rightHalf = sort(arr.get(new NDIndex(mid + ":")));

        return softMerge(leftHalf, rightHalf);
    }
}


In [33]:
import org.apache.commons.lang3.tuple.Pair;
import java.util.ArrayList;
import ai.djl.training.optimizer.*;

float learningRate = 0.1f;
NDArray t1In       = manager.create(new float[]{10.f, 234.f, 1.f, 83.f, 5.f});
NDArray t1Expected = manager.create(new float[]{234.f, 83.f, 10.f, 5.f, 1.f});

// Only one test case for now.
ArrayList<Pair<NDArray, NDArray>> testCases = new ArrayList();
testCases.add(Pair.of(t1In, t1Expected));

// TODO: Try to introduce a bug in the merge sort implementation above. We may have to make everything a parameter.
// Then perhaps set up a neural net to generate test cases and tune it on the normal implementation.
for (Pair<NDArray, NDArray> pair : testCases) {
    NDArray test = pair.getLeft();
    NDArray result = manager.create(test.getShape());
    test.setRequiresGradient(true);

    // Training loop
    for (int i = 0; i < 10; i++) {
        result = DifferentiableMergeSort.sort(test);
        result.setRequiresGradient(true);
        // Create a new gradient collector - these need to be made each training iteration
        try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
            NDArray loss = Loss.l1Loss().evaluate(new NDList(/*expected=*/pair.getRight()), new NDList(result));

            // TODO: Verify that this is actually autodiffing through the merge sort.
            // It seems like we might be just changing the input.
            gc.backward(loss);

            print("LOSS: " + loss.getFloat());
        } catch(Exception e) {
            e.printStackTrace();
        }
        // Update the loss via result params (gradient descent)
        // This actually is increasing the loss since we are doing SGD
        // on the input itself... something is wrong.
        test.subi(result.getGradient().mul(learningRate)); 

        // Clear gradients otherwise memory creeps
        test.getGradient().close();
        result.getGradient().close();
    }
    print("=================");
    print("RESULT: ");
    print(result);
}

LOSS: 0.020709705
LOSS: 0.028915882
LOSS: 0.045237254
LOSS: 0.057474326
LOSS: 0.0741292
LOSS: 0.08631287
LOSS: 0.098490715
LOSS: 0.11066332
LOSS: 0.122830965
LOSS: 0.13868037
RESULT: 
ND: (5) cpu() float32 hasGradient
[233.84  ,  83.18  ,  10.1149,   4.7815,   0.98  ]

