Skip to content

Commit

Permalink
Merge pull request #41 from vanderhe/leakyReLUTransfer
Browse files Browse the repository at this point in the history
Add more modern activation functions
  • Loading branch information
vanderhe committed Oct 15, 2021
2 parents 6c01b6f + 6bcdcc2 commit 150ce1f
Show file tree
Hide file tree
Showing 25 changed files with 5,891 additions and 2,815 deletions.
690 changes: 690 additions & 0 deletions doc/recipes/docs/_figures/transfer/atan.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
540 changes: 540 additions & 0 deletions doc/recipes/docs/_figures/transfer/bent.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,105 changes: 562 additions & 543 deletions doc/recipes/docs/_figures/transfer/gaussian.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
981 changes: 504 additions & 477 deletions doc/recipes/docs/_figures/transfer/heaviside.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
668 changes: 346 additions & 322 deletions doc/recipes/docs/_figures/transfer/linear.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
546 changes: 546 additions & 0 deletions doc/recipes/docs/_figures/transfer/lrelu.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
100 changes: 97 additions & 3 deletions doc/recipes/docs/_figures/transfer/plot_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ def main():
'''Main driver routine.'''

plot_tanh()
plot_atan()
plot_sigmoid()
plot_softplus()
plot_gaussian()
plot_linear()
plot_relu()
plot_lrelu()
plot_bent()
plot_heaviside()


Expand All @@ -46,6 +50,30 @@ def plot_tanh():
plt.show()


def plot_atan():
'''Plots the arcus tangent in given interval.'''

xx = np.linspace(-10.0, 10.0, 1000)
atan = np.arctan(xx)

plt.figure(1, figsize=[7, 5])
plt.title('Arcus Tangent')
plt.xlabel(r'$x$')
plt.ylabel(r'$\arctan(x)$')

plt.xticks((-10.0, 0.0, 10.0))
plt.ylim((-np.pi / 2.0, np.pi / 2.0))
plt.yticks((-np.pi / 2.0, 0.0, np.pi / 2.0),
(r'$-\frac{\pi}{2}$', '0', r'$\frac{\pi}{2}$'))

plt.hlines(0.0, np.min(xx), np.max(xx), color='gray', linestyle='dashed')
plt.plot(xx, atan)
plt.tight_layout()

plt.savefig('atan.svg', dpi=900, format='svg')
plt.show()


def plot_sigmoid():
'''Plots the sigmoid function in given interval.'''

Expand All @@ -67,6 +95,26 @@ def plot_sigmoid():
plt.show()


def plot_softplus():
'''Plots the SoftPlus function in given interval.'''

xx = np.linspace(-3.0, 3.0, 1000)
softplus = np.log(1.0 + np.exp(xx))

plt.figure(1, figsize=[7, 5])
plt.title('SoftPlus Function')
plt.xlabel(r'$x$')
plt.ylabel(r'SoftPlus')

plt.yticks((-1.0, 0.0, 1.0, 2.0, 3.0))

plt.plot(xx, softplus)
plt.tight_layout()

plt.savefig('softplus.svg', dpi=900, format='svg')
plt.show()


def plot_gaussian():
'''Plots gaussian function in given interval.'''

Expand All @@ -76,7 +124,7 @@ def plot_gaussian():
plt.figure(1, figsize=[7, 5])
plt.title('Gaussian Function')
plt.xlabel(r'$x$')
plt.ylabel(r'$S(x)$')
plt.ylabel(r'$G(x)$')

plt.yticks((0.0, 1.0))

Expand Down Expand Up @@ -111,15 +159,15 @@ def plot_linear():


def plot_relu():
'''Plots relu function in given interval.'''
'''Plots ReLU function in given interval.'''

xx = np.linspace(-1.0, 1.0, 1000)
relu = xx * (xx > 0.0)

plt.figure(1, figsize=[7, 5])
plt.title('ReLU Function')
plt.xlabel(r'$x$')
plt.ylabel(r'$R(x)$')
plt.ylabel('ReLU')

plt.xticks((-1.0, 0.0, 1.0))
plt.yticks((-1.0, 0.0, 1.0))
Expand All @@ -131,6 +179,52 @@ def plot_relu():
plt.show()


def plot_lrelu():
'''Plots leaky ReLU function in given interval.'''

xx = np.linspace(-1.0, 1.0, 1000)
lrelu = np.array([max(0.01 * val, val) for val in xx], dtype=float)

plt.figure(1, figsize=[7, 5])
plt.title('Leaky ReLU Function')
plt.xlabel(r'$x$')
plt.ylabel('Leaky ReLU')

plt.xticks((-1.0, 0.0, 1.0))
plt.yticks((-1.0, 0.0, 1.0))

plt.hlines(0.0, np.min(xx), np.max(xx), color='gray', linestyle='dashed')

plt.plot(xx, lrelu)
plt.tight_layout()

plt.savefig('lrelu.svg', dpi=900, format='svg')
plt.show()


def plot_bent():
'''Plots Bent identity function in given interval.'''

xx = np.linspace(-1.0, 1.0, 1000)
bent = (np.sqrt(xx**2 + 1.0) - 1.0) / 2.0 + xx

plt.figure(1, figsize=[7, 5])
plt.title('Bent Identity Function')
plt.xlabel(r'$x$')
plt.ylabel('Bent identity')

plt.xticks((-2.0, -1.0, 0.0, 1.0))
plt.yticks((-1.0, 0.0, 1.0))

plt.hlines(0.0, np.min(xx), np.max(xx), color='gray', linestyle='dashed')

plt.plot(xx, bent)
plt.tight_layout()

plt.savefig('bent.svg', dpi=900, format='svg')
plt.show()


def plot_heaviside():
'''Plots heaviside function in given interval.'''

Expand Down

0 comments on commit 150ce1f

Please sign in to comment.