In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from random import randint
import time

In [25]:
data_path='dataset/tensorized_data/'

train_data=torch.load(data_path+'training_data_bw.pt').squeeze()
train_label=torch.load(data_path+'training_labels_bw.pt').squeeze()
test_data=torch.load(data_path+'testing_data_bw.pt').squeeze()
test_label=torch.load(data_path+'testing_labels_bw.pt').squeeze()

In [26]:
print(train_data.size(), train_label.size())
print(test_data.size(), test_label.size())

torch.Size([3861, 28, 28]) torch.Size([3861])
torch.Size([1644, 28, 28]) torch.Size([1644])


In [27]:
class one_layer_net(nn.Module):

    def __init__(self, input_size, output_size):
        super(one_layer_net , self).__init__()
        self.linear_layer = nn.Linear( input_size, output_size , bias=False)
        
    def forward(self, x):
        scores = self.linear_layer(x)
        return scores

In [28]:
net=one_layer_net(784,2)
print(net)
criterion = nn.CrossEntropyLoss()
bs=20

one_layer_net(
  (linear_layer): Linear(in_features=784, out_features=2, bias=False)
)


In [29]:
def get_error( scores , labels ):

    bs=scores.size(0)
    predicted_labels = scores.argmax(dim=1)
    indicator = (predicted_labels == labels)
    num_matches=indicator.sum()
    
    return 1-num_matches.float()/bs  

In [30]:
start = time.time()

lr = 0.5 # initial learning rate

for epoch in range(100):
    
    optimizer=torch.optim.SGD( net.parameters() , lr=lr )
        
    running_loss=0
    running_error=0
    num_batches=0
    
    shuffled_indices=torch.randperm(3860)
 
    for count in range(0,3860,bs):
    
        optimizer.zero_grad()
        
        indices=shuffled_indices[count:count+bs]
        minibatch_data =  train_data[indices]
        minibatch_label= train_label[indices]

        inputs = minibatch_data.view(bs,784)

        inputs.requires_grad_()

        scores=net( inputs ) 

        loss =  criterion( scores , minibatch_label) 
        
        loss.backward()

        optimizer.step()
        
        
        running_loss += loss.detach().item()
               
        error = get_error( scores.detach() , minibatch_label)
        running_error += error.item()
        
        num_batches+=1
    

    total_loss = running_loss/num_batches
    total_error = running_error/num_batches
    elapsed_time = time.time() - start
    
    print(' ')

    print('epoch=',epoch, ' time=', elapsed_time,
          ' loss=', total_loss , ' error=', total_error*100 ,'percent lr=', lr)
        

 
epoch= 0  time= 0.06846976280212402  loss= 14.426535283654465  error= 46.68393768177131 percent lr= 0.5
 
epoch= 1  time= 0.1715714931488037  loss= 15.508302903545953  error= 49.0673576611929 percent lr= 0.5
 
epoch= 2  time= 0.3342597484588623  loss= 13.469617922071349  error= 46.16580305939511 percent lr= 0.5
 
epoch= 3  time= 0.4438323974609375  loss= 14.191947457703902  error= 46.91709842088926 percent lr= 0.5
 
epoch= 4  time= 0.5743160247802734  loss= 14.515047986581536  error= 47.305699457158696 percent lr= 0.5
 
epoch= 5  time= 0.6535308361053467  loss= 14.168422817566235  error= 46.78756483478249 percent lr= 0.5
 
epoch= 6  time= 0.7303924560546875  loss= 12.811581693165042  error= 45.1036270109483 percent lr= 0.5
 
epoch= 7  time= 0.8344266414642334  loss= 12.980259151656393  error= 44.76683936588505 percent lr= 0.5
 
epoch= 8  time= 0.9382438659667969  loss= 12.629431607500877  error= 44.792746262229166 percent lr= 0.5
 
epoch= 9  time= 1.110175371170044  loss= 13.39456973

epoch= 77  time= 8.485054969787598  loss= 10.296515761261777  error= 40.36269441169778 percent lr= 0.5
 
epoch= 78  time= 8.677163362503052  loss= 9.982423386734384  error= 41.01036277459693 percent lr= 0.5
 
epoch= 79  time= 8.777518272399902  loss= 11.024846993579766  error= 42.694300691080834 percent lr= 0.5
 
epoch= 80  time= 8.85106086730957  loss= 11.11388525691057  error= 42.38341957176288 percent lr= 0.5
 
epoch= 81  time= 8.924722909927368  loss= 11.15911765234458  error= 42.668393732970245 percent lr= 0.5
 
epoch= 82  time= 8.998829126358032  loss= 11.030391127952022  error= 42.150258956177865 percent lr= 0.5
 
epoch= 83  time= 9.075461864471436  loss= 11.310380237398988  error= 42.797927504376425 percent lr= 0.5
 
epoch= 84  time= 9.149147272109985  loss= 10.911327360207553  error= 42.53886033216288 percent lr= 0.5
 
epoch= 85  time= 9.25443410873413  loss= 10.157785226952845  error= 40.155440538040715 percent lr= 0.5
 
epoch= 86  time= 9.413625478744507  loss= 10.4938423318

In [32]:
net.eval()
with torch.no_grad():
    data = test_data.view(1644,784)
    labels = test_label
    scores = net(data) 
    print("Test error={} percent".format(get_error(scores,labels)*100))

Test error=41.970802307128906 percent
