diff --git a/dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/tensorflow/Learn.scala b/dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/tensorflow/Learn.scala index 0b8ee1045..b9c3862d1 100644 --- a/dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/tensorflow/Learn.scala +++ b/dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/tensorflow/Learn.scala @@ -21,7 +21,7 @@ package io.github.mandar2812.dynaml.tensorflow import _root_.io.github.mandar2812.dynaml.pipes.DataPipe import _root_.io.github.mandar2812.dynaml.models.TFModel import io.github.mandar2812.dynaml.tensorflow.layers.{DynamicTimeStepCTRNN, FiniteHorizonCTRNN, FiniteHorizonLinear} -import org.platanios.tensorflow.api.learn.{Mode, StopCriteria} +import org.platanios.tensorflow.api.learn.{ClipGradients, Mode, NoClipGradients, StopCriteria} import org.platanios.tensorflow.api.learn.layers.{Input, Layer} import org.platanios.tensorflow.api.ops.NN.SameConvPadding import org.platanios.tensorflow.api.ops.io.data.Dataset @@ -492,7 +492,8 @@ private[tensorflow] object Learn { stopCriteria: StopCriteria, stepRateFreq: Int = 5000, summarySaveFreq: Int = 5000, - checkPointFreq: Int = 5000)( + checkPointFreq: Int = 5000, + clipGradients: ClipGradients = NoClipGradients)( training_data: Dataset[ (IT, TT), (IO, TO), (ID, TD), (IS, TS)], @@ -503,7 +504,8 @@ private[tensorflow] object Learn { val model = tf.learn.Model.supervised( input, architecture, target, processTarget, - loss, optimizer) + loss, optimizer, + clipGradients) println("\nTraining model.\n") @@ -597,9 +599,10 @@ private[tensorflow] object Learn { optimizer: Optimizer, summariesDir: java.nio.file.Path, stopCriteria: StopCriteria, - stepRateFreq: Int, - summarySaveFreq: Int, - checkPointFreq: Int)( + stepRateFreq: Int = 5000, + summarySaveFreq: Int = 5000, + checkPointFreq: Int = 5000, + clipGradients: ClipGradients = NoClipGradients)( training_data: Dataset[IT, IO, ID, IS], inMemory: Boolean): UnsupModelPair[IT, IO, ID, IS, I] = {