## A training ground for sparse-coded autoencoders:

**I am inheriting thise materials from an archived private repo.**


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import  DataLoader
#
import dataTools as D
import tools as T
from datetime import datetime
import os
#
import matplotlib.pyplot as plt
%matplotlib inline 
%precision 5

In [None]:
# Set some global constants:
num_epoch = 200
batch_size = 100
device = torch.device("cuda:0")

### Select your database:

In [None]:
database_name = 'CelebA'
############################
if database_name == 'CelebA':
    batch_size = 100
    root = '/ndata/CelebA/128_crop/'
    img_names_list_train = './dataset_splits/CelebA/CelebA_train.txt'
    img_names_list_valid = './dataset_splits/CelebA/CelebA_valid.txt'
    img_size = (3, 128, 128)
elif database_name == 'CYale':
    root = '/ndata/ferdowsi/CYale/'
    img_names_list_train = './dataset_splits/CYale/CYale_train.txt'
    img_names_list_valid = './dataset_splits/CYale/CYale_valid.txt'   
    img_size = (1, 168, 192)
num_channel = img_size[0]    

### Initialize the database classes for train and valid splits:

Make sure to adjust the ``num_workers``appropriately based on your data/GPU.

In [None]:
dataset_train = D.imgRead_fromList(root, img_names_list_train, img_size)
dataset_valid = D.imgRead_fromList(root, img_names_list_valid, img_size)
# Initialize the mini-batch dataloaders:
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=0)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True, num_workers=0)

### Initializing the model:

In [None]:
num_blocks = 6
num_filts = [40, 40, 40, 40, 40, 10]
scale_factor = [1, 2, 1, 2, 1, 2]
num_codes = 20
neck_dim = 512
k = 256
############
from models import Autoencoder
net = Autoencoder(img_size, num_blocks, num_filts, scale_factor, num_codes, neck_dim, k).to(device)
# print(net)

### Define the loss-function and optimizer:

In [None]:
def loss_function(x_rec, x):
    loss_BCE = nn.BCELoss()
    #loss_MSE = nn.MSELoss()
    #return loss_L1_Charbonnier
    return loss_BCE(x_rec, x)
#############################################
optimizer = optim.Adam(net.parameters(), lr=0.01,weight_decay=0)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', verbose=True, factor=0.99, min_lr=0.000001,patience=100)

### Main training loop:

In [None]:
net.train()
for i_epoch in range(num_epoch):
    loss_train = 0.0
    print('---------------- epoch = ', i_epoch + 1, '/',num_epoch, ' ----------')
    for i_batch, inp_train in enumerate(dataloader_train):

        inp_train = inp_train['image'].to(device) 
        out_train, code_train = net(inp_train)
        out_train.sigmoid_()  # Remeber to apply it also on valid-test sets. Or move it to the mode!

        optimizer.zero_grad()
        loss = loss_function(out_train, inp_train)
        loss.backward()  
        optimizer.step()
        scheduler.step(loss)
        loss_train += loss.item()
        if i_epoch < 2:
             print(f"b({i_batch}): l = {loss.item():.3f}", end=" | ")   
    print('Avg train loss = ', loss_train/len(dataloader_train))
    with torch.no_grad():
        loss_valid = 0.0
        for _, inputs_valid in enumerate(dataloader_valid):
            inp_valid = inp_valid['image'].to(device) 
            out_valid, code_valid = net(inp_valid)
            out_valid.sigmoid_()
            
            loss_valid += loss_function(out_valid,inp_valid).item()
        print('Avg validation loss = ', loss_valid/len(dataloader_valid))  

# Note that average validation error for each epoch uses the most recent parameters, while
# the average training error is taking all updates into account.

### To save the model:

In [None]:
# To stamp with current time:
now = str(datetime.timestamp(datetime.now()))

In [None]:
# Choosing an informative name for the model:
net_name = database_name + \
            '_filts' + D.list2str(num_filts)+ \
            '_scale' + D.list2str(scale_factor) +\
            '_codes' + str(num_codes) +\
            '_dim' + str(neck_dim) +\
            '_k' + str(k) + \
            '.pth'
#########################
net_root = './weights'
net_path = D.pathStamper(os.path.join(net_root, net_name), now)
print(net_path)
torch.save(net.state_dict(), net_path)

### Some basic visualization and evaluation of reconstruction performance:

In [None]:
idx = 0
D.imShow(inp_valid, idx=idx)
D.imShow(out_valid, idx=idx)

print(torch.norm(inp_train - out_train).pow(2)/torch.norm(inp_train).pow(2))
print(torch.norm(inp_valid - out_valid).pow(2)/torch.norm(inp_valid).pow(2))