<a href="https://colab.research.google.com/github/timsetsfire/wandb-examples/blob/main/colab/wandb_with_scala_via_wandb_java_client.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# WandB with Scala

This notebook is meant to be run in google colab

In [None]:
%%sh
curl https://pjreddie.com/media/files/mnist_train.csv > mnist_train.csv
curl https://pjreddie.com/media/files/mnist_test.csv > mnist_test.csv
pip install wandb[grpc]==0.10.32 -q --upgrade
git clone https://github.com/wandb/client-java.git
wandb login WANDB_API_KEY
apt-get install maven &> /dev/null
cd client-java 

In [None]:
%%sh
cd client-java
make install
make build

## Install the Scala Kernel
If you get a "scala" kernel not recognized warning when loading up the notebook for the first time, start by running the two cells below. Once you are done **reload the page** to load the notebook in the installed Scala kernel.

In [None]:
%%shell
SCALA_VERSION=2.12.8 ALMOND_VERSION=0.3.0+16-548dc10f-SNAPSHOT
curl -Lo coursier https://git.io/coursier-cli
chmod +x coursier
./coursier bootstrap \
    -r jitpack -r sonatype:snapshots \
    -i user -I user:sh.almond:scala-kernel-api_$SCALA_VERSION:$ALMOND_VERSION \
    sh.almond:scala-kernel_$SCALA_VERSION:$ALMOND_VERSION \
    --sources --default=true \
    -o almond-snapshot --embed-files=false &> /dev/null
rm coursier
./almond-snapshot --install --global --force &> /dev/null
rm almond-snapshot

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 42577  100 42577    0     0  67905      0 --:--:-- --:--:-- --:--:-- 67905




In [None]:
%%shell
echo "{
  \"language\" : \"scala\",
  \"display_name\" : \"Scala\",
  \"argv\" : [
    \"bash\",
    \"-c\",
    \"env LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libpython3.6m.so:\$LD_PRELOAD java -jar /usr/local/share/jupyter/kernels/scala/launcher.jar --connection-file {connection_file}\"
  ]
}" > /usr/local/share/jupyter/kernels/scala/kernel.json



## Refresh your browser at this point

In [1]:
val clientJar = "/content/client-java/target/client-ng-java-1.0-SNAPSHOT-jar-with-dependencies.jar"
val path = java.nio.file.FileSystems.getDefault().getPath(clientJar)
val x = ammonite.ops.Path(path)
interp.load.cp(x)

[36mclientJar[39m: [32mString[39m = [32m"/content/client-java/target/client-ng-java-1.0-SNAPSHOT-jar-with-dependencies.jar"[39m
[36mpath[39m: [32mjava[39m.[32mnio[39m.[32mfile[39m.[32mPath[39m = /content/client-java/target/client-ng-java-1.0-SNAPSHOT-jar-with-dependencies.jar
[36mx[39m: [32mos[39m.[32mPath[39m = root/[32m'content[39m/[32m"client-java"[39m/[32m'target[39m/[32m"client-ng-java-1.0-SNAPSHOT-jar-with-dependencies.jar"[39m

## Download Necessary Libraries

Using import as below we download all libraries necessary to create the MNist Model.  This may take a while.

In [None]:
import $ivy.`org.nd4j:nd4j-native-platform:1.0.0-M2`
import $ivy.`org.deeplearning4j:deeplearning4j-datasets:1.0.0-M2`
import $ivy.`org.deeplearning4j:deeplearning4j-core:1.0.0-M2`
import $ivy.`org.nd4j:nd4s_2.11:1.0.0-beta7`
import $ivy.`org.nd4j:nd4j-api:1.0.0-M2`
import $ivy.`org.nd4j:nd4j-native-platform:1.0.0-M2`

In [None]:
  import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  import org.deeplearning4j.nn.conf.layers.DenseLayer;
  import org.deeplearning4j.nn.conf.layers.OutputLayer;
  import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  import org.deeplearning4j.nn.weights.WeightInit;
  import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  import org.nd4j.evaluation.classification.Evaluation;
  import org.nd4j.linalg.activations.Activation;
  import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  import org.nd4j.linalg.learning.config.{Nadam, Adam};
  import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
  import org.nd4j.linalg.dataset.DataSet;
  import org.nd4j.linalg.factory.Nd4j
  import java.io.{File, PrintWriter}
  import org.datavec.api.records.reader.RecordReader;
  import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
  import org.datavec.api.split.FileSplit;
  import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;

In [5]:
def loadData(path: String, batchSize: Int, labelIndex: Int, numClasses: Int): DataSetIterator = {
    val file = new java.io.File(path)
    val reader: RecordReader = new CSVRecordReader()
    reader.initialize(new FileSplit( file ))
    new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numClasses); //reader, batch size, label index, num of label classes
}
{{
  val trainIter = loadData("mnist_train.csv", 256, 0, 10)
  val testIter = loadData("mnist_test.csv", 256, 0, 10)
}}

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.


defined [32mfunction[39m [36mloadData[39m

In [None]:
def loadData2(path: String, labelIndex: Int, numClasses: Int) = {
  val data = Nd4j.readNumpy(path, ",")
  val labels = data.getColumn(labelIndex)
  val Array(m,n) = data.shape
  val features = data.getColumns(  (0 until n.toInt).filter( _ != labelIndex):_*)
  val oneHotLabels = Nd4j.zeros(m,numClasses) 
  labels.dup.data.asInt.zipWithIndex.foreach { 
    case (k,v) => oneHotLabels.putScalar(v,k,1.0)
  }
  val dataset = new DataSet(features, oneHotLabels)
  dataset
}
// wrap in braces to suppress all stdout

val train = loadData2("mnist_train.csv", 0, 10)
val test = loadData2("mnist_test.csv", 0, 10)

In [15]:
import scala.collection.JavaConverters._
import com.wandb.client._
import org.json.JSONObject;


val rate=0.0015
val numEpochs=1
val numColumns=28
val randomSeed=123
val numRows=28
val outputNum=10

val config = new JSONObject()

val configMap = Map(
  "numRows" -> numRows, "numColumns" -> numColumns, "outputNum" -> outputNum,
  "randomSeed" -> randomSeed, "numEpochs" -> numEpochs, "rate" -> rate
  )

configMap.foreach{ case (k,v) => config.put(k,v) }

val tags = List("scala", "dl4j", "client-java").asJava

[32mimport [39m[36mscala.collection.JavaConverters._
[39m
[32mimport [39m[36mcom.wandb.client._
[39m
[32mimport [39m[36morg.json.JSONObject;


[39m
[36mrate[39m: [32mDouble[39m = [32m0.0015[39m
[36mnumEpochs[39m: [32mInt[39m = [32m1[39m
[36mnumColumns[39m: [32mInt[39m = [32m28[39m
[36mrandomSeed[39m: [32mInt[39m = [32m123[39m
[36mnumRows[39m: [32mInt[39m = [32m28[39m
[36moutputNum[39m: [32mInt[39m = [32m10[39m
[36mconfig[39m: [32mJSONObject[39m = {"randomSeed":123,"numRows":28,"rate":0.0015,"numEpochs":1,"outputNum":10,"numColumns":28}
[36mconfigMap[39m: [32mMap[39m[[32mString[39m, [32mAnyVal[39m] = [33mMap[39m(
  [32m"rate"[39m -> [32m0.0015[39m,
  [32m"numEpochs"[39m -> [32m1[39m,
  [32m"numColumns"[39m -> [32m28[39m,
  [32m"randomSeed"[39m -> [32m123[39m,
  [32m"numRows"[39m -> [32m28[39m,
  [32m"outputNum"[39m -> [32m10[39m
)
[36mtags[39m: [32mjava[39m.[32mutil[39m.[32mList[39m[[32mString

In [16]:
val modelConf = new NeuralNetConfiguration.Builder()
            .seed(randomSeed) //include a random seed for reproducibility
            .activation(Activation.RELU)
            .weightInit(WeightInit.XAVIER)
            .updater(new Adam())
            .l2(rate * 0.0001) // regularize learning model
            .list()
            .layer(new DenseLayer.Builder() //create the first input layer.
                    .nIn(numRows * numColumns)
                    .nOut(40)
                    .build())
            .layer(new DenseLayer.Builder() //create the second input layer
                    .nIn(40).activation(Activation.IDENTITY)
                    .nOut(10)
                    .build())
            .layer(new OutputLayer.Builder(LossFunction.MCXENT) //create hidden layer
                    .activation(Activation.SOFTMAX)
                    .nOut(outputNum)
                    .build())
            .build();

[36mmodelConf[39m: [32mMultiLayerConfiguration[39m = {
  "backpropType" : "Standard",
  "cacheMode" : "NONE",
  "confs" : [ {
    "cacheMode" : "NONE",
    "dataType" : "FLOAT",
    "epochCount" : 0,
    "iterationCount" : 0,
    "layer" : {
      "@class" : "org.deeplearning4j.nn.conf.layers.DenseLayer",
      "activationFn" : {
        "@class" : "org.nd4j.linalg.activations.impl.ActivationReLU",
        "max" : null,
        "negativeSlope" : null,
        "threshold" : null
      },
      "biasInit" : 0.0,
      "biasUpdater" : null,
      "constraints" : null,
      "gainInit" : 1.0,
      "gradientNormalization" : "None",
      "gradientNormalizationThreshold" : 1.0,
      "hasBias" : true,
      "hasLayerNorm" : false,
      "idropout" : null,
      "iupdater" : {
        "@class" : "org.nd4j.linalg.learning.config.Adam",
        "beta1" : 0.9,
        "beta2" : 0.999,
        "epsilon" : 1.0E-8,
        "learningRate" : 0.001
      },
      "layerName" : "layer0",
      "ni

In [17]:
val model = new MultiLayerNetwork(modelConf);

[36mmodel[39m: [32mMultiLayerNetwork[39m = org.deeplearning4j.nn.multilayer.MultiLayerNetwork@561c0025

In [18]:

import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.dataset.DataSet;

import org.deeplearning4j.optimize.api.BaseTrainingListener;
import java.io.Serializable;
import com.wandb.client._
import org.json.JSONObject;

case class WandBListener(logIteration: Int = 10,
                         testDataset: DataSet,
                         run: WandbRun) extends BaseTrainingListener with Serializable {

  if(logIteration <= 0) throw new Exception(s"Iteration must be greater than 0")

  override def iterationDone(model: Model, iteration: Int, epoch: Int): Unit = { 
    if(iteration % logIteration == 0) {
      val trainingScore = model.score();
      val testScore = model.asInstanceOf[MultiLayerNetwork].score(testDataset)
      val data = new JSONObject() 
      data.put("epoch", epoch)
      data.put("iteration", iteration)
      data.put("train_loss", trainingScore)
      data.put("test_loss", testScore)
      run.log(data)
      println(s"Score on train dataset at iteration $iteration is $trainingScore")
      println(s"Score on test dataset at iteration $iteration is $testScore")
    }
  }
}

[32mimport [39m[36morg.deeplearning4j.nn.api.Model;
[39m
[32mimport [39m[36morg.nd4j.linalg.dataset.DataSet;

[39m
[32mimport [39m[36morg.deeplearning4j.optimize.api.BaseTrainingListener;
[39m
[32mimport [39m[36mjava.io.Serializable;
[39m
[32mimport [39m[36mcom.wandb.client._
[39m
[32mimport [39m[36morg.json.JSONObject;

[39m
defined [32mclass[39m [36mWandBListener[39m

In [22]:
val runBuilder = new WandbRun.Builder()
runBuilder.withConfig(config).withProject("dl4j-wandb-java-client").setTags(tags).setJobType("training")
val run = runBuilder.build

model.setListeners(new WandBListener(10, test, run));  
model.init();

(1 to 50).foreach{
 i => model.fit(train)
 if(i % 5 == 0){
    val testLogLoss = model.score(test)
    val trainingLogLoss = model.score(train)
 }
}

run.finish()

Score on train dataset at iteration 50 is 2.319589261350473
Score on test dataset at iteration 50 is 2.2378570343817743
Score on train dataset at iteration 60 is 1.8830322212611519
Score on test dataset at iteration 60 is 1.824831048102856
Score on train dataset at iteration 70 is 1.6155475707597593
Score on test dataset at iteration 70 is 1.564484157617173
Score on train dataset at iteration 80 is 1.4439225542252019
Score on test dataset at iteration 80 is 1.4044416607112717
Score on train dataset at iteration 90 is 1.325087383083493
Score on test dataset at iteration 90 is 1.3071131304422248


[36mrunBuilder[39m: [32mWandbRun[39m.[32mBuilder[39m = com.wandb.client.WandbRun$Builder@25130b53
[36mres21_1[39m: [32mWandbRun[39m.[32mBuilder[39m = com.wandb.client.WandbRun$Builder@25130b53
[36mrun[39m: [32mWandbRun[39m = com.wandb.client.WandbRun@2aaf4c14