Skip to content

Using scala to implement tiny LSTM, mainly focusing on the BPTT process of training the network.

License

Notifications You must be signed in to change notification settings

xuanyuansen/scalaLSTM

Repository files navigation

###深入理解LSTM的BPTT算法 ####LSTM网络结构 关于LSTM网络的结构可以阅读这篇文章:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

这里需要注意文章最后提及的LSTM两种变形,第一种是加入peephole,使得gate layer能够回溯前一个cell的状态,这增加了一些复杂度;第二种是GRU,将gate layer和forget layer合并为一个update layer,降低了复杂度。 ####LSTM网络的训练 LSTM的训练使用了BPTT算法,需要重要理解的一点是BPTT算法相当于BP算法扩展到序列(时序)数据,另一个需要理解的点是LSTM是recurrent neural network(这里注意理解recurrent neural network和recursive neural network的区别),BPTT算法在计算中要注意这一点。

####LSTM的计算图Compute Graph

  • LSTM的BPTT算法可以参考这篇文章http://nicodjimenez.github.io/2014/08/08/lstm.html

  • 讲述很清晰,注意这篇文章里面最后的输出h(t)没有加入tanh变换。

  • 为了理解LSTM的recursive特性,可以参考下图。

  • 从LSTM的结构可以看到,当前cell的状态会受到前一个cell状态的影响,这体现了LSTM的recursive特性。同时在误差反向传播计算时,可以发现h(t)的误差不仅仅包含当前时刻T的误差,也包括T时刻后所有时刻的误差,即back propagation through time的含义。这样每个时刻变量的误差都可以经由h(t)和c(t+1)迭代计算。

  • 为了使整个直观计算过程,在参考神经网络计算图分解的基础上,LSTM的计算图如下图所示,从计算图上面可以直观地看出LSTM的forward propagation和back propagation过程。

  • 从图中可以看出,H(t-1)的误差由H(t)决定,且要对所有的gate layer求和,c(t-1)由c(t)决定,而c(t)的误差由两部分,一部分是h(t),领一部分是c(t+1)。

  • 如果所示,在计算的时候,需要传入h(t)和c(t+1),h(t)在更新的时候需要加上h(t+1)。

####SCALA实现

  • breeze库

####利用SPARK实现minibatch方式的训练

####几种常见的LSTM结构

  • 1、原始LSTM
  • 2、peephole LSTM
  • 3、GRU

About

Using scala to implement tiny LSTM, mainly focusing on the BPTT process of training the network.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published