In [None]:
import torch
import math
import matplotlib
import matplotlib.pyplot as plt
import os    
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
from model import BJTModel
from tqdm import tqdm
from time import perf_counter_ns
import csv
from pathlib import Path
import pandas as pd
import numpy as np

device = 'cpu'
torch.set_default_device(device)
DEFAULT_TYPE = torch.float64
torch.set_default_dtype(DEFAULT_TYPE)
torch.set_num_threads(1)

# NN parameters and model
hidden_dim = 64
num_layers = 4
input_size = 4
output_size = 4
activation = "elu"
model = BJTModel(hidden_dim=hidden_dim, num_layers=num_layers, device=device, input_size=input_size, output_size=output_size, activation=activation)
model_name = 'checkpoints/tp5_dataset13_cauchy=0.05_current_kDom=True_stX=False_stY=False_activation=elu_bS=64_hDim=64_nLayers=4_lr=0.0001_epoch=123_validLoss=9.849173998206772e-08.pth'
model.load_state_dict(torch.load(model_name))
model.to(device)
model.eval()

# WD Parameter Definition
n_ports = 23
f0 = 100
fs = 44.1e3
Ts = 1/fs
start_time = 0
stop_time = 1
t = np.arange(start_time, stop_time, Ts, dtype='d')

# Voltage Sources
Vin_amp = 0.7
Vin = Vin_amp * np.sin(2*math.pi*f0*t)
V9 = 9
V22 = 9

# Resistive Parameters
R5 = 68e3
R8 = 50e3
R9 = 22e3
R12 = 470e3
R13 = 33e3
R14 = 1.5e3
R15 = 470e3
R17 = 50e3
R18 = 510
R19 = 100e3
R21 = 10e3
R22 = 1e3
R23 = 1e6

# Dynamic Elements
C6 = 0.01e-6
Z_C6 = Ts/C6
C7 = 4.7e-6
Z_C7 = Ts/C7
L10 = 500e-3
Z_L10 = L10/Ts
C11 = 0.22e-6
Z_C11 = Ts/C11 
C16 = 0.22e-6
Z_C16 = Ts/C16
C20 = 0.01e-6
Z_C20 = Ts/C20

# Impedance Matrix
Z = np.diag([0,0,0,0,R5,Z_C6,Z_C7,R8,R9,Z_L10,Z_C11,R12,R13,R14,R15,Z_C16,R17,R18,R19,Z_C20,R21,R22,R23])
z_bjt = np.matrix([
    [7.582196517735750e+03,7.072121313302530e+03,-4.665240435820493e+03,0.039470183442808],
    [7.072121313302530e+03,2.399945437144726e+04,-1.354933372697417e+04,-8.884116433041805e+03],
    [-4.665240435820493e+03,-1.354933372697417e+04,3.424869882868201e+04,2.851618556414040e+04],
    [0.039470183442808,-8.884116433041805e+03,2.851618556414040e+04,2.951621842536739e+04]])
Z[0:4,0:4] = z_bjt

# Fundamental Loop Matrix
Br = np.matrix([
    [1, -1, 0, 0, 0, 1, -1, 0, 0, 0, 0],
    [1, -1, 0, -1, 0, 0, -1, 0, 0, 0, 1],
    [0, 0, -1, 1, 0, 0, 0, 0, 1, 0, -1],
    [0, 0, -1, 1, 0, 0, 0, 0, 0, 1, -1],
    [1, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0],
    [-1, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0],
    [0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, -1],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1],
    [-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, -1, 1, -1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, -1]], dtype='d')
B = np.concatenate([np.eye(12), Br], axis=-1)

# Scattering Matrix
S = np.eye(n_ports) - 2 * (Z @ B.transpose()) @ np.linalg.solve(B @ Z @ B.transpose(), B)

# Initialization of vectors
a = np.zeros(n_ports)
b = np.zeros(n_ports)
a_bjt = torch.zeros(4)

v = np.zeros((n_ports, len(t)))
i = np.zeros((n_ports, len(t)))

# WD Simulation
pbar = tqdm(range(len(t)), desc='WD Simulation Loop')
t_start = perf_counter_ns()
with torch.inference_mode():
    for n in pbar:
        # Dynamic elements
        b[5] = (a[5]+b[5])/2
        b[6] = (a[6]+b[6])/2
        b[10] = (a[10]+b[10])/2
        b[15] = (a[15]+b[15])/2
        b[19] = (a[19]+b[19])/2
        b[9] = (b[9]-a[9])/2
        
        # Voltage sources
        b[4] = Vin[n]
        b[8] = V9
        b[21] = V22
        
        # Nonlinear Root Scattering
        a_bjt = torch.tensor(S[0:4, :] @ b)
        b[0:4] = model(a_bjt)
        
        # Scattering
        a = np.asarray(S @ b).squeeze()
        
        # Kirchoff variables
        v[:, n] = 0.5 * (a + b)
        #i[:, n] = 0.5 * torch.matmul(torch.linalg.inv(Z), a - b)
t_stop = perf_counter_ns()

# Output signal
vout = v[22, :]

# Compute Real Time Ratio (RTR)
rtr = (t_stop-t_start) * 1e-9 * fs/len(t)
print(f'RTR: {rtr}\n')

# Loading the WD Simulation data
gt = []
with open('data/groundtruth_ebersmoll_sin100_steady.csv', 'r') as file:
    csvreader = csv.reader(file)
    for row in csvreader:
        for el in row:
            gt.append(float(el))
gt = gt[int(start_time*fs):int(stop_time*fs)]

start_idx = -1000
end_idx = len(t)

# Plot
fig, ax = plt.subplots(figsize=(18, 8))
ax.grid()
ax.plot(t[start_idx:end_idx], vout[start_idx:end_idx], label="NN")
ax.plot(t[start_idx:end_idx], np.array(gt[start_idx:end_idx]), color='C1', linewidth=3, ls='--', label="SSC groundtruth")
ax.ticklabel_format(axis='x', style='sci')
ax.set_xlabel('Time')
ax.set_ylabel('V')
ax.legend()
plt.gca().set_axisbelow(True)
plt.show()

In [None]:
# onnx export
hidden_dim = 64
num_layers = 4
input_size = 4
output_size = 4
activation = "elu"
model = BJTModel(hidden_dim=hidden_dim, num_layers=num_layers, device=device, input_size=input_size, output_size=output_size, activation=activation)
model_name = 'checkpoints/tp5_dataset13_cauchy=0.05_current_kDom=True_stX=False_stY=False_activation=elu_bS=64_hDim=64_nLayers=4_lr=0.0001_epoch=123_validLoss=9.849173998206772e-08.pth'
model.load_state_dict(torch.load(model_name))
model.eval()

X = torch.zeros(1, 4, requires_grad=True)
torch_out = model(X)
torch.onnx.export(model, torch_out, "pretrained.onnx", export_params=True, opset_version=10)