forked from karpathy/micrograd
-
Notifications
You must be signed in to change notification settings - Fork 5
/
mnist.py
127 lines (102 loc) · 3.48 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import functools
import itertools
import random
import struct
import sys
import time
from micrograd import nn as nn_interp
from micrograd.engine import Max
from tqdm import tqdm
tqdm.monitor_interval = 0
random.seed(1337)
sys.setrecursionlimit(20000)
class image:
def __init__(self, label, pixels):
self.label = label
self.pixels = pixels
IMAGE_HEIGHT = 28
IMAGE_WIDTH = 28
PIXEL_LENGTH = IMAGE_HEIGHT * IMAGE_WIDTH
DIM = PIXEL_LENGTH
class images:
def __init__(self, images_filename, labels_filename):
self.images = open(images_filename, "rb")
self.labels = open(labels_filename, "rb")
self.idx = 0
self.read_magic()
def read_magic(self):
images_magic = self.images.read(4)
assert images_magic == b"\x00\x00\x08\x03"
labels_magic = self.labels.read(4)
assert labels_magic == b"\x00\x00\x08\x01"
(self.num_images,) = struct.unpack(">L", self.images.read(4))
(self.num_labels,) = struct.unpack(">L", self.labels.read(4))
assert self.num_images == self.num_labels
nrows = self.images.read(4)
assert struct.unpack(">L", nrows) == (IMAGE_HEIGHT,)
ncols = self.images.read(4)
assert struct.unpack(">L", ncols) == (IMAGE_WIDTH,)
def read_image(self):
label_bytes = self.labels.read(1)
assert label_bytes
label = int.from_bytes(label_bytes, "big")
pixels = self.images.read(PIXEL_LENGTH)
assert pixels
self.idx += 1
return image(label, pixels)
def __iter__(self):
return self
def __next__(self):
if self.idx >= self.num_images:
raise StopIteration
return self.read_image()
def num_left(self):
return self.num_images - self.idx
def timer(lam, msg=""):
print(msg, end=" ")
before = time.perf_counter()
result = lam()
after = time.perf_counter()
delta = after - before
print(f"({delta:.2f} s)")
return result
def grouper(n, iterable, fillvalue=None):
"grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return itertools.zip_longest(fillvalue=fillvalue, *args)
def stable_softmax(output):
max_ = functools.reduce(Max, output)
shiftx = [o-max_ for o in output]
exps = [o.exp() for o in shiftx]
sum_ = sum(exps)
return [o/sum_ for o in exps]
NUM_DIGITS = 10
model = timer(lambda: nn_interp.MLP(DIM, [50, NUM_DIGITS]), "Building model...")
def loss_of(model, image):
output = model(image.pixels)
softmax_output = stable_softmax(output)
expected_onehot = [0. for _ in range(NUM_DIGITS)]
expected_onehot[image.label] = 1.
result = -sum(exp*(act+0.0001).log() for exp, act in zip(expected_onehot, softmax_output))
return result
print("Training...")
num_epochs = 100
db = list(images("train-images-idx3-ubyte", "train-labels-idx1-ubyte"))
batch_size = 1000
for epoch in range(num_epochs):
epoch_loss = 0.
before = time.perf_counter()
shuffled = db.copy()
random.shuffle(shuffled)
for batch_idx, batch in tqdm(enumerate(grouper(batch_size, shuffled))):
for p in model.parameters():
p.grad = 0.0
loss = sum(loss_of(model, im) for im in tqdm(batch))
loss.backward()
epoch_loss += loss.data
for p in model.parameters():
p.data -= 0.1 * p.grad
after = time.perf_counter()
delta = after - before
epoch_loss /= len(db)
print(f"...epoch {epoch:4d} loss {epoch_loss:.2f} (took {delta} sec)")