In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

import brevitas.nn as qnn

In [None]:
X_train = np.arange(0, 100, 0.5)
y_train = np.sin(X_train)

In [None]:
X_test = np.arange(100, 200, 0.5)
y_test = np.sin(X_test)

In [None]:
fig, ax = plt.subplots(figsize=(15,4))
ax.plot(X_train, y_train, lw=3, label="train")
ax.plot(X_test, y_test, lw=3, c="purple", label="test")
ax.legend(loc="lower left")
plt.show()

In [None]:
train_series = torch.from_numpy(y_train).type(torch.float32)
test_series = torch.from_numpy(y_test).type(torch.float32)

In [None]:
# LSTM expects input of (batch, sequence, features)
# So shape should be (1, 179, 20) and labels (1, 1, 179)
look_back = 20

train_dataset = []
train_labels = []
for i in range(len(train_series)-look_back):
    train_dataset.append(train_series[i:i+20])
    train_labels.append(train_series[i+20])
train_dataset = torch.stack(train_dataset).unsqueeze(0)
train_labels = torch.stack(train_labels).unsqueeze(0).unsqueeze(2)

In [None]:
class QuantNet(nn.Module):
    def __init__(self, n_neurons, input_shape):
        super(QuantNet, self).__init__()
        self.fc1 = qnn.QuantLinear(input_shape, n_neurons, bias=True, weight_scaling_per_output_channel=True)
        self.fc = qnn.QuantLinear(n_neurons, 1, bias=True, weight_scaling_per_output_channel=True)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.fc(out)
        return out

In [None]:
n_neurons = 4

model = QuantNet(n_neurons, look_back)

In [None]:
loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

loss_curve = []
for epoch in range(300):
    loss_total = 0
    
    model.zero_grad()
    
    predictions = model(train_dataset)
    
    loss = loss_function(predictions, train_labels)
    loss_total += loss.item()
    loss.backward()
    optimizer.step()
    loss_curve.append(loss_total)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.plot(loss_curve, lw=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Training Loss (MSE)")
plt.show()

In [None]:
# LSTM expects input of (batch, sequence, features)
# So shape should be (1, 179, 20)
test_dataset = [test_series[i:i+20] for i in range(len(train_series)-look_back)]
test_dataset = torch.stack(test_dataset).unsqueeze(0)

with torch.no_grad():
    test_predictions = model(test_dataset).squeeze()

In [None]:
x = np.arange(110, 200, 0.5)
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.plot(X_train,y_train, lw=2, label='train data')
ax.plot(X_test,y_test, lw=3, c='purple', label='test data')
ax.plot(x,test_predictions, lw=3, c='cyan',linestyle = ':', label='predictions')
ax.legend(loc="lower left")
plt.show();

In [None]:
extrapolation = []
seed_batch = test_series[:20].reshape(1, 1, 20)
current_batch = seed_batch
with torch.no_grad():
    for _ in range(400):
        predicted_value = model(current_batch)
        extrapolation.append(predicted_value.item())
        current_batch = torch.cat((current_batch[:,:,1:], predicted_value), axis=2)


In [None]:
x = np.arange(110,310,0.5)
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.plot(X_train,y_train, lw=2, label='train data')
ax.plot(X_test,y_test, lw=3, c='purple', label='test data')
ax.plot(x,extrapolation, lw=3, c='cyan',linestyle = ':', label='extrapolation')
ax.legend(loc="lower left")
plt.show();

In [None]:
from brevitas.export import export_onnx_qcdq, export_onnx_qop
export_onnx_qcdq(model, args=torch.randn(1, 20, dtype=torch.float32), export_path='simplenet_qcdq.onnx');

In [None]:
import netron
import time
from IPython.display import IFrame

def show_netron(model_path, port):
    time.sleep(3.)
    netron.start(model_path, address=("0.0.0.0", port), browse=False)
    return IFrame(src=f"http://intelnuc-i7:{port}/", width="100%", height=400)

In [None]:
show_netron("simplenet_qcdq.onnx", 8889)


In [None]:
import onnxruntime as ort

In [None]:
sess = ort.InferenceSession("simplenet_qcdq.onnx", ort.SessionOptions())
input_name = sess.get_inputs()[0].name

In [None]:
extrapolation_onnx = []
extrapolation_brevitas = []
seed_batch = test_series[:20].reshape(1, 1, 20)
current_batch_onnx = seed_batch.reshape(1, 20)
current_batch_brev = seed_batch

with torch.no_grad():
    for _ in range(400):
        predicted_value_onnx = sess.run(None, {input_name: current_batch_onnx.numpy()})[0]
        predicted_value_onnx_ort = torch.tensor(predicted_value_onnx)
        predicted_value_brevitas = model(current_batch_brev)

        extrapolation_onnx.extend(predicted_value_onnx_ort.numpy().tolist())
        extrapolation_brevitas.append(predicted_value_brevitas.item())

        current_batch_onnx = torch.cat((current_batch_onnx[:,1:], predicted_value_onnx_ort), axis=1)
        current_batch_brev = torch.cat((current_batch_brev[:,:,1:], predicted_value_brevitas), axis=2)

In [None]:
x = np.arange(110,310,0.5)
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.plot(X_train,y_train, lw=2, label='train data')
ax.plot(X_test,y_test, lw=3, c='purple', label='test data')
ax.plot(x,extrapolation_onnx, lw=3, c='r',linestyle = ':', label='extrapolation')
ax.plot(x,extrapolation_brevitas, lw=2, c='cyan',linestyle = ':', label='extrapolation')
ax.legend(loc="lower left")
plt.show();