In [1]:
import os
import torch 
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler

from PIL import Image
import numpy as np

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt 

from tqdm import tqdm


device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
class Config(object):
    def __init__(self):
        self.name = 'Autoencoder'
        self.dataset_name = 'MNIST'
        
        self.save_path = './checkpoint/'+ self.name
        self.model_path = self.save_path + '/models'
        self.decode_path = self.save_path + '/decoded_results'
        self.val_path = self.save_path + '/val_results'
        self.test_path = self.save_path + '/test_results'
        
    
        self.batch_size = 64
        self.max_epochs = 100
        self.lr = 1e-3 
        self.weight_decay = 1e-4
        self.save_every = 1     #epoch

        
    
        os.makedirs(self.save_path, exist_ok=True)
        os.makedirs(self.model_path, exist_ok=True)
        os.makedirs(self.decode_path, exist_ok=True)
        os.makedirs(self.val_path, exist_ok=True)
        os.makedirs(self.test_path, exist_ok=True)
        
opt = Config()

In [3]:
def get_train_valid_loader(batch_size,
                           random_seed,
                           valid_size=0.2,
                           shuffle=True,
                           num_workers=1,
                           pin_memory=True):
    
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

    normalize = transforms.Normalize((0.1307,), (0.3081,))  # MNIST

    # define transforms
    transform = transforms.Compose([transforms.ToTensor(),normalize])
    
    
    # load the dataset
    train_dataset = torchvision.datasets.MNIST(root='./dataset/MNIST',train=True, download=True,transform=transform)

    valid_dataset = torchvision.datasets.MNIST(root='./dataset/MNIST',train=True, download=True,transform=transform)

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle == True:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=opt.batch_size, sampler=train_sampler, 
                    num_workers=num_workers, pin_memory=pin_memory)

    valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=opt.batch_size, sampler=valid_sampler, 
                    num_workers=num_workers, pin_memory=pin_memory)


    
    return (train_loader, valid_loader)
    
def get_test_loader(batch_size,
                    shuffle=True,
                    num_workers=1,
                    pin_memory=True):
    
    normalize = transforms.Normalize((0.1307,), (0.3081,))  # MNIST

    # define transform
    transform = transforms.Compose([transforms.ToTensor(),normalize])

    dataset = torchvision.datasets.MNIST(root='./dataset/MNIST',train=False,download=True,transform=transform)

    data_loader = torch.utils.data.DataLoader(dataset,batch_size=opt.batch_size, shuffle=shuffle, 
                                              num_workers=num_workers,pin_memory=pin_memory)

    return data_loader,dataset

In [4]:
train_loader,val_loader = get_train_valid_loader(opt.batch_size,random_seed=24)
test_loader,dataset = get_test_loader(opt.batch_size)

In [5]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        #Encoder
        self.enc1 = nn.Linear(in_features=784, out_features=256) # Input image (28*28 = 784)
        self.enc2 = nn.Linear(in_features=256, out_features=128)
        self.enc3 = nn.Linear(in_features=128, out_features=64)
        

        #Decoder 
        self.dec1 = nn.Linear(in_features=64, out_features=128)
        self.dec2 = nn.Linear(in_features=128, out_features=256)
        self.dec3 = nn.Linear(in_features=256, out_features=784) # Output image (28*28 = 784)

    def forward(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))

        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))

        return x

In [6]:
def train_autoencoder(**kwargs):
    torch.cuda.empty_cache()
    opt = Config()
    
    autoencoder = Autoencoder()
    autoencoder=autoencoder.to(device)
    optimizer = optim.Adam(autoencoder.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
    criterion = nn.MSELoss()
    
    
    try:
        state = torch.load(os.path.join(opt.model_path, 'autoencoder.pth'))
        autoencoder.load_state_dict(state['state_dict'])
        print("Loaded pre-trained autoencoder with success.")
        e_counter=state['epoch']
        best_valid_loss = state['valid_loss_min']
        print('val loss till yet :',best_valid_loss)
        print('Previously Trained for {} epoches'.format(e_counter))
        e_counter += 1
    except FileNotFoundError:
        print("Pre-trained weights of integrated autoencoder not found. Training from scratch.")
        e_counter=0
        best_valid_loss = float('inf')
        
    train_loss = []
    val_loss=[]
    epoches=[]
    
    for epoch in range(e_counter,opt.max_epochs):
        autoencoder.train()
        print()
        print('==================================================================')
        print('-------------Epoch: {}/{}------------'.format(epoch,opt.max_epochs))
        epoch_loss=0.0
        for idx,(data,_) in enumerate(train_loader,0):
            
            data = data.to(device)
            optimizer.zero_grad()
            
            data = data.view(data.size(0),-1)
            output = autoencoder(data)
            
            loss = criterion(output, data)
            loss.backward()
            optimizer.step()
            
            epoch_loss+=loss.item()
            
            if (idx+1)%200 == 0:
                print('Batch ',idx,'/',len(train_loader),'|| Loss: %.4f '%(epoch_loss/(idx+1)))

                               
        mean_loss = epoch_loss / len(train_loader)
        train_loss.append(mean_loss)
        with open('auto_logs.txt', 'a') as file:
            file.write('epoch: ' + str(epoch) + ', train loss: ' + str(mean_loss) + '\n')
        
    
        print()    
        print('...........................validation....................................')
        autoencoder.eval()
        eval_loss = 0.0

        with torch.no_grad():
            for idx,(data,_) in enumerate(val_loader):
                data = data.to(device)
                data = data.view(data.size(0),-1)
                output = autoencoder.forward(data)
                v_loss = criterion(output, data)
                eval_loss += v_loss.item()
                
                if (idx+1)%90 == 0:
                    print('Batch ',idx,'/',len(val_loader),'|| Loss: %.4f'%(eval_loss/(idx+1)))
                
        valid_loss = eval_loss/len(val_loader)
        val_loss.append(valid_loss)
        epoches.append(epoch)
        print('Train Loss: %.4f || Valid Loss: %.4f '%(mean_loss,valid_loss))
        with open('auto_logs.txt', 'a') as file:
            file.write('epoch: '+ str(epoch)+', validation loss: '+str(valid_loss)+'\n')
        
        state = {
                'epoch': epoch,
                'valid_loss_min': valid_loss,
                'state_dict': autoencoder.state_dict(),
                'optimizer': optimizer.state_dict()
                }
        
        if epoch % opt.save_every == 0 or epoch == opt.max_epochs - 1:
            if valid_loss < best_valid_loss:
                print('Validation loss decreased ({:.4f} --> {:.4f}). Saving autoencoder ...'.format(best_valid_loss,valid_loss))
                torch.save(state, os.path.join(opt.model_path, 'autoencoder.pth'))
                filename = 'fake_%04d.png' %(epoch)
                val_path = os.path.join(opt.val_path, filename)
                output = output.view(output.size(0), 1, 28, 28)
                torchvision.utils.save_image(output.cpu().data, val_path)
                best_valid_loss = valid_loss
                
                
    filepath=os.path.join(opt.save_path, 'autoencoder_losses.png')
    plt.title("Training Curve")
    plt.plot(epoches, train_loss, label="Train")
    plt.plot(epoches, val_loss, label="Validation")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.savefig(filepath)



In [None]:
#train and validate the autoencoder
train_autoencoder()

In [7]:
#classifier Network
# LeNet Model definition

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x



In [8]:
def train_classifier(**kwargs):
    torch.cuda.empty_cache()
    opt = Config()
    print('loading the model...')  
    model=Net()
    model.to(device)
    autoencoder=Autoencoder().to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr = opt.lr,amsgrad=True,weight_decay=5e-4)    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    criterion = nn.CrossEntropyLoss()

    #loading autoencoder weights
    try:
        state = torch.load(os.path.join(opt.model_path,'autoencoder.pth'))
        autoencoder.load_state_dict(state['state_dict'])
        print("Loaded pre-trained autoencoder with success.")
    except:
        print("Pre-trained autoencoder not found. Train autoencoder first")
        train_autoencoder()
        
        
    
    try:
        state = torch.load(os.path.join(opt.model_path,'classifier.pth'))
        model.load_state_dict(state['state_dict'])
        print("Loaded pre-trained classifier with success.")
        e_counter=state['epoch']
        best_valid_loss = state['valid_loss_min']
        best_valid_acc = state['valid_acc_max']
        print('Previously Trained for {} epoches'.format(e_counter))
        print('Best Validation loss till yet :',best_valid_loss)
        print('Best val Accuracy till yet :',best_valid_acc,'%')
        e_counter+=1
    except FileNotFoundError:
        print("Pre-trained classifier not found. Training from scratch.")
        e_counter=0
        best_valid_loss = float('inf')
        prev_loss=float('inf')
        best_valid_acc = 0.00


    train_loss = []
    train_accuracy = []
    val_loss=[]
    val_accuracy=[]
    epoches=[]
    
    for epoch in range(e_counter,opt.max_epochs):
        model.train()
        print()
        print('==================================================================')
        print('-------------Epoch: {}/{}------------'.format(epoch,opt.max_epochs))
        epoch_loss=0.0
        correct = 0
        total = 0
        train_acc = 0.0
        for idx,batch in enumerate(train_loader):
            image,label=batch
            image,label=image.to(device),label.to(device)
            label=label.squeeze()
            
            optimizer.zero_grad()
            
            image = image.view(image.size(0),-1)
            dec_out = autoencoder(image)
            dec_out = dec_out.view(dec_out.size(0),1,28,28)
            output = model(dec_out)
            
            loss=criterion(output, label)
            loss.backward()
            optimizer.step()            
            epoch_loss+=loss.item()
            
            _,predicted = torch.max(output.data, 1)
            total += label.size(0)
            correct += predicted.eq(label).sum().item()
            train_acc+=100.*correct/total
            
            if idx%200 == 0:
                print('Batch ',idx,'/',len(train_loader),'||Loss: %.5f ||Acc: %.3f%% '%((epoch_loss/(idx+1)),(train_acc/(idx+1))))
        
        
                
        mean_loss = epoch_loss / len(train_loader)
        train_loss.append(mean_loss)

        mean_acc=train_acc/len(train_loader)
        train_accuracy.append(mean_acc)

        print('Train Loss: %.4f || Train Acc: %.3f%%'%(mean_loss,mean_acc)) 
        
        with open('./checkpoint/classifier_train_logs.txt', 'a') as file:
            file.write('epoch: ' + str(epoch) + ',loss: ' + str(mean_loss)+ ',acc: ' + str(mean_acc) +'\n')
            
        print()    
        print('...........................validation....................................')
          
        model.eval()
        eval_loss = 0.0
        val_acc = 0.0
        v_correct = 0
        v_total = 0
        with torch.no_grad():
            for i, (image,label) in enumerate(val_loader):
                image, label = image.to(device),label.to(device)
                label=label.squeeze()
                
                image = image.view(image.size(0),-1)
                dec_out = autoencoder(image)
                dec_out = dec_out.view(dec_out.size(0),1,28,28)
                output = model(dec_out)
                
                
                loss=criterion(output, label)                
                eval_loss += loss.item()
                
                _, predicted = torch.max(output.data, 1)
                v_total += label.size(0)
                v_correct += predicted.eq(label).sum().item()
                val_acc+=100.*v_correct/v_total
                
                if i%100 == 0:
                    print('Batch ',i,'/',len(val_loader), '||Loss: %.5f ||Acc: %.3f%%'%((eval_loss/(i+1)),(val_acc/(i+1))))
                
        valid_loss = eval_loss/len(val_loader)
        valid_acc=val_acc/len(val_loader)
        print('val Loss: %.4f || val Acc: %.3f%%'%(valid_loss,valid_acc))
        val_loss.append(valid_loss)
        val_accuracy.append(valid_acc)
        epoches.append(epoch)

        
        state = {
                'epoch': epoch,
                'valid_loss_min': valid_loss,
                'valid_acc_max': valid_acc,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
                }
        
        if epoch % opt.save_every == 0 or epoch == opt.max_epochs - 1:
            if valid_acc > best_valid_acc:
                print('Validation acc increased ({:.5f} --> {:.5f}). Saving model ...'.format(best_valid_acc,valid_acc))
                torch.save(state, os.path.join(opt.model_path, 'classifier.pth'))  
                
                with open('./checkpoint/classifer_val_logs.txt', 'a') as file:
                    file.write('epoch: ' + str(epoch) + ',loss: ' + str(valid_loss) + ',acc: ' + str(valid_acc) + '\n')
                
                best_valid_acc = valid_acc
                
    filepath=os.path.join(opt.save_path, 'classifier_losses.png')
    plt.title("Training Curve")
    plt.plot(epoches, train_loss, label="Train")
    plt.plot(epoches, val_loss, label="Validation")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.savefig(filepath)
    
    
#     filepath=os.path.join(opt.save_path, 'classifier_accuracy.png')
#     plt.title("Training Curve")
#     plt.plot(epoches, train_acc, label="Train")
#     plt.plot(epoches, val_acc, label="Validation")
#     plt.xlabel("Epochs")
#     plt.ylabel("ACC")
#     plt.legend(loc='best')
#     plt.savefig(filepath)


            


In [9]:
train_classifier()

loading the model...
Loaded pre-trained autoencoder with success.
Pre-trained classifier not found. Training from scratch.

-------------Epoch: 0/100------------
Batch  0 / 750 ||Loss: 2.31255 ||Acc: 6.250% 
Batch  200 / 750 ||Loss: 1.03651 ||Acc: 47.735% 
Batch  400 / 750 ||Loss: 0.75962 ||Acc: 59.661% 
Batch  600 / 750 ||Loss: 0.63766 ||Acc: 65.729% 
Train Loss: 0.5831 || Train Acc: 68.736%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.04635 ||Acc: 100.000%
Batch  100 / 188 ||Loss: 0.15290 ||Acc: 95.654%
val Loss: 0.1558 || val Acc: 95.532%
Validation acc increased (0.00000 --> 95.53189). Saving model ...

-------------Epoch: 1/100------------
Batch  0 / 750 ||Loss: 0.27904 ||Acc: 92.188% 
Batch  200 / 750 ||Loss: 0.33624 ||Acc: 89.757% 
Batch  400 / 750 ||Loss: 0.31629 ||Acc: 90.081% 
Batch  600 / 750 ||Loss: 0.30779 ||Acc: 90.358% 
Train Loss: 0.3004 || Train Acc: 90.518%

...........................validation...................

Batch  0 / 750 ||Loss: 0.16662 ||Acc: 93.750% 
Batch  200 / 750 ||Loss: 0.15405 ||Acc: 95.563% 
Batch  400 / 750 ||Loss: 0.15259 ||Acc: 95.514% 
Batch  600 / 750 ||Loss: 0.15537 ||Acc: 95.485% 
Train Loss: 0.1575 || Train Acc: 95.463%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.05687 ||Acc: 96.875%
Batch  100 / 188 ||Loss: 0.08906 ||Acc: 97.326%
val Loss: 0.0750 || val Acc: 97.534%

-------------Epoch: 15/100------------
Batch  0 / 750 ||Loss: 0.06389 ||Acc: 98.438% 
Batch  200 / 750 ||Loss: 0.15697 ||Acc: 95.361% 
Batch  400 / 750 ||Loss: 0.15439 ||Acc: 95.313% 
Batch  600 / 750 ||Loss: 0.15263 ||Acc: 95.329% 
Train Loss: 0.1515 || Train Acc: 95.359%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.06943 ||Acc: 96.875%
Batch  100 / 188 ||Loss: 0.07609 ||Acc: 97.628%
val Loss: 0.0732 || val Acc: 97.720%

-------------Epoch: 16/100------------
Batch  0 / 750 ||Loss: 0.33666 ||Acc: 

val Loss: 0.0661 || val Acc: 98.069%

-------------Epoch: 29/100------------
Batch  0 / 750 ||Loss: 0.09585 ||Acc: 96.875% 
Batch  200 / 750 ||Loss: 0.13879 ||Acc: 95.662% 
Batch  400 / 750 ||Loss: 0.13824 ||Acc: 95.810% 
Batch  600 / 750 ||Loss: 0.13718 ||Acc: 95.871% 
Train Loss: 0.1358 || Train Acc: 95.917%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.16117 ||Acc: 95.312%
Batch  100 / 188 ||Loss: 0.06977 ||Acc: 97.751%
val Loss: 0.0708 || val Acc: 97.865%

-------------Epoch: 30/100------------
Batch  0 / 750 ||Loss: 0.13289 ||Acc: 95.312% 
Batch  200 / 750 ||Loss: 0.12222 ||Acc: 95.989% 
Batch  400 / 750 ||Loss: 0.13021 ||Acc: 96.094% 
Batch  600 / 750 ||Loss: 0.13169 ||Acc: 96.108% 
Train Loss: 0.1329 || Train Acc: 96.111%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.08938 ||Acc: 96.875%
Batch  100 / 188 ||Loss: 0.06405 ||Acc: 98.031%
val Loss: 0.0688 || val Acc: 98.005%



Batch  0 / 188 ||Loss: 0.01208 ||Acc: 100.000%
Batch  100 / 188 ||Loss: 0.06916 ||Acc: 98.093%
val Loss: 0.0674 || val Acc: 98.093%

-------------Epoch: 44/100------------
Batch  0 / 750 ||Loss: 0.12780 ||Acc: 93.750% 
Batch  200 / 750 ||Loss: 0.12034 ||Acc: 96.196% 
Batch  400 / 750 ||Loss: 0.12505 ||Acc: 96.317% 
Batch  600 / 750 ||Loss: 0.12601 ||Acc: 96.322% 
Train Loss: 0.1265 || Train Acc: 96.319%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.03478 ||Acc: 98.438%
Batch  100 / 188 ||Loss: 0.06144 ||Acc: 97.827%
val Loss: 0.0676 || val Acc: 97.882%

-------------Epoch: 45/100------------
Batch  0 / 750 ||Loss: 0.05935 ||Acc: 98.438% 
Batch  200 / 750 ||Loss: 0.12105 ||Acc: 96.364% 
Batch  400 / 750 ||Loss: 0.12341 ||Acc: 96.314% 
Batch  600 / 750 ||Loss: 0.12474 ||Acc: 96.300% 
Train Loss: 0.1269 || Train Acc: 96.292%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.00797 ||Acc:

Batch  0 / 188 ||Loss: 0.00728 ||Acc: 100.000%
Batch  100 / 188 ||Loss: 0.05457 ||Acc: 98.367%
val Loss: 0.0656 || val Acc: 98.272%
Validation acc increased (98.26424 --> 98.27241). Saving model ...

-------------Epoch: 59/100------------
Batch  0 / 750 ||Loss: 0.13036 ||Acc: 96.875% 
Batch  200 / 750 ||Loss: 0.11703 ||Acc: 96.678% 
Batch  400 / 750 ||Loss: 0.12015 ||Acc: 96.554% 
Batch  600 / 750 ||Loss: 0.11977 ||Acc: 96.493% 
Train Loss: 0.1199 || Train Acc: 96.466%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.05247 ||Acc: 98.438%
Batch  100 / 188 ||Loss: 0.06598 ||Acc: 98.225%
val Loss: 0.0660 || val Acc: 98.190%

-------------Epoch: 60/100------------
Batch  0 / 750 ||Loss: 0.18309 ||Acc: 96.875% 
Batch  200 / 750 ||Loss: 0.11970 ||Acc: 96.948% 
Batch  400 / 750 ||Loss: 0.12157 ||Acc: 96.737% 
Batch  600 / 750 ||Loss: 0.12069 ||Acc: 96.673% 
Train Loss: 0.1212 || Train Acc: 96.645%

...........................validation.......

Batch  0 / 188 ||Loss: 0.02576 ||Acc: 98.438%
Batch  100 / 188 ||Loss: 0.06191 ||Acc: 98.207%
val Loss: 0.0646 || val Acc: 98.186%

-------------Epoch: 74/100------------
Batch  0 / 750 ||Loss: 0.02365 ||Acc: 100.000% 
Batch  200 / 750 ||Loss: 0.11985 ||Acc: 96.531% 
Batch  400 / 750 ||Loss: 0.11577 ||Acc: 96.467% 
Batch  600 / 750 ||Loss: 0.11367 ||Acc: 96.505% 
Train Loss: 0.1146 || Train Acc: 96.520%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.05376 ||Acc: 98.438%
Batch  100 / 188 ||Loss: 0.06052 ||Acc: 98.311%
val Loss: 0.0609 || val Acc: 98.294%
Validation acc increased (98.27241 --> 98.29439). Saving model ...

-------------Epoch: 75/100------------
Batch  0 / 750 ||Loss: 0.13654 ||Acc: 96.875% 
Batch  200 / 750 ||Loss: 0.11005 ||Acc: 96.487% 
Batch  400 / 750 ||Loss: 0.11912 ||Acc: 96.497% 
Batch  600 / 750 ||Loss: 0.11857 ||Acc: 96.466% 
Train Loss: 0.1173 || Train Acc: 96.466%

...........................validation.......

Batch  600 / 750 ||Loss: 0.11320 ||Acc: 96.627% 
Train Loss: 0.1157 || Train Acc: 96.629%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.27348 ||Acc: 95.312%
Batch  100 / 188 ||Loss: 0.06179 ||Acc: 98.207%
val Loss: 0.0623 || val Acc: 98.231%

-------------Epoch: 89/100------------
Batch  0 / 750 ||Loss: 0.03679 ||Acc: 98.438% 
Batch  200 / 750 ||Loss: 0.11449 ||Acc: 96.382% 
Batch  400 / 750 ||Loss: 0.11948 ||Acc: 96.392% 
Batch  600 / 750 ||Loss: 0.11757 ||Acc: 96.425% 
Train Loss: 0.1173 || Train Acc: 96.444%

...........................validation....................................
Batch  0 / 188 ||Loss: 0.08092 ||Acc: 95.312%
Batch  100 / 188 ||Loss: 0.06104 ||Acc: 98.250%
val Loss: 0.0618 || val Acc: 98.248%

-------------Epoch: 90/100------------
Batch  0 / 750 ||Loss: 0.02936 ||Acc: 100.000% 
Batch  200 / 750 ||Loss: 0.10697 ||Acc: 97.177% 
Batch  400 / 750 ||Loss: 0.11113 ||Acc: 96.966% 
Batch  600 / 750 ||Loss: 0.11331 ||Ac

ValueError: x and y must have same first dimension, but have shapes (100,) and (1,)

In [10]:
def test_classifier(**kwargs):
    
    model=Net()
    model.to(device)
    autoencoder=Autoencoder().to(device)
    
    state = torch.load(os.path.join(opt.model_path,'autoencoder.pth'))
    autoencoder.load_state_dict(state['state_dict'])
    print("Loaded pre-trained autoencoder with success.")
    
    state = torch.load(os.path.join(opt.model_path,'classifier.pth'))
    model.load_state_dict(state['state_dict'])
    print("Loaded pre-trained classifier with success.")

    
    model.eval()
    test_acc = 0.0
    v_correct = 0
    v_total = 0
    with torch.no_grad():
        for (image,label) in tqdm(test_loader):
            image, label = image.to(device),label.to(device)
            label=label.squeeze()

            image = image.view(image.size(0),-1)
            dec_out = autoencoder(image)
            dec_out = dec_out.view(dec_out.size(0),1,28,28)
            output = model(dec_out)


            _, predicted = torch.max(output.data, 1)
            v_total += label.size(0)
            v_correct += predicted.eq(label).sum().item()
            test_acc +=100.*v_correct/v_total
        accuracy = test_acc/len(test_loader)
    print('Test Acc: %.3f%%'%(accuracy))

    

In [11]:
test_classifier()

  0%|          | 0/157 [00:00<?, ?it/s]

Loaded pre-trained autoencoder with success.
Loaded pre-trained classifier with success.


100%|██████████| 157/157 [00:02<00:00, 60.03it/s]

Test Acc: 98.584%



