In [5]:
MODEL_PATH = "../src/public/models/"

FFT_BINS = 24

In [6]:
import torch
import torch.nn as nn
import torch.onnx

In [7]:
def ewa(new, old, alpha):
    return new * alpha + old * (1 - alpha)

def scaleToAsymptote(x, asymptote, threshold):
    return torch.minimum(torch.tensor(asymptote), x * asymptote/threshold)

def pw_linear(x, breakpoints_values):
    # Extract x_i and y_i from input
    xp, yp = zip(*breakpoints_values)  # unzip into two lists
    xp = torch.tensor(xp, dtype=x.dtype, device=x.device)
    yp = torch.tensor(yp, dtype=x.dtype, device=x.device)

    # Clamp x to the range
    x_clamped = torch.clamp(x, xp[0], xp[-1])

    # Find the interval index i such that xp[i] <= x < xp[i+1]
    indices = torch.bucketize(x_clamped, xp, right=False)
    indices = torch.clamp(indices, 1, len(xp) - 1)

    x0 = xp[indices - 1]
    x1 = xp[indices]
    y0 = yp[indices - 1]
    y1 = yp[indices]

    # Linear interpolation
    t = (x_clamped - x0) / (x1 - x0 + 1e-8)
    y = y0 + t * (y1 - y0)
    return y

In [8]:
class SmoothedPeak(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, h):   
        peak = (x**1.5).sum(axis=1) * 10
        h_new = ewa(peak, h, 0.025)
        y = torch.tanh(h_new**2 * 2) 
        return y, h_new


# Example usage and export
model = SmoothedPeak()

# Dummy inputs: batch size 1
x = torch.randn(1, FFT_BINS)
h = torch.zeros(1, 1)

# Export to ONNX
torch.onnx.export(
    model,
    (x, h),
    MODEL_PATH + "smooth_peak.onnx",
    input_names=["x", "h"],
    output_names=["output", "h"],
    dynamic_axes={
        "x": {0: "batch"},
        "h": {0: "batch"},
        "output": {0: "batch"},
    },
    opset_version=12,
)

In [23]:
class Playing(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, h_signal, h_binary):
        signal = (x.max(axis=-1)[0] != 0) * 1
        h_signal_new = ewa(signal, h_signal, 0.01)
        binary = (h_signal_new > 0.15) * 1.1 - 0.05
        h_binary_new = ewa(binary, h_binary, 0.05)
        y = torch.clamp(h_binary_new, torch.tensor(0), torch.tensor(1))

        return y, h_signal_new, h_binary_new


# Example usage and export
model = Playing()

# Dummy inputs: batch size 1
x = torch.randn(1, FFT_BINS)
h = torch.zeros(1, 1)
# Export to ONNX
torch.onnx.export(
    model,
    (x, h, h),
    MODEL_PATH + "playing.onnx",
    input_names=["x", "h_signal", "h_binary"],
    output_names=["output", "h_signal", "h_binary"],
    dynamic_axes={
        "x": {0: "batch"},
        "h_signal": {0: "batch"},
        "h_binary": {0: "batch"},
        "output": {0: "batch"},
    },
    opset_version=12,
)

  y = torch.clamp(h_binary_new, torch.tensor(0), torch.tensor(1))


In [9]:
class Momentum(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, h_peak, h_silence, h):
        peak = (x**0.1).sum(axis=-1) * 0.5
        h_peak_new = ewa(peak, h_peak, 0.4)
        is_prominent = (h_peak_new > 0.65) * 1
        
        silence = (x.max(axis=-1)[0] == 0) * 1
        h_silence_new = ewa(silence, h_silence, 0.01)
        is_silent = (h_silence_new > 0.85) * 1

        scale = 0.1
        h_new = torch.clamp(h + is_prominent * 1e-2 - is_silent * 2e-1 - 2e-3,
                            torch.tensor(0), torch.tensor(2/scale))
        s = h_new * scale
        y = pw_linear(s, [(0, 0), (0.3, 0.4), (0.7, 0.6), (1, 0.9), (2, 1)])
        return y, h_peak_new, h_silence_new, h_new


# Example usage and export
model = Momentum()

# Dummy inputs: batch size 1
x = torch.randn(1, FFT_BINS)
h = torch.zeros(1, 1)
# Export to ONNX
torch.onnx.export(
    model,
    (x, h, h, h),
    MODEL_PATH + "momentum.onnx",
    input_names=["x", "h_peak", "h_silence", "h"],
    output_names=["output", "h_peak", "h_silence", "h"],
    dynamic_axes={
        "x": {0: "batch"},
        "h": {0: "batch"},
        "h_peak": {0: "batch"},
        "h_silence": {0: "batch"},
        "output": {0: "batch"},
    },
    opset_version=12,
)

  torch.tensor(0), torch.tensor(2/scale))
  xp = torch.tensor(xp, dtype=x.dtype, device=x.device)
  yp = torch.tensor(yp, dtype=x.dtype, device=x.device)
  indices = torch.clamp(indices, 1, len(xp) - 1)
