-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlp-Fs-gambit.sc
95 lines (78 loc) · 2.23 KB
/
mlp-Fs-gambit.sc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
(include "common-gambit.sc")
;;; Representation for weights:
;;; list with one element for each layer following the input;
;;; each such list has one element for each unit in that layer;
;;; which consists of a bias, followed by the weights for each
;;; unit in the previous layer.
;;; Basic MLP
(define (sum-activities activities)
(lambda (bias-ws)
(let ((bias (first bias-ws)) (ws (rest bias-ws)))
((reduce d+ bias) (map d* ws activities)))))
(define (sum-layer activities ws-layer)
(map (sum-activities activities) ws-layer))
(define (sigmoid x) (d/ 1 (d+ (dexp (d- 0 x)) 1)))
(define (forward-pass ws-layers)
(lambda (in)
(if (null? ws-layers)
in
((forward-pass (cdr ws-layers))
(map sigmoid (sum-layer in (first ws-layers)))))))
(define (error-on-dataset dataset)
(lambda (ws-layers)
((reduce d+ 0)
(map (lambda (in-target)
(let ((in (first in-target))
(target (second in-target)))
(d* 0.5
(magnitude-squared (v- ((forward-pass ws-layers) in) target)))))
dataset))))
;;; Scaled structure subtraction
(define (s-k* x k y)
(cond ((real? x) (d- x (d* k y)))
((pair? x) (cons (s-k* (car x) k (car y))
(s-k* (cdr x) k (cdr y))))
(else x)))
;;; Vanilla gradient optimization.
;;; Gradient minimize f starting at w0 for n iterations via
;;; w(t+1) = w(t) - eta * grad_w f.
;;; returns the last f(w)
(define (weight-gradient f)
(lambda (ws)
((map-n
(lambda (li)
(let ((ll (list-ref ws li)))
((map-n
(lambda (ui)
((map-n (lambda (wi)
((derivative-F
(lambda (x)
(f (replace-ith
ws
li
(replace-ith
(list-ref ws li)
ui
(replace-ith
(list-ref (list-ref ws li) ui) wi x))))))
(list-ref (list-ref (list-ref ws li) ui) wi))))
(length (list-ref ll ui)))))
(length ll)))))
(length ws))))
(define (vanilla f w0 n eta)
(if (dzero? n)
(f w0)
(vanilla f (s-k* w0 eta ((weight-gradient f) w0)) (d- n 1) eta)))
;;; XOR network
(define (xor-ws0)
'(((0 -0.284227 1.16054) (0 0.617194 1.30467))
((0 -0.084395 0.648461))))
(define (xor-data)
'(((0 0) (0))
((0 1) (1))
((1 0) (1))
((1 1) (0))))
(define (run)
(write-real
(vanilla (error-on-dataset (xor-data)) (xor-ws0) 1000000 0.3)))
(run)