In [1]:
import sys
import numpy as np
import sys
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import dataset
from sklearn.utils import shuffle
import tensorflow as tf
from time import time
import pickle

train_data_size = 4000
batch_size = 50
DEVICE = 1
n_epochs = 300
learning_rate = 0.01
dropout_rate = 0.4

def compute_error(y_pred, y):
    result = [1.0 if y1==y2 else 0.0
              for y1, y2 in zip(y_pred.to('cpu'), y.type(torch.LongTensor).to('cpu'))]
    return 1.0 - sum(result)/len(result)

def _test_loss_err(model, loss_fn, loader):
    errs, losses = [], []
    for (x, y) in loader:
        x, y = x.type(torch.FloatTensor), y.type(torch.FloatTensor)
        x, y = x.to(DEVICE), y.to(DEVICE)
        y_pred = model(x)
        loss_val = loss_fn(y_pred.type(torch.FloatTensor), y.type(torch.LongTensor))
        vals, y_pred = torch.max(y_pred, 1)
        error = compute_error(y_pred, y)
        losses.append(loss_val.item())
        errs.append(error)

    loss = np.mean(losses)
    error = np.mean(errs)

    return loss, error

class Dataset(dataset.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __getitem__(self, index):
        return (self.data[index], self.labels[index])
    
    def __len__(self):
        return self.data.shape[0]

def print_log(dict_):
    buff = '|'.join(['['+str(k)+ ':' +"{0:.3f}".format(v)+']' for k, v in sorted(dict_.items())])
    sys.stdout.write('\r' + buff)
    sys.stdout.flush()
    
class TorchModel(nn.Module):
    def __init__(self, dropout_rate=0.0, init=None):
        super(TorchModel, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3,
                      out_channels=6,
                      kernel_size=5,
                      padding=0,
                      bias=True),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(p=dropout_rate))
        if init:
            conv1_init = init['conv1']
            self.conv1[0].weight = nn.Parameter(torch.FloatTensor(conv1_init))
        else:
            torch.nn.init.xavier_uniform_(self.conv1[0].weight)
        torch.nn.init.zeros_(self.conv1[0].bias)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6,
                      out_channels=12,
                      kernel_size=3,
                      bias=True),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(p=dropout_rate))
        if init:
            conv2_init = init['conv2']
            self.conv2[0].weight = nn.Parameter(torch.FloatTensor(conv2_init))
        else:
            torch.nn.init.xavier_uniform_(self.conv2[0].weight)
        torch.nn.init.zeros_(self.conv2[0].bias)
        
        self.logits = nn.Linear(432, 10)
        if init:
            logits_init = init['logits']
            logits_init = np.reshape(logits_init, [10, 432])
            self.logits.weight = nn.Parameter(torch.FloatTensor(logits_init))
        else:
            torch.nn.init.xavier_uniform_(self.logits.weight)
        torch.nn.init.zeros_(self.logits.bias)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.logits(x)
        return x
    
    def count_params(self):
         return sum(p.numel() for p in self.parameters() if p.requires_grad)

with open('init.pkl', 'rb') as fo:
    init = pickle.load(fo)
    
with open('data.pkl', 'rb') as fo:
    data = pickle.load(fo)


In [2]:
x_train = data['x_train']
x_test = data['x_test']
y_train = data['y_train']
y_test = data['y_test']
x_train, x_test = (np.reshape(x_train, [x_train.shape[0], 3, 32, 32]),
                   np.reshape(x_test, [x_test.shape[0], 3, 32, 32]))
min_errs = []
time_took = []
for i in range(1):
    
    train_loader = torch.utils.data.DataLoader(Dataset(x_train, y_train),
                                               batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(Dataset(x_test, y_test),
                                              batch_size=batch_size)
    start_time = time()
    model = TorchModel(dropout_rate=dropout_rate)
    model = model.to(DEVICE)

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
    loss_fn = loss_fn.to(DEVICE)

    train_log = {'train_loss':[],
                 'test_loss':[],
                 'train_error':[],
                 'test_error':[]}

    model.eval()
    loss_val, err_val = _test_loss_err(model, loss_fn, test_loader)
    train_log['test_loss'].append(loss_val.item())
    train_log['test_error'].append(err_val)
    loss_val, err_val = _test_loss_err(model, loss_fn, train_loader)
    train_log['train_loss'].append(loss_val.item())
    train_log['train_error'].append(err_val)

    for epoch in range(n_epochs):
        model.train()
        print_log({'epoch':epoch, 'err':train_log['test_error'][-1]})
        for (x, y) in train_loader:
            x, y = x.type(torch.FloatTensor).to(DEVICE), y.type(torch.FloatTensor).to(DEVICE)
            y_pred = model(x)
            loss_val = loss_fn(y_pred.type(torch.FloatTensor), y.type(torch.LongTensor))
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
        model.eval()
        loss_val, err_val = _test_loss_err(model, loss_fn, test_loader)
        train_log['test_loss'].append(loss_val.item())
        train_log['test_error'].append(err_val)
        loss_val, err_val = _test_loss_err(model, loss_fn, train_loader)
        train_log['train_loss'].append(loss_val.item())
        train_log['train_error'].append(err_val)
    with open('torch_log.pkl', 'wb') as fo:
        train_log['n_epochs'] = n_epochs
        pickle.dump(train_log, fo, protocol=pickle.HIGHEST_PROTOCOL)
    time_took.append(time() - start_time)
    min_errs.append(min(train_log['test_error']))
print()
print('time_took', np.mean(time_took))
print('min_error', np.mean(min_errs))

[epoch:299.000]|[err:0.654]
time_took 110.5572509765625
min_error 0.6471000000000001
