# LeNet5 Analog Hardware-aware Training Example
 Training the LeNet5 neural network with hardware aware training using the Inference RPU Config to optimize inference on a PCM device.

<a href="https://colab.research.google.com/github/IBM/aihwkit/blob/master/notebooks//analog_training_LeNet5_hwa.ipynb" target="_parent">
    <img src="https://colab.research.google.com/assets/colab-badge.svg"/>
</a>

## IBM Analog Hardware Acceleration Kit

IBM Analog Hardware Acceleration Kit (AIHWKIT) is an open source Python toolkit for exploring and using the capabilities of in-memory computing devices in the context of artificial intelligence.
The pytorch integration consists of a series of primitives and features that allow using the toolkit within PyTorch. 
The github repository can be found at: https://github.com/IBM/aihwkit

There are two possible scenarios for using Analog AI, one where the Analog accelerator targets training of DNN and one where the Analog accelerator aims at accelerating the inference of a DNN.
Employing Analog accelerator for training scenarios requires innovations on the algorithms used during the backpropagation (BP) phase and we will see an example of such innovation in the Tiki-Taka notebook.
Employing Analog accelerator for inference scenarios allows the use of a digital accelerator for the training part and then transfering the weights to the analog hardware for the inference, which we will explore in this notebook.


## Hardware Aware Training with Analog AI

When performing inference on Analog hardware, direct transfer of the weights trained on a digital chip to analog hardware would results in reduced network accuracies due to non-idealities of the non-volatile memory technologies used in the analog chip. 
For example, a common non-ideality of Phase Change Memory is the resistance drift. This drift is due to structural relaxation of the physical material composing the memory and causes an evolution of the memory resistance over time, which translates in a change in the weight stored in the memory array. 
This would eventually result in a decrease network accuracy over time.
What one can do is to train the network on digital accelerator but in a way that is aware of the hardware characteristics that will be used in the inference pass, we refer to this as Hardware Aware Training (HWA).

<center><img src="img/hwa.jpg" style="width:50%; height:50%"/></center> 

In hardware-aware training, we add many of the device non-idealities in the forward pass, so that the network itself, while learning the features to achieve high accuracy, also learns how to be resiliant to these non-idealities.
Examples of the non-idealities inserted during the forward pass include quantization noise and thermal noise related to the Digital-to-Analog converter (DAC) and Analog-to-Digital converter (ADC) for the digital-to-analog signal conversion (and vice versa), as well as non-idealities of the NVM technology in use, which for Phase Change Memory are the state dependent weight noise, the weight drift among others.

Hardware-aware training greatly improves analog inference accuracy. For more information, please refer to:

https://ieeexplore.ieee.org/document/8993573

https://ieeexplore.ieee.org/document/8776519

https://www.nature.com/articles/s41467-020-16108-9

In this notebook we will use the AIHWKIT to do HWA training of a LeNet5 inspired network with the MNIST dataset. We will then evaluate the accuracy of the analog hardware and its evolution in time.

The first thing to do is to install the AIHKIT and dependencies in your environment. The preferred way to install this package is by using the Python package index (please uncomment this line to install in your environment if not previously installed):

In [None]:
# To install the cpu-only enabled kit, un-comment the line below
# %pip install aihwkit

# To install the GPU-enabled wheel go to https://aihwkit.readthedocs.io/en/latest/advanced_install.html#install-the-aihwkit-using-pip
# and copy the option on GPU options that best suits your enviroment and paste it below and run the cell. For example, Python 3.10 and CUDA 12.1:
# !wget https://aihwkit-gpu-demo.s3.us-east.cloud-object-storage.appdomain.cloud/aihwkit-0.9.2+cuda121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# %pip install aihwkit-0.9.2+cuda121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl


If the library was installed correctly, you can use the following snippet for creating an analog layer and predicting the output:

In [13]:
from torch import Tensor
from aihwkit.nn import AnalogLinear

model = AnalogLinear(2, 2)
model(Tensor([[0.1, 0.2], [0.3, 0.4]]))

tensor([[-0.5974,  0.2536],
        [-0.5692,  0.4325]], grad_fn=<AddBackward0>)

Now that the package is installed and running, we can start working on creating the LeNet5 network.

AIHWKIT offers different Analog layers that can be used to build a network, including AnalogLinear and AnalogConv2d which will be the main layers used to build the present network. 
In addition to the standard input that are expected by the PyTorch layers (in_channels, out_channels, etc.) the analog layers also expect a rpu_config input which defines various settings of the RPU tile. Through the rpu_config parameter the user can specify many of the hardware specs such as: device used in the cross-point array, bit used by the ADC/DAC converters, noise values and many other. Additional details on the RPU configuration can be found at https://aihwkit.readthedocs.io/en/latest/using_simulator.html#rpu-configurations
For this particular case we will use two device per cross-point which will effectively allow us to enable the weight transfer needed to implement the Tiki-Taka algorithm.

In [14]:
from aihwkit.simulator.configs import InferenceRPUConfig
from aihwkit.simulator.configs.utils import BoundManagementType, WeightNoiseType
from aihwkit.inference import PCMLikeNoiseModel, ReRamCMONoiseModel
from aihwkit.simulator.parameters.io import IOParametersIRDropT
from aihwkit.inference import GlobalDriftCompensation
from aihwkit.simulator.configs.utils import (
    WeightModifierType,
    BoundManagementType,
    WeightClipType,
    NoiseManagementType,
    WeightRemapType,
)
PCM=False
def create_rpu_config():
    input_prec = 6
    output_prec = 8
    my_rpu_config = InferenceRPUConfig()
    my_rpu_config.mapping.digital_bias = True # do the bias of the MVM digitally
    my_rpu_config.mapping.max_input_size = 256
    my_rpu_config.mapping.max_output_size = 256

    my_rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0)
    #my_rpu_config.noise_model = ReRamCMONoiseModel(g_max=88.19, g_min=9.0,
                                                #acceptance_range=2.0)
    #my_rpu_config.drift_compensation = None # by default is GlobalCompensation from PCM

    #my_rpu_config.drift_compensation = None
    my_rpu_config.forward.w_noise_type = WeightNoiseType.ADDITIVE_CONSTANT
    my_rpu_config.forward.w_noise = 0.02
    my_rpu_config.forward = IOParametersIRDropT()
    my_rpu_config.forward.inp_res = 1 / (2**input_prec - 2)
    my_rpu_config.forward.out_res = 1 / (2**output_prec - 2)
    my_rpu_config.forward.is_perfect = False
    #my_rpu_config.forward.out_noise = 0.0 # Output on the current addition (?)
    my_rpu_config.forward.ir_drop_g_ratio = 1.0 / 0.35 / 25e-6 # change to 25w-6 when using PCM
    my_rpu_config.forward.ir_drop = 1.0 # TODO set to 1.0 when activating IR drop effects
    my_rpu_config.forward.ir_drop_rs = 0.35 # Default: 0.15
    my_rpu_config.forward.noise_management = NoiseManagementType.ABS_MAX # Rescale back the output with the scaling for normalizing the input
    my_rpu_config.forward.bound_management = BoundManagementType.NONE # No learning of the ranges
    my_rpu_config.forward.out_bound = 10.0  # quite restrictive
    return my_rpu_config

We can now use this rpu_config as input of the network model:

In [15]:
from torch.nn import Tanh, MaxPool2d, LogSoftmax, Flatten
from aihwkit.nn import AnalogConv2d, AnalogLinear, AnalogSequential

def create_analog_network(rpu_config):
    
    channel = [16, 32, 512, 128]
    model = AnalogSequential(
        AnalogConv2d(in_channels=1, out_channels=channel[0], kernel_size=5, stride=1,
                        rpu_config=rpu_config),
        Tanh(),
        MaxPool2d(kernel_size=2),
        AnalogConv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=5, stride=1,
                        rpu_config=rpu_config),
        Tanh(),
        MaxPool2d(kernel_size=2),
        Tanh(),
        Flatten(),
        AnalogLinear(in_features=channel[2], out_features=channel[3], rpu_config=rpu_config),
        Tanh(),
        AnalogLinear(in_features=channel[3], out_features=10, rpu_config=rpu_config),
        LogSoftmax(dim=1)
    )

    return model

We will use the cross entropy to calculate the loss and the Stochastic Gradient Descent (SGD) as optimizer:

In [16]:
from torch.nn import CrossEntropyLoss

criterion = CrossEntropyLoss()


from aihwkit.optim import AnalogSGD

def create_analog_optimizer(model):
    """Create the analog-aware optimizer.

    Args:
        model (nn.Module): model to be trained

    Returns:
        Optimizer: created analog optimizer
    """
    
    optimizer = AnalogSGD(model.parameters(), lr=0.01) # we will use a learning rate of 0.01 as in the paper
    optimizer.regroup_param_groups(model)

    return optimizer

We can now write the train function which will optimize the network over the MNIST train dataset. The train_step function will take as input the images to train on, the model to train and the criterion and optimizer to train with:

In [17]:
from torch import device, cuda
from tqdm import tqdm
DEVICE = device('cuda' if cuda.is_available() else 'cpu')
print('Running the simulation on: ', DEVICE)

def train_step(train_data, model, criterion, optimizer):
    """Train network.

    Args:
        train_data (DataLoader): Validation set to perform the evaluation
        model (nn.Module): Trained model to be evaluated
        criterion (nn.CrossEntropyLoss): criterion to compute loss
        optimizer (Optimizer): analog model optimizer

    Returns:
        train_dataset_loss: epoch loss of the train dataset
    """
    total_loss = 0

    model.train()

    for images, labels in tqdm(train_data):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        optimizer.zero_grad()

        # Add training Tensor to the model (input).
        output = model(images)
        loss = criterion(output, labels)

        # Run training (backward propagation).
        loss.backward()

        # Optimize weights.
        optimizer.step()
        total_loss += loss.item() * images.size(0)
    train_dataset_loss = total_loss / len(train_data.dataset)

    return train_dataset_loss

Running the simulation on:  cuda


Since training can be quite time consuming it is nice to see the evolution of the training process by testing the model capabilities on a set of images that it has not seen before (test dataset). So we write a test_step function:

In [18]:
def test_step(validation_data, model, criterion):
    """Test trained network

    Args:
        validation_data (DataLoader): Validation set to perform the evaluation
        model (nn.Module): Trained model to be evaluated
        criterion (nn.CrossEntropyLoss): criterion to compute loss

    Returns: 
        test_dataset_loss: epoch loss of the train_dataset
        test_dataset_error: error of the test dataset
        test_dataset_accuracy: accuracy of the test dataset
    """
    total_loss = 0
    predicted_ok = 0
    total_images = 0

    model.eval()

    for images, labels in validation_data:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        pred = model(images)
        loss = criterion(pred, labels)
        total_loss += loss.item() * images.size(0)

        _, predicted = torch.max(pred.data, 1)
        total_images += labels.size(0)
        predicted_ok += (predicted == labels).sum().item()
        test_dataset_accuracy = predicted_ok/total_images*100
        test_dataset_error = (1-predicted_ok/total_images)*100

    test_dataset_loss = total_loss / len(validation_data.dataset)

    return test_dataset_loss, test_dataset_error, test_dataset_accuracy

To reach satisfactory accuracy levels, the train_step will have to be repeated mulitple time so we will implement a loop over a certain number of epochs:

In [19]:
def training_loop(model, criterion, optimizer, train_data, validation_data, epochs=15, print_every=1):
    """Training loop.

    Args:
        model (nn.Module): Trained model to be evaluated
        criterion (nn.CrossEntropyLoss): criterion to compute loss
        optimizer (Optimizer): analog model optimizer
        train_data (DataLoader): Validation set to perform the evaluation
        validation_data (DataLoader): Validation set to perform the evaluation
        epochs (int): global parameter to define epochs number
        print_every (int): defines how many times to print training progress

    """
    train_losses = []
    valid_losses = []
    test_error = []

    # Train model
    for epoch in range(0, epochs):
        # Train_step
        train_loss = train_step(train_data, model, criterion, optimizer)
        train_losses.append(train_loss)

        if epoch % print_every == (print_every - 1):
            # Validate_step
            with torch.no_grad():
                valid_loss, error, accuracy = test_step(validation_data, model, criterion)
                valid_losses.append(valid_loss)
                test_error.append(error)

            print(f'Epoch: {epoch}\t'
                  f'Train loss: {train_loss:.4f}\t'
                  f'Valid loss: {valid_loss:.4f}\t'
                  f'Test error: {error:.2f}%\t'
                  f'Test accuracy: {accuracy:.2f}%\t')

In [20]:
def test_inference(model, criterion, test_data, file="reram"):
    
    from numpy import logspace, log10
    
    total_loss = 0
    predicted_ok = 0
    total_images = 0
    accuracy_pre = 0
    error_pre = 0
    
    # Create the t_inference_list using inference_time.
    # Generate the 9 values between 0 and the inference time using log10
    max_inference_time = 1e8
    n_times = 9
    t_inference_list = [0, 1, 3600, 3600 * 24, 3600 * 24 * 365 * 10]
    errors = torch.zeros(size=(len(t_inference_list),1))
    accuracy = torch.zeros(size=(len(t_inference_list),1))
    # Simulation of inference pass at different times after training.
    for i,t_inference in enumerate(t_inference_list):
        model.drift_analog_weights(t_inference)

        time_since = t_inference
        accuracy_post = 0
        error_post = 0
        predicted_ok = 0
        total_images = 0

        for images, labels in test_data:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            pred = model(images)
            loss = criterion(pred, labels)
            total_loss += loss.item() * images.size(0)

            _, predicted = torch.max(pred.data, 1)
            total_images += labels.size(0)
            predicted_ok += (predicted == labels).sum().item()
        accuracy_post = predicted_ok/total_images*100
        error_post = (1-predicted_ok/total_images)*100
        errors[i] = error_post
        accuracy[i] = accuracy_post
        print(f'Error after inference: {error_post:.2f}\t'
              f'Accuracy after inference: {accuracy_post:.2f}%\t'
              f'Drift t={time_since: .2e}\t')
    torch.save(errors, file+"_error_hwa.th")
    torch.save(accuracy, file+"_accuracy_hwa.th")

We will now download the MNIST dataset and prepare the images for the training and test:

In [21]:
import os
from torchvision import datasets, transforms
PATH_DATASET = os.path.join('data', 'DATASET')
os.makedirs(PATH_DATASET, exist_ok=True)

def load_images():
    """Load images for train from torchvision datasets."""

    transform = transforms.Compose([transforms.ToTensor()])
    train_set = datasets.MNIST(PATH_DATASET, download=True, train=True, transform=transform)
    test_set = datasets.MNIST(PATH_DATASET, download=True, train=False, transform=transform)
    train_data = torch.utils.data.DataLoader(train_set, batch_size=8, shuffle=True)
    test_data = torch.utils.data.DataLoader(test_set, batch_size=8, shuffle=False)

    return train_data, test_data

In [22]:
import torch.nn as nn
def create_digital_network():
    channel = [16, 32, 512, 128]
    model = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=channel[0], kernel_size=5, stride=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=5, stride=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(in_features=channel[2], out_features=channel[3]),
        nn.ReLU(),
        nn.Linear(in_features=channel[3], out_features=10),
        nn.LogSoftmax(dim=1)
    )

    return model

Put together all the code above to train

In [23]:
import torch
from aihwkit.nn.conversion import convert_to_analog
torch.manual_seed(1)

#load the dataset
train_data, test_data = load_images()
rpu_config = create_rpu_config()
#create the rpu_config
model = create_analog_network(rpu_config).to(DEVICE)

#define the analog optimizer
optimizer = create_analog_optimizer(model)

training_loop(model, criterion, optimizer, train_data, test_data)
#define the analog optimizer
#optimizer = create_analog_optimizer(model)

#training_loop(model, criterion, optimizer, train_data, test_data)




100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:21<00:00, 91.54it/s]


Epoch: 0	Train loss: 0.3706	Valid loss: 0.1125	Test error: 3.40%	Test accuracy: 96.60%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:21<00:00, 91.57it/s]


Epoch: 1	Train loss: 0.0950	Valid loss: 0.0708	Test error: 2.26%	Test accuracy: 97.74%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:21<00:00, 91.71it/s]


Epoch: 2	Train loss: 0.0688	Valid loss: 0.0548	Test error: 1.69%	Test accuracy: 98.31%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:21<00:00, 92.16it/s]


Epoch: 3	Train loss: 0.0562	Valid loss: 0.0482	Test error: 1.55%	Test accuracy: 98.45%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:20<00:00, 93.10it/s]


Epoch: 4	Train loss: 0.0472	Valid loss: 0.0448	Test error: 1.45%	Test accuracy: 98.55%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:21<00:00, 91.97it/s]


Epoch: 5	Train loss: 0.0419	Valid loss: 0.0437	Test error: 1.37%	Test accuracy: 98.63%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:20<00:00, 92.71it/s]


Epoch: 6	Train loss: 0.0378	Valid loss: 0.0417	Test error: 1.32%	Test accuracy: 98.68%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:25<00:00, 87.99it/s]


Epoch: 7	Train loss: 0.0344	Valid loss: 0.0408	Test error: 1.27%	Test accuracy: 98.73%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:23<00:00, 89.70it/s]


Epoch: 8	Train loss: 0.0331	Valid loss: 0.0421	Test error: 1.27%	Test accuracy: 98.73%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:24<00:00, 88.26it/s]


Epoch: 9	Train loss: 0.0316	Valid loss: 0.0415	Test error: 1.18%	Test accuracy: 98.82%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:26<00:00, 86.91it/s]


Epoch: 10	Train loss: 0.0301	Valid loss: 0.0408	Test error: 1.21%	Test accuracy: 98.79%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:21<00:00, 92.12it/s]


Epoch: 11	Train loss: 0.0292	Valid loss: 0.0425	Test error: 1.22%	Test accuracy: 98.78%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:20<00:00, 93.14it/s]


Epoch: 12	Train loss: 0.0277	Valid loss: 0.0394	Test error: 1.15%	Test accuracy: 98.85%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:21<00:00, 91.76it/s]


Epoch: 13	Train loss: 0.0261	Valid loss: 0.0401	Test error: 1.19%	Test accuracy: 98.81%	


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7500/7500 [01:23<00:00, 90.32it/s]


Epoch: 14	Train loss: 0.0243	Valid loss: 0.0386	Test error: 1.21%	Test accuracy: 98.79%	
