# CIFAR-10 dataset and in-situ training

## Crossbar

In [None]:
"""
crossbar.py
Louis Primeau
University of Toronto Department of Electrical and Computer Engineering
louis.primeau@mail.utoronto.ca
July 29th 2020
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import itertools
import time
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.io import savemat
import torchvision
import torch.optim as optim
import math
from scipy.io import savemat

# Implements scipy's minmax scaler except just between 0 and 1 for torch Tensors.
# Taken from a ptrblck post on the PyTorch forums. Love that dude.
class MinMaxScaler(object):
    def __call__(self, tensor):
        self.scale = 1.0 / (tensor.max(dim=1, keepdim=True)[0] - tensor.min(dim=1, keepdim=True)[0])
        self.min = tensor.min(dim=1, keepdim=True)[0]
        tensor.sub_(self.min).mul_(self.scale)
        return tensor
    def inverse_transform(self, tensor):
        tensor.div_(self.scale).add_(self.min)
        return tensor


In [None]:
class ticket:
    def __init__(self, row, col, m_rows, m_cols, matrix, mat_scale_factor, crossbar, uvect, decode, inputres, outputres):
        self.row, self.col = row, col
        self.m_rows, self.m_cols = m_rows, m_cols
        self.crossbar = crossbar
        self.mat_scale_factor = mat_scale_factor
        self.matrix = matrix
        self.uvect = uvect
        self.inputres = inputres
        self.adcres = outputres
        self.decode = torch.matmul(self.uvect.t(),self.matrix)

    def prep_vector(self, vector, v_bits):

        # Scale vector to [0, 2^v_bits]
        vect_min = torch.min(vector)
        vector = vector - vect_min
        vect_scale_factor = torch.max(vector) / (2**v_bits - 1)
        vector = vector / vect_scale_factor if vect_scale_factor != 0.0 else vector

        # decompose vector by bit
        bit_vector = torch.zeros(vector.size(0),v_bits)
        bin2s = lambda x : ''.join(reversed( [str((int(x) >> i) & 1) for i in range(v_bits)] ) )
        for j in range(vector.size(0)):
            bit_vector[j,:] = torch.Tensor([float(i) for i in list(bin2s(vector[j]))])
        bit_vector *= self.crossbar.V

        # Pad bit vector with unselected voltages
        pad_vector = torch.zeros(self.crossbar.size[0], v_bits)

        pad_vector[self.row:self.row + self.m_rows,:] = bit_vector

        return pad_vector, vect_scale_factor, vect_min

    def vmm(self, vector):
        # Baseline VMM operation without CODEX
        v_bits = self.inputres
        assert vector.size(1) == 1, "vector wrong shape"

        crossbar = self.crossbar
        # Rescale vector and convert to bits.
        pad_vector, vect_scale_factor, vect_min = self.prep_vector(vector, v_bits)

        rW = self.crossbar.W[0:(self.matrix.shape[0]),0:(2*self.matrix.shape[1])]
        rW = rW[:,1::2] - rW[:,0::2]

        # Perform crossbar VMM
        rV = torch.transpose(pad_vector[0:vector.size(0)],0,1)
        rout = torch.matmul(rV, rW)

        # Round rout to input ADC resolution
        rout_scale_factor = torch.max(rout) / (2**self.adcres - 1)
        rout = rout / rout_scale_factor
        rout = torch.round(rout)
        rout = rout * rout_scale_factor

        # Add binary outputs
        for i in range(rout.size(0)):
            rout[i] *= 2**(v_bits - i - 1)
        rout = torch.sum(rout, axis=0)

        # Rescale binary outputs
        rout = (rout / crossbar.V * vect_scale_factor*self.mat_scale_factor) / 1.5131 + torch.sum(vect_min*self.matrix,axis=0)
        return rout.view(-1,1)

    def CODEXvmm(self, xvector):
        # CODEX VMM operation
        assert xvector.size(1) == 1, "vector wrong shape"
        v_bits=self.inputres
        crossbar = self.crossbar

        #Add encoding vector u to x
        vector = xvector + self.uvect
        pad_vector, vect_scale_factor, vect_min = self.prep_vector(vector, v_bits+1)

        rW = self.crossbar.W[0:(self.matrix.shape[0]),0:(2*self.matrix.shape[1])]
        rW = rW[:,1::2] - rW[:,0::2]

        rV = torch.transpose(pad_vector[0:vector.size(0)],0,1)
        # The rout on the line below this comment contains
        # the raw output currents that the ADC will receive.
        rout = torch.matmul(rV, rW)

        # Round rout to input ADC resolution
        rout_scale_factor = torch.max(rout) / (2**self.adcres - 1)
        rout = rout / rout_scale_factor
        rout = torch.round(rout)
        rout = rout * rout_scale_factor

        for i in range(rout.size(0)):
            rout[i] *= 2**(v_bits - i - 1)
        rout = torch.sum(rout, axis=0)
        rout = 2*(rout / crossbar.V * vect_scale_factor*self.mat_scale_factor) / 1.5231 + torch.sum(vect_min*self.matrix,axis=0)
        rout = rout - self.decode
        return rout.view(-1,1)

    def modified_CODEXvmm(self, xvector):
        # CODEX VMM operation
        assert xvector.size(1) == 1, "vector wrong shape"
        v_bits=self.inputres
        crossbar = self.crossbar

        #Add encoding vector u to x
        vector = xvector + self.uvect
        pad_vector, vect_scale_factor, vect_min = self.prep_vector(vector, v_bits+1)

        rW = self.crossbar.W[0:(self.matrix.shape[0]),0:(2*self.matrix.shape[1])]
        rW = rW[:,1::2] - rW[:,0::2]

        rV = torch.transpose(pad_vector[0:vector.size(0)],0,1)
        # The rout on the line below this comment contains
        # the raw output currents that the ADC will receive.
        rout = torch.matmul(rV, rW)

        # Round rout to input ADC resolution
        rout_scale_factor = torch.max(rout) / (2**self.adcres - 1)
        rout = rout / rout_scale_factor
        rout = torch.round(rout)
        rout = rout * rout_scale_factor

        for i in range(rout.size(0)):
            rout[i] *= 2**(v_bits - i - 1)
        rout = torch.sum(rout, axis=0)
        rout = 2*(rout / crossbar.V * vect_scale_factor*self.mat_scale_factor) / 1.5231 + torch.sum(vect_min*self.matrix,axis=0)

        # do not decode except during inference
        # rout = rout - self.decode
        return rout.view(-1,1)


### Linear

In [None]:
import torch.optim as optim
import random
import copy

class linear(torch.autograd.Function):
    #From Louis: Custom pytorch autograd function for crossbar VMM operation
    @staticmethod
    def forward(ctx, ticket, x, W, b):
        ctx.save_for_backward(x, W, b)
        #return ticket.CODEXvmm(x) + b
        return ticket.CODEXvmm(x) + b

    @staticmethod
    def backward(ctx, dx):
        x, W, b = ctx.saved_tensors
        grad_input = W.t().mm(dx)
        grad_weight = dx.mm(x.t())
        grad_bias = dx
        return (None, grad_input, grad_weight, grad_bias)

class Linear(torch.nn.Module):
    def __init__(self, input_size, output_size, cb,uvect):
        super(Linear, self).__init__()
        self.W = torch.nn.parameter.Parameter(torch.rand(output_size, input_size))
        self.b = torch.nn.parameter.Parameter(torch.rand(output_size, 1))
        self.cb = cb

        #Instantiate Linear layer with pool of random encoding vectors to sample from
        self.uvectlist = uvect
        self.uvectidx = 0
        # Decoding vector is calculated ideally here off-chip, but calculating decoding vector on-chip is also possible
        self.decode = torch.matmul(self.uvectlist[self.uvectidx].t(),torch.transpose(self.W,0,1)).detach().clone()
        self.ticket = cb.register_linear(torch.transpose(self.W,0,1),self.uvectlist[self.uvectidx],self.decode)
        self.f = linear()
        self.cbon = False

    def forward(self, x):
        return self.f.apply(self.ticket, x, self.W, self.b) if self.cbon else self.W.matmul(x) + self.b

    def remap(self):
        #Should call the remap crossbar function after 1 or a couple update steps
        self.cb.clear()
        self.ticket = self.cb.register_linear(torch.transpose(self.W,0,1),self.uvectlist[self.uvectidx],self.decode)

    def update_decode(self):
        #Update decoding vector by updating U*G.
        self.decode = torch.matmul(self.uvectlist[self.uvectidx].t(),torch.transpose(self.W,0,1)).detach().clone()

    def resample(self):
        #Sample random new uvector from provided uvectlist
        self.cb.clear()
        self.uvectidx = random.randint(0, len(uvectlist)-1)
        self.ticket = self.cb.register_linear(torch.transpose(self.W,0,1),self.uvectlist[self.uvectidx],self.decode)

    def use_cb(self, state):
        self.cbon = state


In [None]:
class crossbar:
    def __init__(self, device_params):

        # Power Supply Voltage
        self.V = device_params["Vdd"]

        # DAC resolution
        self.input_resolution = device_params["dac_resolution"]
        self.output_resolution = device_params["adc_resolution"]

        # Wordline Resistance
        self.r_wl = torch.Tensor((device_params["r_wl"],))
        # Bitline Resistance
        self.r_bl = torch.Tensor((device_params["r_bl"],))

        # Number of rows, columns
        self.size = device_params["m"], device_params["n"]

        # High resistance state
        self.g_on = 1 / torch.normal(device_params["r_on_mean"], device_params["r_on_stddev"], size=self.size)
        #self.g_on = (1 / device_params["r_on_mean"]) * torch.ones(self.size)

        # Low Resistance state
        self.g_off = 1 / torch.normal(device_params["r_off_mean"], device_params["r_off_stddev"], size=self.size)
        #self.g_off = (1 / device_params["r_off_mean"]) * torch.ones(self.size)

        self.g_wl = torch.Tensor((1 / device_params["r_wl"],))
        self.g_bl = torch.Tensor((1 / device_params["r_bl"],))

        # Resolution
        self.resolution = device_params["device_resolution"]
        # Conductance tensor, m x n x 2**resolution

        # 2**self.resolution - 1 so that there's a conductance state in the middle.
        self.conductance_states = torch.cat([torch.cat([torch.linspace(self.g_off[i,j], self.g_on[i,j],2**self.resolution - 1).unsqueeze(0)
                                                        for j in range(self.size[1])],dim=0).unsqueeze(0)
                                             for i in range(self.size[0])],dim=0)

        # Bias Scheme
        self.bias_voltage = self.V * device_params["bias_scheme"]

        # Tile size (1x1 = 1T1R, nxm = passive, etc.)
        self.tile_rows = device_params["tile_rows"]
        self.tile_cols = device_params["tile_cols"]
        assert self.size[0] % self.tile_rows == 0, "tile size does not divide crossbar size in row direction"
        assert self.size[1] % self.tile_cols == 0, "tile size does not divide crossbar size in col direction"

        # Resistance of CMOS lines
        self.r_cmos_line = device_params["r_cmos_line"]

        # Conductance Matrix; initialize each memristor at the on resstance
        self.W = torch.ones(self.size) * self.g_on

        # Stuck-on & stuck-on device nonideality
        self.p_stuck_on = device_params["p_stuck_on"]
        self.p_stuck_off = device_params["p_stuck_off"]
        self.devicefaults = False

        self.mapped = []

        self.saved_tiles = {}

    def apply_stuck(self, p_stuck_on, p_stuck_off):

        state_dist = torch.distributions.categorical.Categorical(probs=torch.Tensor([p_stuck_on, p_stuck_off, 1 - p_stuck_on - p_stuck_off]))
        state_mask = state_dist.sample(self.size)

        self.W[state_mask == 0] = self.g_off[state_mask==0]
        self.W[state_mask == 1] = self.g_on[state_mask==1]

        return None

    def map(self, matrix):
        assert not(matrix.size(0) > self.size[0] or matrix.size(1)*2 > self.size[1]), "input too large"
        midpoint = self.conductance_states.size(2) // 2

        for i in range(matrix.size(0)):
            for j in range(matrix.size(1)):

                shifted = self.conductance_states[i,j] - self.conductance_states[i,j,midpoint]
                idx = torch.min(torch.abs(shifted - matrix[i,j]), dim=0)[1]

                self.W[i,2*j+1] = self.conductance_states[i,j,idx]
                self.W[i,2*j] = self.conductance_states[i,j,midpoint-(idx-midpoint)]

    def solve(self, voltage):
        output = torch.zeros((voltage.size(1), self.size[1]))
        for i in range(self.size[0] // self.tile_rows):
            for j in range(self.size[1] // self.tile_cols):
                for k in range(voltage.size(1)):
                    coords = (i*self.tile_rows, (i+1)*self.tile_rows, j*self.tile_cols, (j+1)*self.tile_rows)
                    vect = voltage[i*self.tile_rows:(i+1)*self.tile_rows,k]
                    solution = self.circuit_solve(coords, vect, torch.zeros(self.size[1]), torch.ones(self.size[1]), torch.zeros(self.size[0]))
                    output[k] += torch.cat((torch.zeros(j*self.tile_cols), solution, torch.zeros((self.size[1] // self.tile_cols - j - 1) * self.tile_cols)))
        return output

    """
    A Comprehensive Crossbar Array Model With Solutions for Line Resistance and Nonlinear Device Characteristics
    An Chen
    IEEE TRANSACTIONS ON ELECTRON DEVICES, VOL. 60, NO. 4, APRIL 2013
    """

    def hash_M(self, a, b, c, d):
        return str(a) + "_" + str(b) + "_" + str(c) + "_" + str(d)

    def make_M(self, a, b, c, d):

        conductances = self.W[a:b,c:d]
        g_wl, g_bl = self.g_wl, self.g_bl
        g_s_wl_in, g_s_wl_out = torch.ones(self.tile_rows) * 1, torch.ones(self.tile_rows) * 1e-9
        g_s_bl_in, g_s_bl_out = torch.ones(self.tile_rows) * 1e-9, torch.ones(self.tile_rows) * 1
        m, n = self.tile_rows, self.tile_cols

        A = torch.block_diag(*tuple(torch.diag(conductances[i,:])
                          + torch.diag(torch.cat((g_wl, g_wl * 2 * torch.ones(n-2), g_wl)))
                          + torch.diag(g_wl * -1 *torch.ones(n-1), diagonal = 1)
                          + torch.diag(g_wl * -1 *torch.ones(n-1), diagonal = -1)
                          + torch.diag(torch.cat((g_s_wl_in[i].view(1), torch.zeros(n - 2), g_s_wl_out[i].view(1))))
                                   for i in range(m)))

        B = torch.block_diag(*tuple(-torch.diag(conductances[i,:]) for i in range(m)))

        def makec(j):
            c = torch.zeros(m, m*n)
            for i in range(m):
                c[i,n*(i) + j] = conductances[i,j]
            return c

        C = torch.cat([makec(j) for j in range(n)],dim=0)

        def maked(j):
            d = torch.zeros(m, m*n)

            def c(k):
                return(k - 1)

            i = 1
            d[c(i),c(j)] = -g_s_bl_in[c(j)] - g_bl - conductances[c(i),c(j)]
            d[c(i), n*i + c(j)] = g_bl

            i = m
            d[c(i), n*(i-2) + c(j)] = g_bl
            d[c(i), n*(i-1) + c(j)] = -g_s_bl_out[c(j)] - conductances[c(i),c(j)] - g_bl

            for i in range(2, m):
                d[c(i), n*(i-2) + c(j)] = g_bl
                d[c(i), n*(i-1) + c(j)] = -g_bl - conductances[c(i),c(j)] - g_bl
                d[c(i), n*(i+1) + c(j)] = g_bl

            return d

        D = torch.cat([maked(j) for j in range(1,n+1)], dim=0)

        M = torch.cat((torch.cat((A,B),dim=1), torch.cat((C,D),dim=1)), dim=0)

        self.saved_tiles[self.hash_M(a,b,c,d)] = M

        return torch.inverse(M)

    def circuit_solve(self, coords,  v_wl_in, v_bl_in, v_bl_out, v_wl_out):

        g_wl, g_bl = self.g_wl, self.g_bl
        g_s_wl_in, g_s_wl_out = torch.ones(self.tile_rows) * 1, torch.ones(self.tile_rows) * 1e-9
        g_s_bl_in, g_s_bl_out = torch.ones(self.tile_rows) * 1e-9, torch.ones(self.tile_rows) * 1
        m, n = self.tile_rows, self.tile_cols


        if self.hash_M(*coords) not in self.saved_tiles.keys():
            #print(coords)
            M = self.make_M(*coords)
        else:
            M = self.saved_tiles[self.hash_M(*coords)]

        E = torch.cat([torch.cat(((v_wl_in[i]*g_s_wl_in[i]).view(1), #EW
                                  torch.zeros(n-2),
                                  (v_wl_out[i]*g_s_wl_out[i]).view(1)))
                                 for i in range(m)] +
                      [torch.cat(((-v_bl_in[i]*g_s_bl_in[i]).view(1), #EB
                                  torch.zeros(m-2),
                                  (-v_bl_in[i]*g_s_bl_out[i]).view(1)))
                                 for i in range(n)]
        ).view(-1, 1)

        V = torch.matmul(M, E)

        V = torch.chunk(torch.solve(E, M)[0], 2)

        return torch.sum((V[1] - V[0]).view(m,n)*self.W[coords[0]:coords[1],coords[2]:coords[3]],dim=0)

    def register_linear(self, matrix, uvectlist, decode, bias=None):

        row, col = self.find_space(matrix.size(0), matrix.size(1))
        # Need to add checks for bias size and col size

        # Scale matrix
        mat_scale_factor = torch.max(torch.abs(matrix)) / torch.max(self.g_on) * 2
        scaled_matrix = matrix / mat_scale_factor

        midpoint = self.conductance_states.size(2) // 2
        for i in range(row, row + scaled_matrix.size(0)):
            for j in range(col, col + scaled_matrix.size(1)):

                shifted = self.conductance_states[i,j] - self.conductance_states[i,j,midpoint]
                idx = torch.min(torch.abs(shifted - scaled_matrix[i-row,j-col]), dim=0)[1]
                self.W[i,2*j+1] = self.conductance_states[i,j,idx]
                self.W[i,2*j] = self.conductance_states[i,j,midpoint-(idx-midpoint)]

        return ticket(row, col, matrix.size(0), matrix.size(1), matrix, mat_scale_factor, self, uvectlist, decode, self.input_resolution, self.output_resolution)

    def which_tiles(self, row, col, m_row, m_col):
        return itertools.product(range(row // self.tile_rows, (row + m_row) // self.tile_rows + 1),
                                 range(col // self.tile_cols,(col + m_col) // self.tile_cols + 1),
        )

    def find_space(self, m_row, m_col):
        if not self.mapped:
            self.mapped.append((0,0,m_row,m_col))
        else:
            self.mapped.append((self.mapped[-1][0] + self.mapped[-1][2], self.mapped[-1][1] + self.mapped[-1][3], m_row, m_col))
        return self.mapped[-1][0], self.mapped[-1][1]

    def clear(self):
        self.mapped = []
        self.W = torch.ones(self.size) * self.g_on

    def conductance_update(self):
        self.conductance_states = torch.cat([torch.cat([torch.linspace(self.g_off[i,j], self.g_on[i,j],2**self.resolution - 1).unsqueeze(0)
                                                        for j in range(self.size[1])],dim=0).unsqueeze(0)
                                             for i in range(self.size[0])],dim=0)


## Device parameters

In [None]:
# Key Idea is that CODEX allows us to use higher ADC inpt resolution by
# Reducing the ADC sensing range.
device_params = {"Vdd": 1.8,
                 "r_wl": 10,
                 "r_bl": 10,
                 "m": 600,
                 "n": 600,
                 "r_on_mean": 1e4,
                 "r_on_stddev": 1e3,
                 "r_off_mean": 1e5,
                 "r_off_stddev": 1e4,
                 "dac_resolution": 5,
                 "adc_resolution": 8.3,
                 "device_resolution": 8,
                 "bias_scheme": 1/3,
                 "tile_rows": 4,
                 "tile_cols": 4,
                 "r_cmos_line": 600,
                 "r_cmos_transistor": 20,
                 "p_stuck_on": 0.01,
                 "p_stuck_off": 0.01}

## Training + Testing

In [None]:
def network_tester(model, test_loader, test_size, epoch, log = True):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            if batch_idx * len(data) > test_size:
                break
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
            total+=target.size(0)
            loss = F.nll_loss(output, target)

            test_loss += loss.item()

            if batch_idx % 100 == 0 and log:
              with open('log_baseline_test.csv', 'a') as f:
                writer = csv.writer(f)
                writer.writerow([batch_idx + test_size * 100, test_loss/(batch_idx+1), correct.item()/total])
              print("Epoch", epoch, 'iteration',batch_idx, 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                          % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    return torch.div(correct, float(total))

### Baseline

In [None]:
def net_trainer(model, train_dataloader, test_dataloader, config):
    model.train()
    avg_loss = -1
    optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'])

    torch.manual_seed(0)
    update_per_epoch = []
    loss_per_epoch = []
    every_loss = []
    thr_per_batch = []

    for j in range(config['num_epochs']):
        data, target = next(iter(train_dataloader))
        optimizer.zero_grad()
        # output = model(data)
        # loss = F.nll_loss(output, target)
        # loss.backward()
        # optimizer.step()
        # cur_avg_abs_grad = model.fc1.W.grad.abs().mean()
        # init_thr = local_sched(config['learning_rate'], config['gamma'],
        #                        cur_avg_abs_grad, j+1)
        # thr = init_thr
        cur_epoch_updates = 0
        for batch_idx, (data, target) in enumerate(train_dataloader):

            if batch_idx % config['test_interval'] == 1:
                print(f"Test accuracy: {network_tester(model, test_dataloader, 100, j)}")

            if batch_idx % config['log_interval'] == 1:
                print(f"Epoch: {j + 1}, Loss: {loss.data} Updates: {cur_epoch_updates}/{batch_idx}, Avg Grad: {cur_avg_abs_grad}")
                with open('log_baseline_train.csv', 'a') as f:
                    writer = csv.writer(f)
                    writer.writerow([batch_idx + j * config['log_interval'], loss.data.item(), network_tester(model, train_dataloader, 100, j, False).item()])

            output = model(data)
            loss = F.nll_loss(output, target) / config['batch_size']
            loss.backward()
            every_loss.append(loss.data)

            # if avg_loss == -1.0:
            #     avg_loss = loss
            # else:
            #     avg_loss += loss
            # if avg_loss > config['naive_loss_thr']:
            #     thr = init_thr

            if batch_idx % config['batch_size'] == 0:
                cur_avg_abs_grad = model.fc1.W.grad.abs().mean()
                # this is kind of hardcoded... Are we doing num_layers in this search?
                # we want to perform searches on things that don't really affect the
                # effectivity of thresholding, so maybe not
                # perhaps let's do ADC resolution + learning rate?
                cur_epoch_updates += 1
                # cur_lr = get_current_lr(optimizer, 0, 0)
                # thr = local_sched(config['learning_rate'],
                #                                 config['gamma'], cur_avg_abs_grad, batch_idx+1)
                optimizer.step()
                model.fc1.remap()
                #netowrk.fc2.remap()
                optimizer.zero_grad()
            #     avg_loss = -1.0
            # thr_per_batch.append(thr)


        update_per_epoch.append(cur_epoch_updates)


        loss_per_epoch.append(loss.data)
    return loss_per_epoch, update_per_epoch, every_loss

### Input modulation + thresholding

In [None]:
def net_trainer_thresh(model, train_dataloader, test_dataloader, config):
    model.train()
    avg_loss = -1
    optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'])
    threshold_update = False

    def local_sched(learning_rate, gamma, cur_avg_abs_grad, epoch_n):
        new_thr = cur_avg_abs_grad*(1+(math.exp(-learning_rate*gamma*epoch_n)))
        upper_bound = cur_avg_abs_grad * config['max_thresh_multiplier']

        return min(new_thr, upper_bound)

    torch.manual_seed(0)
    update_per_epoch = []
    loss_per_epoch = []
    every_loss = []
    thr_per_batch = []

    for j in range(config['num_epochs']):
        data, target = next(iter(train_dataloader))
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        cur_avg_abs_grad = model.fc1.W.grad.abs().mean()
        init_thr = local_sched(config['learning_rate'], config['gamma'],
                               cur_avg_abs_grad, j+1)
        thr = init_thr
        cur_epoch_updates = 0
        for batch_idx, (data, target) in enumerate(train_dataloader):

            if batch_idx % config['test_interval'] == 1:
                print(f"Test accuracy: {network_tester(model, test_dataloader, 300, j)}")

            if batch_idx % config['log_interval'] == 1:
                print(f"Epoch: {j + 1}, Loss: {loss.data} Updates: {cur_epoch_updates}/{batch_idx}, Avg Grad: {cur_avg_abs_grad}, Threshold: {thr}")
                with open('log_baseline_train.csv', 'a') as f:
                  writer = csv.writer(f)
                  writer.writerow([batch_idx + j * config['log_interval'], loss.data.item(), network_tester(model, train_dataloader, 100, j, False).item()])

            output = model(data)
            loss = F.nll_loss(output, target) / config['batch_size']
            loss.backward()
            every_loss.append(loss.data)

            if avg_loss == -1.0:
                avg_loss = loss
            else:
                avg_loss += loss
            # if avg_loss > config['naive_loss_thr']:
            #     thr = init_thr

            if batch_idx % config['batch_size'] == 0:
                cur_avg_abs_grad = model.fc1.W.grad.abs().mean()
                if threshold_update:
                    thr = local_sched(config['learning_rate'],
                                                    config['gamma'], cur_avg_abs_grad, batch_idx+1)
                    threshold_update = False
                # this is kind of hardcoded... Are we doing num_layers in this search?
                # we want to perform searches on things that don't really affect the
                # effectivity of thresholding, so maybe not
                # perhaps let's do ADC resolution + learning rate?
                if cur_avg_abs_grad > thr:
                    cur_epoch_updates += 1
                    # cur_lr = get_current_lr(optimizer, 0, 0)
                    # thr = local_sched(config['learning_rate'],
                    #                                 config['gamma'], cur_avg_abs_grad, batch_idx+1)
                    optimizer.step()
                    model.fc1.remap()
                    #netowrk.fc2.remap()
                    optimizer.zero_grad()
                    threshold_update = True
                avg_loss = -1.0
                thr_per_batch.append(thr)


        update_per_epoch.append(cur_epoch_updates)


        loss_per_epoch.append(loss.data)
    return loss_per_epoch, update_per_epoch, every_loss, thr_per_batch

### Manhattan


In [None]:
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, required
from typing import List, Optional
import numpy as np

count_list = []
last_layer = []

class ManhattanSGD(Optimizer):
    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov,
                        maximize=maximize, foreach=foreach)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(ManhattanSGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
            group.setdefault('maximize', False)
            group.setdefault('foreach', None)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs the Manhattan Learning rule such that
        \Delta W(i,j) = sgn(\Delta w(i,j))
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            has_sparse_grad = False

            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data

                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        # buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                        #buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()#.mul_(group['lr'])
                        buf = param_state['momentum_buffer'] =  torch.from_numpy(np.where(torch.clone(d_p).cpu().detach()>0, 1, -1)).to(device)
                        #count_list.append(buf.cpu().detach().numpy().flatten().tolist())
                    else:
                        buf = param_state['momentum_buffer'].float()
                        #buf.mul_(momentum).add_(1 - dampening, d_p)
                        #buf.add_(1 - dampening, torch.from_numpy(np.where(torch.clone(d_p).cpu().detach()>0, 1, -1)).to(device))
                        #buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()#.mul_(group['lr'])
                        buf = param_state['momentum_buffer'] =  torch.from_numpy(np.where(torch.clone(d_p).cpu().detach()>0, 1, -1)*momentum*(1-dampening)).to(device)
                        #count_list.append(buf.cpu().detach().numpy().flatten().tolist())
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf
                    p.data.add_(-group['lr'], d_p)
                else:
                    p.data.add_(-group['lr'], d_p)


        return loss


In [None]:
def net_trainer_manh(model, train_dataloader, test_dataloader, config):
    model.train()
    avg_loss = -1
    optimizer = ManhattanSGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'], dampening=config['dampening'])
    #optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'])

    torch.manual_seed(0)
    update_per_epoch = []
    loss_per_epoch = []
    every_loss = []
    thr_per_batch = []

    for j in range(config['num_epochs']):
        data, target = next(iter(train_dataloader))
        optimizer.zero_grad()
        # output = model(data)
        # loss = F.nll_loss(output, target)
        # loss.backward()
        # optimizer.step()
        # cur_avg_abs_grad = model.fc1.W.grad.abs().mean()
        # init_thr = local_sched(config['learning_rate'], config['gamma'],
        #                        cur_avg_abs_grad, j+1)
        # thr = init_thr
        cur_epoch_updates = 0
        for batch_idx, (data, target) in enumerate(train_dataloader):

            if batch_idx % config['test_interval'] == 1:
                print(f"Test accuracy: {network_tester(model, test_dataloader, 400, j)}")

            if batch_idx % config['log_interval'] == 1:
                print(f"Epoch: {j + 1}, Loss: {loss.data} Updates: {cur_epoch_updates}/{batch_idx}, Avg Grad: {cur_avg_abs_grad}")
                with open('log_baseline_train.csv', 'a') as f:
                    writer = csv.writer(f)
                    writer.writerow([batch_idx + j * config['log_interval'], loss.data.item(), network_tester(model, train_dataloader, 400, j, False).item()])

            output = model(data)
            loss = F.nll_loss(output, target) / config['batch_size']
            loss.backward()
            every_loss.append(loss.data)


            if batch_idx % config['batch_size'] == 0:
                cur_avg_abs_grad = model.fc1.W.grad.abs().mean()
                # this is kind of hardcoded... Are we doing num_layers in this search?
                # we want to perform searches on things that don't really affect the
                # effectivity of thresholding, so maybe not
                # perhaps let's do ADC resolution + learning rate?
                cur_epoch_updates += 1

                optimizer.step()
                model.fc1.remap()
                #netowrk.fc2.remap()
                optimizer.zero_grad()
            #     avg_loss = -1.0
            # thr_per_batch.append(thr)


        update_per_epoch.append(cur_epoch_updates)


        loss_per_epoch.append(loss.data)
    return loss_per_epoch, update_per_epoch, every_loss

## ResNet18

In [None]:
'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # self.linear = nn.Linear(512*block.expansion, num_classes)
        self.uvect = [2*(torch.rand(512,1) - 0.5) for i in range(0,10)] #
        crb1 = crossbar(device_params)
        # Can test using more than 1 crossbar linear layers.
        # Easiest implementation is to create a crossbar for each linear layer
        self.fc1 = Linear(512, 10,crb1,self.uvect)
        self.fc1.use_cb(True)
        #self.fc2 = nn.Linear(64*2*2, 10)
        self.traincount = 0

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(512, 1)
        #print(out.shape)
        out = self.fc1(out)
        out = out.t()
        out = F.log_softmax(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])

## Loading Image

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import csv

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=1, shuffle=True, num_workers=1)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=1, shuffle=False, num_workers=1)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Training Process

In [None]:
net = ResNet18()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

train_config = {
    "num_epochs" : 35,
    "batch_size" : 150,
    "gamma" : 1,
    "naive_loss_thr" : 2,
    'learning_rate' : 0.002,
    "log_interval" : 1000,
    "momentum": 0.9,
    "max_thresh_multiplier": 1.5,
    "test_interval": 10000,
}

# we know for this dataset the max n_epoch = 50,000
def calc_gamma(lr, m_epoch):
    return np.log(lr)/(-lr*m_epoch)

train_config['gamma'] = calc_gamma(train_config['learning_rate'], 50000)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

with open('log_baseline_train.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(["iteration", "train_loss", "train_acc"])

with open('log_baseline_test.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(["iteration", "test_loss", "test_acc"])

c3f1_loss_per_epoch, c3f1_update_per_epoch, c3f1_every_loss, c3f1_thr_per_batch = net_trainer(net, trainloader, testloader, train_config)



Epoch 0 iteration 0 Loss: 7.162 | Acc: 0.000% (0/1)
Epoch 0 iteration 100 Loss: 5.991 | Acc: 7.921% (8/101)
Test accuracy: 0.07920791953802109
Epoch: 1, Loss: 0.02601216360926628 Updates: 1/1, Avg Grad: 0.0010897691827267408
Epoch: 1, Loss: 0.010968198999762535 Updates: 7/1001, Avg Grad: 0.004348933696746826
Epoch: 1, Loss: 0.042897649109363556 Updates: 14/2001, Avg Grad: 0.0069520012475550175
Epoch: 1, Loss: 0.03277191147208214 Updates: 21/3001, Avg Grad: 0.004150037653744221
Epoch: 1, Loss: 0.01962302252650261 Updates: 27/4001, Avg Grad: 0.0025889561511576176
Epoch: 1, Loss: 0.014845818281173706 Updates: 34/5001, Avg Grad: 0.003305623773485422
Epoch: 1, Loss: 0.014006637036800385 Updates: 41/6001, Avg Grad: 0.0018860778072848916
Epoch: 1, Loss: 0.012566613964736462 Updates: 47/7001, Avg Grad: 0.0025752042420208454
Epoch: 1, Loss: 0.01774468831717968 Updates: 54/8001, Avg Grad: 0.0026689129881560802
Epoch: 1, Loss: 0.016343511641025543 Updates: 61/9001, Avg Grad: 0.0016055116429924965

KeyboardInterrupt: ignored

In [None]:
thresh_outputs, thresh_loss = [], []
for i in range(500):
    item = next(iter(testloader))
    res = net(item[0])
    thresh_outputs.append(res[0])
    loss_curr =  F.nll_loss(res, item[1])
    thresh_loss.append(loss_curr.item())

with open("err_codex.txt", 'w') as writefile:
    writefile.write(str(thresh_loss))

err2 = np.cumsum(thresh_loss)
with open("err_codex_cumsum.txt", 'w') as writefile:
    writefile.write(str(err2))



In [None]:
net = ResNet18()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

train_config = {
    "num_epochs" : 35,
    "batch_size" : 200,
    "gamma" : 1,
    "naive_loss_thr" : 2,
    'learning_rate' : 0.002,
    "log_interval" : 1000,
    "momentum": 0.9,
    "max_thresh_multiplier": 1.1,
    "test_interval": 10000,
}

# we know for this dataset the max n_epoch = 50,000
def calc_gamma(lr, m_epoch):
    return np.log(lr)/(-lr*m_epoch)

train_config['gamma'] = calc_gamma(train_config['learning_rate'], train_config['num_epochs'])

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

with open('log_baseline_train.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(["iteration", "train_loss", "train_acc"])

with open('log_baseline_test.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(["iteration", "test_loss", "test_acc"])

c3f1_loss_per_epoch, c3f1_update_per_epoch, c3f1_every_loss, c3f1_thr_per_batch = net_trainer_thresh(net, trainloader, testloader, train_config)



Epoch 0 iteration 0 Loss: 3.974 | Acc: 0.000% (0/1)
Epoch 0 iteration 100 Loss: 3.955 | Acc: 12.871% (13/101)
Epoch 0 iteration 200 Loss: 4.027 | Acc: 13.930% (28/201)
Epoch 0 iteration 300 Loss: 4.252 | Acc: 12.292% (37/301)
Test accuracy: 0.1229235902428627
Epoch: 1, Loss: 0.06031562760472298 Updates: 0/1, Avg Grad: 0.15655477344989777, Threshold: 0.17143213748931885
Epoch: 1, Loss: 0.011132776737213135 Updates: 1/1001, Avg Grad: 0.1719636619091034, Threshold: 0.17143213748931885
Epoch: 1, Loss: 0.07827382534742355 Updates: 3/2001, Avg Grad: 0.051185380667448044, Threshold: 0.051185380667448044
Epoch: 1, Loss: 0.1328832507133484 Updates: 6/3001, Avg Grad: 0.019957710057497025, Threshold: 0.011061353608965874
Epoch: 1, Loss: 0.011396464891731739 Updates: 8/4001, Avg Grad: 0.006320871412754059, Threshold: 0.006320871412754059
Epoch: 1, Loss: 0.00982842780649662 Updates: 11/5001, Avg Grad: 0.008499400690197945, Threshold: 0.005156924016773701
Epoch: 1, Loss: 0.012680484913289547 Updates

KeyboardInterrupt: ignored

In [None]:
print(net.fc1.cb.W)

tensor([[5.2074e-05, 5.2074e-05, 4.9672e-05,  ..., 9.7597e-05, 9.5624e-05,
         9.8300e-05],
        [5.2284e-05, 5.2284e-05, 5.2047e-05,  ..., 9.8921e-05, 8.7966e-05,
         9.0470e-05],
        [5.1894e-05, 5.4662e-05, 5.6103e-05,  ..., 1.1025e-04, 1.0725e-04,
         9.4334e-05],
        ...,
        [1.1123e-04, 9.3991e-05, 9.8402e-05,  ..., 1.0641e-04, 8.2949e-05,
         9.7641e-05],
        [1.1752e-04, 1.0429e-04, 9.8852e-05,  ..., 8.2507e-05, 9.0317e-05,
         8.9049e-05],
        [1.1125e-04, 9.5968e-05, 1.1941e-04,  ..., 9.7233e-05, 1.0448e-04,
         1.0215e-04]])


In [None]:
net = ResNet18()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

train_config = {
    "num_epochs" : 20,
    "batch_size" : 800,
    "gamma" : 1,
    "naive_loss_thr" : 2,
    'learning_rate' : 0.0002, #0.002,
    "log_interval" : 1000,
    "momentum": 0.9,
    "max_thresh_multiplier": 2,
    "test_interval": 10000,
    "dampening": 0.1
}

# we know for this dataset the max n_epoch = 50,000
def calc_gamma(lr, m_epoch):
    return np.log(lr)/(-lr*m_epoch)

train_config['gamma'] = calc_gamma(train_config['learning_rate'], 50000)
#optimizer = ManhattanSGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'])

#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

with open('log_baseline_train.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(["iteration", "train_loss", "train_acc"])

with open('log_baseline_test.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(["iteration", "test_loss", "test_acc"])

c3f1_loss_per_epoch, c3f1_update_per_epoch, c3f1_every_loss = net_trainer_manh(net, trainloader, testloader, train_config)



Epoch 0 iteration 0 Loss: 6.252 | Acc: 0.000% (0/1)
Epoch 0 iteration 100 Loss: 5.950 | Acc: 10.891% (11/101)
Epoch 0 iteration 200 Loss: 5.982 | Acc: 8.955% (18/201)
Epoch 0 iteration 300 Loss: 6.140 | Acc: 8.970% (27/301)
Epoch 0 iteration 400 Loss: 6.115 | Acc: 8.978% (36/401)
Test accuracy: 0.08977556228637695
Epoch: 1, Loss: 0.002918248064815998 Updates: 1/1, Avg Grad: 0.0001863234501797706
Epoch: 1, Loss: 0.0011225983034819365 Updates: 2/1001, Avg Grad: 0.0039987629279494286
Epoch: 1, Loss: 0.0012948412913829088 Updates: 3/2001, Avg Grad: 0.0043893177062273026
Epoch: 1, Loss: 0.005354406777769327 Updates: 4/3001, Avg Grad: 0.007110339589416981
Epoch: 1, Loss: 0.01574348285794258 Updates: 6/4001, Avg Grad: 0.03171857073903084
Epoch: 1, Loss: 0.0023952170740813017 Updates: 7/5001, Avg Grad: 0.004082971252501011
Epoch: 1, Loss: 0.002471892163157463 Updates: 8/6001, Avg Grad: 0.023377755656838417
Epoch: 1, Loss: 0.0008532585925422609 Updates: 9/7001, Avg Grad: 0.0041057695634663105
E

KeyboardInterrupt: ignored