In [6]:
import os
import sys

sys.path.append('../')
from src.Linear import Linear
from src.ReLU import ReLU
from src.Model import Model
from src.Criterion import Criterion
from src.Optimizer import SGDOptimizer
from src.utils import load_file, save_file
from src.ProgressBar import ProgressBar

import torch
import torchfile
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy

In [7]:
data_path = '../data/train/data.bin'
label_path = '../data/train/labels.bin'
model_dir = '../outputs/m1'
model_path = os.path.join(model_dir, 'model.bin')

os.makedirs(model_dir, exist_ok=True)
train_data = load_file(data_path)
train_labels = load_file(label_path)

In [8]:
class TrainDataLoader():
    
    def __init__(self, data, labels, batch_size=4, shuffle=True):
        
        self.data = deepcopy(data)
        self.labels = deepcopy(labels)
        self.num_data = data.shape[0]
        self.batch_size = batch_size
        
        assert self.data.shape[0] == self.labels.shape[0]
        
        # reshape data, flatten it
        self.data = self.data.reshape(self.num_data, -1)
        self.labels = self.labels.reshape(self.num_data, 1)
        
        if shuffle:
            self._shuffle_data()
            
        self.curr_pos = 0
        self.done = False
        
    def _shuffle_data(self):
        p = np.random.permutation(self.num_data)
        self.data = self.data[p]
        self.labels = self.labels[p]
        
    def get_batch(self):
        batch = (self.data[self.curr_pos : self.curr_pos + self.batch_size],
                 self.labels[self.curr_pos : self.curr_pos + self.batch_size])
        
        self.curr_pos += self.batch_size
        if self.curr_pos >= self.num_data:
            self.done = True
        
        return batch
    
"""
count = 0
index = 0
while count < 10:
    if dataloader.labels[index] == 5:
        plt.imshow(dataloader.data[index])
        plt.show()
        count += 1
    index += 1

count = 0
while not dataloader.done:
    dataloader.get_batch()[1]
    count += 1
count
"""

'\ncount = 0\nindex = 0\nwhile count < 10:\n    if dataloader.labels[index] == 5:\n        plt.imshow(dataloader.data[index])\n        plt.show()\n        count += 1\n    index += 1\n\ncount = 0\nwhile not dataloader.done:\n    dataloader.get_batch()[1]\n    count += 1\ncount\n'

In [22]:
def run_step(model, optimizer, loss, batch_data):
    
    # get inputs and targets
    inp, target = batch_data

    # zero the parameter gradients
    model.clearGradParam()

    # forward + backward
    out = model.forward(inp)
    loss_value = loss.forward(out, target)
    gradInput = loss.backward(out, target)
    model.backward(inp, gradInput)

    # optimize
    optimizer.step()
    
    return loss_value
    

def run_epoch(model, optimizer, loss, batch_size=4):
    
    dataloader = TrainDataLoader(train_data, 
                             train_labels, 
                             batch_size=batch_size)
    num_steps = int(dataloader.num_data / batch_size) + 1
    # progress = ProgressBar(num_steps, fmt=ProgressBar.FULL)
    
    step = 0
    while not dataloader.done:
        # progress.current += 1
        # progress()
        loss_value = run_step(model, optimizer, loss, dataloader.get_batch())
        
        if step % 100 == 0:
            print("[step %d] loss: %f" % (step, loss_value))
        
        step += 1
    
    model.save(model_path)

    
# ---------------------
model = Model([
    Linear(11664, 100),
    ReLU(),
    Linear(100, 50),
    ReLU(),
    Linear(50, 50),
    ReLU(),
    Linear(50, 6)
])
optimizer = SGDOptimizer(model, lr=1e-6, momentum=0.9, decay=0.0)
loss = Criterion()
run_epoch(model, optimizer, loss, batch_size=10)

[step 0] loss: 82.868899
[step 100] loss: 14.226470
[step 200] loss: 7.008267
[step 300] loss: 3.688465
[step 400] loss: 3.736008
[step 500] loss: 3.793732
[step 600] loss: 9.840537
[step 700] loss: 7.321252
[step 800] loss: 3.105856
[step 900] loss: 2.699834
[step 1000] loss: 6.189293
[step 1100] loss: 5.485415
[step 1200] loss: 3.707459
[step 1300] loss: 3.548865
[step 1400] loss: 3.346508
[step 1500] loss: 3.827556
[step 1600] loss: 2.921995
[step 1700] loss: 3.523190
[step 1800] loss: 3.363484
[step 1900] loss: 3.428165
[step 2000] loss: 1.719549
[step 2100] loss: 3.002167
[step 2200] loss: 4.200880
[step 2300] loss: 2.884273
[step 2400] loss: 4.854297
[step 2500] loss: 2.183768
[step 2600] loss: 1.952941
[step 2700] loss: 2.054574
[step 2800] loss: 2.996145
[step 2900] loss: 1.578286
