# Siren Exploration

This is a colab to explore properties of the Siren MLP, proposed in our work [Implicit Neural Activations with Periodic Activation Functions](https://vsitzmann.github.io/siren).


We will first implement a streamlined version of Siren for fast experimentation. This lacks the code to easily do baseline comparisons - please refer to the main code for that - but will greatly simplify the code!

**Make sure that you have enabled the GPU under Edit -> Notebook Settings!**

We will then reproduce the following results from the paper: 
* [Fitting an image](#section_1)
* [Fitting an audio signal](#section_2)
* [Solving Poisson's equation](#section_3)
* [Initialization scheme & distribution of activations](#activations)
* [Distribution of activations is shift-invariant](#shift_invariance)

We will also explore Siren's [behavior outside of the training range](#out_of_range).

Let's go! First, some imports, and a function to quickly generate coordinate grids.

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os

from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np

def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing=None), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

Now, we code up the sine layer, which will be the basic building block of SIREN. This is a much more concise implementation than the one in the main code, as here, we aren't concerned with the baseline comparisons.

In [20]:
from collections import OrderedDict
class ModulatedSineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30, modulation=True):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        # modulation パラメータを追加（学習可能）
        self.modulation = nn.Parameter(torch.ones(out_features)) if modulation else None
        
        self.init_weights()
    
    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
    
    def forward(self, input):
        linear_output = self.linear(input)
        
        # modulationを適用
        if self.modulation is not None:
            linear_output = linear_output * self.modulation
        
        return torch.sin(self.omega_0 * linear_output)
    
    def forward_with_intermediate(self, input): 
        intermediate = self.omega_0 * self.linear(input)
        
        # modulationを適用
        if self.modulation is not None:
            intermediate = intermediate * self.modulation
        
        return torch.sin(intermediate), intermediate

class ModulatedSiren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 
                 omega=30, modulation=True):
        super().__init__()
        self.omega = omega
        self.hidden_layers = hidden_layers
        self.hidden_features = hidden_features
        self.net = []
        
        # 最初の層はis_first=True
        self.net.append(ModulatedSineLayer(in_features, hidden_features, 
                                           is_first=True, omega_0=omega, modulation=modulation))

        # 中間層
        for i in range(hidden_layers):
            self.net.append(ModulatedSineLayer(hidden_features, hidden_features, 
                                               is_first=False, omega_0=omega, modulation=modulation))

        # 最後の層
        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            
            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / omega, 
                                              np.sqrt(6 / hidden_features) / omega)
                
            self.net.append(final_linear)
        else:
            self.net.append(ModulatedSineLayer(hidden_features, out_features, 
                                               is_first=False, omega_0=omega, modulation=modulation))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True)  # allows to take derivative w.r.t. input
        output = self.net(coords)
        return output 

    def forward_with_activations(self, coords, retain_grad=False):
        '''Returns not only model output, but also intermediate activations.
        Only used for visualizing activations later!'''
        activations = OrderedDict()

        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, ModulatedSineLayer):
                x, intermed = layer.forward_with_intermediate(x)
                
                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()
                    
                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else: 
                x = layer(x)
                
                if retain_grad:
                    x.retain_grad()
                    
            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations

In [21]:
def laplace(y, x):
    grad = gradient(y, x)
    return divergence(grad, x)


def divergence(y, x):
    div = 0.
    for i in range(y.shape[-1]):
        div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
    return div


def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

# Experiments

For the image fitting and poisson experiments, we'll use the classic cameraman image.

In [22]:
from scipy.io import wavfile
class AudioFile(Dataset):
    def __init__(self, filename):
        super().__init__()
        self.rate, self.data = wavfile.read(filename)  # 音声ファイルを読み込む
        # ステレオ音声の場合はモノラルに変換
        if len(self.data.shape) > 1 and self.data.shape[1] == 2:
            self.data = np.mean(self.data, axis=1)
        self.data = self.data.astype(np.float32)  # データをfloat32にキャスト
        self.file_length = len(self.data)  # 音声ファイルのサンプル数
        print(f"Rate: {self.rate} Hz, Length: {self.file_length} samples")

    def __len__(self):
        return 1
    
    def get_length(self):
        return self.file_length

    def get_rate(self):
        return self.rate

    def __getitem__(self, idx):
        return self.rate, self.data

## Fitting an audio signal
<a id='section_2'></a>

Here, we'll use Siren to parameterize an audio signal - i.e., we seek to parameterize an audio waverform $f(t)$  at time points $t$ by a SIREN $\Phi$.

That is we seek the function $\Phi$ such that:  $\mathcal{L}\int_\Omega \lVert \Phi(t) - f(t) \rVert \mathrm{d}t$  is minimized, in which  $\Omega$  is the domain of the waveform.

For the audio, we'll use the bach sonata:

In [23]:
# ImplicitAudioWrapperクラス: 音声データを格納し、座標データを生成
class ImplicitAudioWrapper(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        # -100から100までの範囲で、音声ファイルの長さに対応する均等なグリッドを生成
        self.grid = np.linspace(start=-100, stop=100, num=dataset.file_length)
        self.grid = self.grid.astype(np.float32)
        self.grid = torch.Tensor(self.grid).view(-1, 1)  # グリッドをPyTorchのテンソルに変換

    def get_num_samples(self):
        return self.grid.shape[0]

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        # 音声データとサンプリングレートを取得
        rate, data = self.dataset[idx]
        # 音声データをスケールし、テンソルに変換
        scale = np.max(np.abs(data))  # 音声データの最大値でスケール
        data = (data / scale)  # 正規化
        data = torch.Tensor(data).view(-1, 1)  # PyTorchテンソルに変換
        # 座標テンソルと正規化された音声データを返す
        return {'idx': idx, 'coords': self.grid}, {'func': data, 'rate': rate, 'scale': scale}

Let's build a little dataset that computes coordinates for audio files:

In [None]:
# AudioFileオブジェクトを作成
audio_dataset = AudioFile(filename="sirenvoice3.wav")
# ImplicitAudioWrapperを使用してデータをラップ
coord_dataset = ImplicitAudioWrapper(audio_dataset)
# DataLoaderを作成
dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)
audio_samples = audio_dataset.get_length()
audio_rate = audio_dataset.get_rate()

# データローダーからデータを取り出す
for batch in dataloader:
    coords = batch[0]['coords']
    func = batch[1]['func']
    rate = batch[1]['rate']
    scale = batch[1]['scale']
    print(f"Coords: {coords.shape}, Audio Data: {func.shape}, Rate: {rate}, Scale: {scale}")

We now fit Siren in a simple training loop. Within only hundreds of iterations, the image and its gradients are approximated well.

In [25]:
import copy
import torch.optim as optim
import matplotlib.pyplot as plt

def train_siren(dataloader, hidden_features, hidden_layers, omega, epochs):
    """
    SIRENモデルをトレーニングする関数（音声データ用）

    Args:
        dataloader (DataLoader): 音声データ用のデータローダー
        hidden_features (int): 隠れ層のユニット数
        hidden_layers (int): 隠れ層の数
        omega (float): 初期周波数
        epochs (int): エポック数

    Returns:
        Siren: トレーニング済みのSIRENモデル
    """
    # データローダーから最初のバッチを取得
    batch = next(iter(dataloader))
    
    # バッチ次元を除去
    model_input = batch[0]['coords'].squeeze(0)  # 形状 [N, 1]
    ground_truth = batch[1]['func'].squeeze(0)   # 形状 [N, 1]
    
    # モデルをGPUに移動する場合は以下を有効化
    model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

    best_loss = float('inf')
    siren_best = None
    optimizer_best = None
    plot_epochs =  epochs / 5 #グラフを表示する頻度

    for stage in range(2):
        # ステージごとに異なる初期設定が可能
        for epoch in range(epochs):
            if stage == 0:
                # 初期段階では1次元入力用にSIRENモデルを初期化
                siren = ModulatedSiren(in_features=1, out_features=1, hidden_features=hidden_features, 
                              hidden_layers=hidden_layers, outermost_linear=True, omega=omega)
                siren.cuda()  # GPUを使用する場合
                optimizer = optim.Adam(siren.parameters(), lr=1e-4)
            elif stage == 1:
                # 2段階目では学習率を下げるなどの調整が可能
                siren = siren_best
                optimizer = optim.Adam(siren.parameters(), lr=1e-5)

            # トレーニングモードに設定
            siren.train()
            optimizer.zero_grad()

            # フォワードパス
            model_output = siren(model_input)  # 出力テンソル（[N, 1]）

            # ロス計算（MSE）
            loss = nn.MSELoss()(model_output, ground_truth)

            # バックワードパス
            loss.backward()
            optimizer.step()

            # ロスの監視とモデルの保存
            if loss.item() < best_loss:
                best_loss = loss.item()
                siren_best = copy.deepcopy(siren)
                optimizer_best = copy.deepcopy(optimizer)
            
            # ログの出力
            if (epoch + 1) % 10 == 0 or epoch == 0:
                print(f"Stage: {stage}, Epoch: {epoch+1}/{epochs}, Loss: {loss.item():.6f}")

                if (stage != 0 and (epoch+1) % plot_epochs == 0) or (stage == 0 and loss < best_loss):
                    # モデルの出力を可視化
                    siren.eval()
                    with torch.no_grad():
                        output = model_output.cpu().numpy()  # 形状 [N, 1]
                    ground = ground_truth.cpu().numpy()           # 形状 [N, 1]

                    # プロット時にバッチ次元を除去し、形状を調整
                    plt.figure(figsize=(12, 6))
                    plt.plot(batch[0]['coords'].squeeze(0).cpu().detach().numpy().flatten(), 
                            ground.flatten(), label='Ground Truth')
                    plt.plot(batch[0]['coords'].squeeze(0).cpu().detach().numpy().flatten(), 
                            output.flatten(), label='SIREN Output')
                    plt.legend()
                    plt.title(f'Stage: {stage}, Epoch: {epoch+1}, Loss: {loss.item():.6f}')
                    plt.xlabel('Time')
                    plt.ylabel('Amplitude')
                    plt.show()

        # ステージ終了後、最良モデルを保持
        siren = siren_best
        optimizer = optimizer_best

    siren_best.eval()
    outplt = siren_best(model_input)
    outplt = outplt.cpu().detach().numpy()

    return siren_best

In [None]:
from IPython.display import Audio
siren_voice = train_siren(dataloader, 28, 1, 20, 1000000) 
#siren_voice = train_siren(dataloader, 32, 1, 20, 9000)

Generate an audio waveform using the trained SIREN model and save it as a WAV file.

In [27]:
import torchaudio
def save_audio(model, dataloader, num_samples, sample_rate, path):
    """
    学習済みSIRENモデルを使用してオーディオ波形を生成し、WAVファイルとして保存します。

    Args:
        model (Siren): 学習済みのSIRENモデル。
        duration (float): オーディオの長さ（秒）。
        sample_rate (int): サンプリングレート（例: 44100）。
        path (str): 保存するWAVファイルのパス。
    """
    model.eval() #モデルを評価状態に切り替える
    with torch.no_grad():
        # サンプル数の計算
        duration = num_samples / sample_rate
        # 0からdurationまでの線形間隔の時間座標を生成
        t = torch.linspace(0, duration, steps=num_samples).unsqueeze(1)  # 形状 [num_samples, 1]
        
        # モデルがGPU上にある場合はデバイスを合わせる
        device = next(model.parameters()).device
        t = t.to(device)
        
        # データローダーから最初のバッチを取得
        batch = next(iter(dataloader))
        model_input = batch[0]['coords'].squeeze(0)  # 形状 [N, 1]
        model_input = model_input.cuda()
        outplt = model(model_input)
        waveform = outplt.cpu().numpy()
        print(f"Waveform shape before processing: {waveform.shape}")

        plt.plot(batch[0]['coords'].squeeze(0).cpu().detach().numpy().flatten(), waveform, label='SIREN Output')
        plt.legend()
        plt.xlabel('Time')
        plt.ylabel('Amplitude')
        plt.show()
        
        # 正規化（-1から1の範囲にスケーリング）
        max_val = np.max(np.abs(waveform))
        if max_val > 1:
            waveform = waveform / max_val
            print(f"maxval: {max_val}")
        
        # float32型に変換
        waveform = waveform.astype(np.float32)
        
        # 波形の形状を [1, samples] に変更
        if waveform.ndim == 2 and waveform.shape[1] == 1:
            waveform = waveform.T  # 形状を [1, samples] に転置
            print(f"Waveform shape after transposing: {waveform.shape}")  # 例: (1, 28848)
        elif waveform.ndim == 1:
            waveform = np.expand_dims(waveform, axis=0)  # 形状を [1, samples] にする
            print(f"Waveform shape after expanding dims: {waveform.shape}")
        else:
            raise ValueError(f"Unexpected waveform shape: {waveform.shape}")
        
        # 再度形状を確認（デバッグ用）
        print(f"Final waveform shape: {waveform.shape}")
        
        # PyTorchテンソルに変換
        waveform_tensor = torch.from_numpy(waveform)
    
    # torchaudioを使用してWAVファイルとして保存
    try:
        torchaudio.save(path, waveform_tensor, sample_rate)
        print(f"Audio saved successfully at {path}")
    except Exception as e:
        print(f"Failed to save audio: {e}")

In [None]:
save_audio(siren_voice, dataloader, audio_samples, audio_rate, "siren_best.wav" )

Generates the shader code.

In [None]:
import re

def dump_data(dat):
    dat = dat.cpu().detach().numpy()
    return dat

def print_vec4(ws):
    vec = "vec4(" + ",".join(["{0:.4g}".format(w) for w in ws]) + ")"
    vec = re.sub(r"\b0\.", ".", vec)
    return vec

def print_mat4(ws):
    mat = "mat4(" + ",".join(["{0:.3g}".format(w) for w in np.transpose(ws).flatten()]) + ")"
    mat = re.sub(r"\b0\.", ".", mat)
    return mat

def serialize_to_shadertoy(siren, varname):
    omega = siren.omega
    hidden_features = siren.hidden_features
    hidden_layers = siren.hidden_layers
    chunks = int(hidden_features / 4)
    print(f"//Chunksds: {chunks}, omega: {omega}, hidden_features: {hidden_features}, hidden_layers: {hidden_layers}")
    
    # 最初の層の重み、バイアス、modulationを取得
    in_w = dump_data(siren.net[0].linear.weight)  # shape: [hidden_features, in_features]
    in_bias = dump_data(siren.net[0].linear.bias)  # shape: [hidden_features]
    modulation = dump_data(siren.net[0].modulation)  # shape: [hidden_features]
    
    in_features = in_w.shape[1]
    
    # 最初の層の処理
    for row in range(chunks):
        if in_features == 2:
            x_vec = in_w[row*4:(row+1)*4, 1] * omega
            y_vec = in_w[row*4:(row+1)*4, 0] * -1 * omega
            bias = in_bias[row*4:(row+1)*4] * omega
            mod_vec = modulation[row*4:(row+1)*4]
            shader_line = (
                f"vec4 {varname}0_{row} = sin((uv.x * {print_vec4(x_vec)} + "
                f"uv.y * {print_vec4(y_vec)} + {print_vec4(bias)}) * {print_vec4(mod_vec)});"
            )
        elif in_features == 1:
            x_vec = in_w[row*4:(row+1)*4, 0] * omega
            bias = in_bias[row*4:(row+1)*4] * omega
            mod_vec = modulation[row*4:(row+1)*4]
            shader_line = (
                f"vec4 {varname}0_{row} = sin((t * {print_vec4(x_vec)} + {print_vec4(bias)}) * {print_vec4(mod_vec)});"
            )
        else:
            raise ValueError(f"Unsupported number of input features: {in_features}")
        
        print(shader_line)
    
    # 隠れ層の処理
    for layer in range(hidden_layers):
        layer_w = dump_data(siren.net[layer+1].linear.weight)  # shape: [hidden_features, hidden_features]
        layer_bias = dump_data(siren.net[layer+1].linear.bias)  # shape: [hidden_features]
        modulation = dump_data(siren.net[layer+1].modulation)  # shape: [hidden_features]

        for row in range(chunks):
            line = f"vec4 {varname}{layer+1}_{row} = sin(\n"
            for col in range(chunks):
                mat = layer_w[row*4:(row+1)*4, col*4:(col+1)*4] * omega
                line += f"    {print_mat4(mat)} * {varname}{layer}_{col} +\n"
            bias = layer_bias[row*4:(row+1)*4] * omega
            mod_vec = modulation[row*4:(row+1)*4]
            line += f"    {print_vec4(bias)} * {print_vec4(mod_vec)});"
            print(line)
    
    # 出力層の処理
    out_w = dump_data(siren.net[-1].weight)  # shape: [1, hidden_features]
    out_bias = dump_data(siren.net[-1].bias)  # shape: [1]
    
    line = f"float {varname} = "
    for row in range(chunks):
        vec = out_w[0, row*4:(row+1)*4]
        line += f"dot({varname}{hidden_layers}_{row}, {print_vec4(vec)}) +\n    "
    line += f"{out_bias[0]};"
    print(line)

print("//luma network")
serialize_to_shadertoy(siren_voice, "f")