-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlp-R-haskell-ad.hs
71 lines (56 loc) · 1.74 KB
/
mlp-R-haskell-ad.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}
import Common_Haskell_AD
-- import Crumple (crumple2, uncrumple2)
import Numeric.AD (AD)
import Data.Reflection (Reifies)
import Numeric.AD.Internal.Reverse (Tape, Reverse)
sum_activities activities (bias:ws) =
foldl (+) bias (zipWith (*) ws activities)
sum_layer activities ws_layer = map (sum_activities activities) ws_layer
sigmoid x = 1/(exp (0-x) + 1)
forward_pass [] in1 = in1
forward_pass (ws_layer:ws_layers) in1 =
forward_pass ws_layers (map sigmoid (sum_layer in1 ws_layer))
error_on_dataset dataset ws_layers =
foldl (+)
0
(map (\ (in1, target) ->
0.5 *
(magnitude_squared
(vminus (forward_pass ws_layers in1) target)))
dataset)
s_kstar ws k y =
zipWith (\ l y ->
zipWith (\ u y ->
zipWith (\ w y -> w-k*y)
u y)
l y)
ws y
weight_gradient ::
Num a =>
(forall s. Reifies s Tape => [[[Reverse s a]]] -> Reverse s a)
-> [[[a]]] -> [[[a]]]
weight_gradient f ws =
crumple2 ws $ gradient_R (f . crumple2 ws) (uncrumple2 ws ws)
vanilla ::
Num a =>
(forall s. Reifies s Tape => [[[Reverse s a]]] -> Reverse s a)
-> [[[a]]] -> Int -> a -> a
vanilla f w0 n eta =
if n==0
then lower_fs_R (f . crumple2 w0) (uncrumple2 w0 w0)
else vanilla f (s_kstar w0 eta (weight_gradient f w0)) (n-1) eta
xor_ws0 = [[[0, -0.284227, 1.16054],
[0, 0.617194, 1.30467]],
[[0, -0.084395, 0.648461]]]
xor_data = [([0, 0], [0]),
([0, 1], [1]),
([1, 0], [1]),
([1, 1], [0])]
run = vanilla (error_on_dataset xor_data)
xor_ws0
1000000
0.3
main = print run