In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import os, glob, random
# from IPython.display import Audio, display # Keep this if in Colab/Jupyter

# --- Constants ---
TARGET_SR = 44100
N_FFT = 1024 # Window size for STFT
HOP_LENGTH = 256 # Hop length for STFT
N_MELS = 128 # Number of Mel frequency bins
LATENT_DIM = 64
FIXED_FRAMES = 512 # Target number of frames in the spectrogram
TOTAL_EPOCHS = 3000
CHUNK_SIZE = 100
NUM_CHUNKS = TOTAL_EPOCHS // CHUNK_SIZE
KL_TARGET = 0.1
KL_WARMUP = 500

# --- Data Loading and Preprocessing ---
folder_path = '/content/drive/MyDrive/Neural Drum Machine/Samples/01. Bass Drum' # UPDATE THIS

if not os.path.exists(folder_path) or not os.listdir(folder_path):
    print(f"Warning: Folder '{folder_path}' is empty or not found.")
    print("Attempting to use/create dummy data.")
    dummy_data_dir = "dummy_audio_data_mel"
    os.makedirs(dummy_data_dir, exist_ok=True)
    dummy_file_path = os.path.join(dummy_data_dir, "dummy_kick_mel.wav")
    if not os.path.exists(dummy_file_path):
        try:
            from scipy.io.wavfile import write as write_wav
            sample_rate = TARGET_SR; duration = 1; frequency = 440
            t = np.linspace(0, duration, int(sample_rate * duration), False)
            note = np.sin(frequency * t * 2 * np.pi)
            audio_data = (note * (32767 / np.max(np.abs(note)))).astype(np.int16)
            write_wav(dummy_file_path, sample_rate, audio_data)
            print(f"Created dummy WAV: {dummy_file_path}")
        except Exception as e: print(f"Error creating dummy WAV: {e}")
    if os.path.exists(dummy_file_path): folder_path = dummy_data_dir
    else: print("Could not create/find dummy audio.")

files_list = glob.glob(os.path.join(folder_path, '*.wav'))
print(f"Found {len(files_list)} samples in '{folder_path}'.")

def wav_to_mel_spec(filename, sr=TARGET_SR, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS, target_frames=FIXED_FRAMES):
    """Converts a WAV file to a normalized log-Mel spectrogram."""
    try:
        y, file_sr = librosa.load(filename, sr=None)
        if file_sr != sr: y = librosa.resample(y, orig_sr=file_sr, target_sr=sr)

        y, _ = librosa.effects.trim(y, top_db=30)

        target_length = sr
        if len(y) > target_length: y = y[:target_length]
        else: y = np.pad(y, (0, target_length - len(y)), mode='constant')

        # Compute Mel spectrogram
        S_mel = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)

        # Convert to log scale (log(1 + S_mel) for stability)
        log_S_mel = np.log1p(S_mel)

        # Normalize
        if log_S_mel.max() > 0:
            log_S_mel = log_S_mel / log_S_mel.max()
        else:
            log_S_mel = np.zeros_like(log_S_mel)

        # Pad or truncate to FIXED_FRAMES
        if log_S_mel.shape[1] < target_frames:
            pad_width = target_frames - log_S_mel.shape[1]
            log_S_mel = np.pad(log_S_mel, ((0, 0), (0, pad_width)), mode='constant', constant_values=0)
        else:
            log_S_mel = log_S_mel[:, :target_frames]

        return log_S_mel
    except Exception as e:
        print(f"Error processing {filename} for Mel spec: {e}")
        return np.zeros((n_mels, target_frames), dtype=np.float32)

SAMPLES_LIST = [wav_to_mel_spec(f) for f in files_list if os.path.exists(f)]
SAMPLES_LIST = [s for s in SAMPLES_LIST if s is not None and s.shape == (N_MELS, FIXED_FRAMES)]

if not SAMPLES_LIST:
    print(f"Could not load valid Mel spectrograms. Using random noise ({N_MELS}x{FIXED_FRAMES}).")
    SAMPLES = np.random.rand(1, N_MELS, FIXED_FRAMES).astype(np.float32)
else:
    SAMPLES = np.stack(SAMPLES_LIST)

print(f"SAMPLES (Mel Spectrograms) shape: {SAMPLES.shape}") # Expected: (N, N_MELS, FIXED_FRAMES)

class DrumDataset(torch.utils.data.Dataset):
    def __init__(self, specs): self.specs = torch.tensor(specs, dtype=torch.float32)
    def __len__(self): return len(self.specs)
    def __getitem__(self, idx): return self.specs[idx]

if SAMPLES.shape[0] > 0:
    loader = torch.utils.data.DataLoader(DrumDataset(SAMPLES), batch_size=16, shuffle=True)
else:
    print("No samples for DataLoader.")
    loader = None

# --- Simple VAE (Encoder and Decoder adjusted for N_MELS) ---
class SimpleEncoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, n_input_channels=N_MELS): # n_input_channels is N_MELS
        super().__init__()
        # Input: (B, N_MELS, FIXED_FRAMES) e.g. (B, 128, 512)
        self.conv = nn.Sequential(
            nn.Conv1d(n_input_channels, 128, kernel_size=4, stride=2, padding=1), # (B, 128, 256)
            nn.ReLU(),
            nn.Conv1d(128, 64, kernel_size=4, stride=2, padding=1),            # (B, 64, 128)
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=4, stride=2, padding=1),             # (B, 32, 64)
            nn.ReLU()
        )
        self.flatten = nn.Flatten()
        self.final_conv_output_frames = FIXED_FRAMES // 8
        self.mu = nn.Linear(32 * self.final_conv_output_frames, latent_dim)
        self.logvar = nn.Linear(32 * self.final_conv_output_frames, latent_dim)

    def forward(self, x):
        h = self.conv(x)
        h = self.flatten(h)
        return self.mu(h), self.logvar(h)

class SimpleDecoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, n_output_channels=N_MELS): # n_output_channels is N_MELS
        super().__init__()
        self.n_output_channels = n_output_channels
        self.initial_frames = FIXED_FRAMES // 4
        self.initial_channels = 64

        self.fc = nn.Linear(latent_dim, self.initial_channels * self.initial_frames)

        self.decode_layers = nn.Sequential(
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv1d(self.initial_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv1d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, self.n_output_channels, kernel_size=1), # Output N_MELS channels
            nn.ReLU()
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, self.initial_channels, self.initial_frames)
        recon_mel_spec = self.decode_layers(x)
        return recon_mel_spec

class SimpleVAE(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super().__init__()
        self.encoder = SimpleEncoder(latent_dim, n_input_channels=N_MELS)
        self.decoder = SimpleDecoder(latent_dim, n_output_channels=N_MELS)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar

# --- Setup for Training ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

vae = SimpleVAE(latent_dim=LATENT_DIM).to(device)
opt = optim.Adam(vae.parameters(), lr=1e-3)

def kl_divergence_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)

def mel_spec_to_audio(log_mel_spec_norm, sr=TARGET_SR, n_fft=N_FFT, hop_length=HOP_LENGTH):
    """Converts a normalized log-Mel spectrogram back to audio."""
    if isinstance(log_mel_spec_norm, torch.Tensor):
        log_mel_spec_norm = log_mel_spec_norm.detach().cpu().numpy()

    current_max = log_mel_spec_norm.max()
    if current_max <= 1e-6: # Effectively silent
        return np.zeros(sr)

    # Denormalize based on its own max (consistent with input processing)
    log_mel_spec_scaled = log_mel_spec_norm * current_max

    # Inverse of log1p is expm1
    mel_spec_approx = np.expm1(log_mel_spec_scaled)

    # Convert Mel spectrogram to linear magnitude STFT spectrogram
    # This is an approximation
    stft_magnitude_approx = librosa.feature.inverse.mel_to_stft(
        M=mel_spec_approx, sr=sr, n_fft=n_fft
    )

    # Griffin-Lim algorithm to reconstruct phase and convert to audio
    audio_out = librosa.griffinlim(stft_magnitude_approx, n_iter=32, hop_length=hop_length, n_fft=n_fft)
    return audio_out

# --- Training Loop ---
if loader is not None and SAMPLES.shape[0] > 0:
    start_epoch = 0
    for chunk_idx in range(NUM_CHUNKS):
        print(f"\n--- Chunk {chunk_idx+1}/{NUM_CHUNKS} (Epochs {start_epoch+1}-{start_epoch+CHUNK_SIZE}) ---")
        for epoch in range(start_epoch, start_epoch + CHUNK_SIZE):
            vae.train()
            total_recon_loss_epoch, total_kl_loss_epoch = 0, 0
            kl_weight = min(KL_TARGET, KL_TARGET * (epoch / KL_WARMUP) if KL_WARMUP > 0 else KL_TARGET)

            for batch_specs in loader:
                batch_specs = batch_specs.to(device)
                opt.zero_grad()
                recon_specs, mu, logvar = vae(batch_specs)

                recon_loss = F.l1_loss(recon_specs, batch_specs, reduction='sum') / batch_specs.size(0)
                kl_loss_batch = kl_divergence_loss(mu, logvar).mean()
                loss = recon_loss + kl_weight * kl_loss_batch

                loss.backward()
                opt.step()

                total_recon_loss_epoch += recon_loss.item() * batch_specs.size(0)
                total_kl_loss_epoch += kl_loss_batch.item() * batch_specs.size(0)

            avg_recon_loss = total_recon_loss_epoch / len(SAMPLES)
            avg_kl_loss = total_kl_loss_epoch / len(SAMPLES)
            avg_total_loss = avg_recon_loss + kl_weight * avg_kl_loss

            if (epoch + 1) % 50 == 0 or epoch == start_epoch + CHUNK_SIZE -1:
                print(f"Epoch {epoch+1:03d} | Total Loss: {avg_total_loss:.4f} | Recon: {avg_recon_loss:.4f} | KL: {avg_kl_loss:.4f} (W: {kl_weight:.4f})")

        # --- Preview ---
        vae.eval()
        with torch.no_grad():
            preview_batch = next(iter(loader)).to(device)
            idx = random.randint(0, preview_batch.size(0) - 1)
            real_spec_tensor = preview_batch[idx:idx+1]
            recon_spec_tensor, _, _ = vae(real_spec_tensor)
            random_latent_z = torch.randn(1, LATENT_DIM).to(device)
            generated_spec_tensor = vae.decoder(random_latent_z)

            real_np = real_spec_tensor.squeeze().cpu().numpy()
            recon_np = recon_spec_tensor.squeeze().cpu().numpy()
            generated_np = generated_spec_tensor.squeeze().cpu().numpy()

            plt.figure(figsize=(18, 5))
            plt.suptitle(f"VAE (Mel Spec Decoder) - Chunk {chunk_idx+1} — Epoch {start_epoch+CHUNK_SIZE}", fontsize=16)

            plt.subplot(1, 3, 1); librosa.display.specshow(real_np, sr=TARGET_SR, hop_length=HOP_LENGTH, x_axis='time', y_axis='mel'); plt.title("Original Mel Spec"); plt.colorbar(format='%+2.0f dB')
            plt.subplot(1, 3, 2); librosa.display.specshow(recon_np, sr=TARGET_SR, hop_length=HOP_LENGTH, x_axis='time', y_axis='mel'); plt.title("Reconstructed Mel Spec"); plt.colorbar(format='%+2.0f dB')
            plt.subplot(1, 3, 3); librosa.display.specshow(generated_np, sr=TARGET_SR, hop_length=HOP_LENGTH, x_axis='time', y_axis='mel'); plt.title("Generated Mel Spec"); plt.colorbar(format='%+2.0f dB')

            plt.tight_layout(rect=[0, 0.05, 1, 0.95]); plt.show()

            try:
                from IPython.display import Audio, display
                print("\nOriginal Kick (from Mel Spec):"); display(Audio(mel_spec_to_audio(real_np), rate=TARGET_SR, normalize=False))
                print("Reconstructed Kick (from Mel Spec):"); display(Audio(mel_spec_to_audio(recon_np), rate=TARGET_SR, normalize=False))
                print("Generated Kick (from Mel Spec):"); display(Audio(mel_spec_to_audio(generated_np), rate=TARGET_SR, normalize=False))
            except ImportError: print("\nIPython.display.Audio not available. Skipping audio playback.")
        start_epoch += CHUNK_SIZE
else:
    print("Skipping training: No data or loader.")
print("\n--- Script Finished ---")


In [None]:
import torch
import torch.nn as nn
import time
import numpy as np

# --- Constants (ensure these match your VAE's decoder architecture) ---
LATENT_DIM = 64
N_MELS = 128       # Number of Mel bins the decoder outputs
FIXED_FRAMES = 512 # Number of time frames the decoder outputs

# --- Simplified Decoder Definition (from your VAE for Mel Spectrograms) ---
class SimpleDecoder(nn.Module):
    """Simplified Decoder part of the VAE using Upsample + Conv1D."""
    def __init__(self, latent_dim=LATENT_DIM, n_output_channels=N_MELS, fixed_frames=FIXED_FRAMES):
        super().__init__()
        self.n_output_channels = n_output_channels
        # Initial number of frames and channels after the FC layer, before upsampling
        self.initial_frames = fixed_frames // 4  # e.g., 512 / 4 = 128. We'll upsample twice by 2x.
        self.initial_channels = 64  # Hyperparameter: can be tuned

        # Fully connected layer to project latent space to initial deconv shape
        self.fc = nn.Linear(latent_dim, self.initial_channels * self.initial_frames)

        # Upsampling and convolutional layers
        self.decode_layers = nn.Sequential(
            nn.ReLU(), # Activation after FC

            # Block 1: Upsample from initial_frames to initial_frames*2
            nn.Upsample(scale_factor=2, mode='nearest'), # (B, initial_channels, initial_frames*2)
            nn.Conv1d(self.initial_channels, 32, kernel_size=3, stride=1, padding=1), # (B, 32, initial_frames*2)
            nn.ReLU(),

            # Block 2: Upsample from initial_frames*2 to initial_frames*4 (target fixed_frames)
            nn.Upsample(scale_factor=2, mode='nearest'), # (B, 32, fixed_frames)
            nn.Conv1d(32, 16, kernel_size=3, stride=1, padding=1), # (B, 16, fixed_frames)
            nn.ReLU(),

            # Final convolution to map to the target number of output channels (n_output_channels)
            nn.Conv1d(16, self.n_output_channels, kernel_size=1), # (B, n_output_channels, fixed_frames)
            nn.ReLU()
        )

    def forward(self, z): # z: (B, latent_dim)
        # Project latent vector and reshape to start decoding
        x = self.fc(z) # (B, initial_channels * initial_frames)
        x = x.view(-1, self.initial_channels, self.initial_frames) # (B, initial_channels, initial_frames)

        # Pass through upsampling and convolutional layers
        recon_spec = self.decode_layers(x) # (B, n_output_channels, fixed_frames)
        return recon_spec

def test_decoder_inference_speed(decoder_model, latent_dimension, current_device, num_runs=100, warm_up_runs=10):
    """
    Tests the inference speed of the provided decoder model.

    Args:
        decoder_model (nn.Module): The decoder model instance.
        latent_dimension (int): The dimension of the latent space.
        current_device (torch.device): The device to run the model on (e.g., 'cuda' or 'cpu').
        num_runs (int): Number of inference runs to average for timing.
        warm_up_runs (int): Number of initial runs to discard.
    """
    decoder_model.eval() # Set the model to evaluation mode
    decoder_model.to(current_device) # Ensure model is on the correct device

    inference_times = []

    print(f"\n--- Starting Decoder Inference Speed Test ---")
    print(f"Device: {current_device}")
    print(f"Input latent dimension: {latent_dimension}")
    print(f"Output spectrogram shape (expected): (1, {decoder_model.n_output_channels}, {FIXED_FRAMES})") # Assuming fixed_frames is global or passed
    print(f"Warm-up runs: {warm_up_runs}")
    print(f"Timed runs: {num_runs}")

    # Create a dummy input tensor once
    # Batch size of 1 for single latent vector test
    z_input = torch.randn(1, latent_dimension).to(current_device)

    for i in range(warm_up_runs + num_runs):
        if current_device.type == 'cuda':
            torch.cuda.synchronize() # Wait for all preceding GPU ops to finish

        start_time = time.perf_counter()

        with torch.no_grad(): # Disable gradient calculations for inference
            _ = decoder_model(z_input) # Perform inference

        if current_device.type == 'cuda':
            torch.cuda.synchronize() # Wait for decoder op to finish

        end_time = time.perf_counter()

        if i >= warm_up_runs: # Only record times after warm-up
            time_ms = (end_time - start_time) * 1000
            inference_times.append(time_ms)

    if not inference_times:
        print("No timed runs were recorded. Check num_runs.")
        return

    # --- Calculate and print statistics ---
    avg_time = np.mean(inference_times)
    std_time = np.std(inference_times)
    min_time = np.min(inference_times)
    max_time = np.max(inference_times)

    print("\n--- Decoder Speed Test Results ---")
    print(f"Average Inference Time (per latent vector): {avg_time:.4f} ms")
    print(f"Standard Deviation: {std_time:.4f} ms")
    print(f"Min Time: {min_time:.4f} ms, Max Time: {max_time:.4f} ms")

    fps = 1000 / avg_time if avg_time > 0 else float('inf')
    print(f"Approximate Throughput: {fps:.2f} latent vectors/second")
    print("--- End of Decoder Speed Test ---")

if __name__ == '__main__':
    # Determine device
    selected_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Instantiate the decoder
    # Pass fixed_frames to the constructor if it's used internally for layer sizing
    decoder = SimpleDecoder(latent_dim=LATENT_DIM,
                            n_output_channels=N_MELS,
                            fixed_frames=FIXED_FRAMES)

    # --- IMPORTANT ---
    # For a realistic speed test of a TRAINED model, you would load its state_dict here:
    # try:
    #     # model_path = 'path_to_your_trained_vae.pth' # Path to the full VAE model
    #     # full_vae_state_dict = torch.load(model_path, map_location=selected_device)
    #     # decoder_state_dict = {k.replace('decoder.', ''): v for k, v in full_vae_state_dict.items() if k.startswith('decoder.')}
    #     # decoder.load_state_dict(decoder_state_dict)
    #     # print(f"Successfully loaded trained decoder weights for speed test.")
    #     print("Skipping loading of trained decoder weights for this simple test. Using initialized model.")
    # except FileNotFoundError:
    #     print(f"Warning: Trained model weights file not found. Speed test will run with an UNTRAINED (randomly initialized) decoder.")
    # except Exception as e:
    #     print(f"Error loading decoder weights: {e}. Speed test will run with an UNTRAINED decoder.")
    print("Using initialized (untrained) decoder for speed test.")

    # Run the speed test
    test_decoder_inference_speed(decoder,
                                 LATENT_DIM,
                                 selected_device,
                                 num_runs=200,  # Increase for more stable averages
                                 warm_up_runs=20)


Using initialized (untrained) decoder for speed test.

--- Starting Decoder Inference Speed Test ---
Device: cuda
Input latent dimension: 64
Output spectrogram shape (expected): (1, 128, 512)
Warm-up runs: 20
Timed runs: 200

--- Decoder Speed Test Results ---
Average Inference Time (per latent vector): 0.4897 ms
Standard Deviation: 0.0652 ms
Min Time: 0.4547 ms, Max Time: 1.2477 ms
Approximate Throughput: 2041.97 latent vectors/second
--- End of Decoder Speed Test ---


In [None]:
import torch
import torch.nn as nn
import numpy as np
import os # For loading model

# --- Constants (from your VAE decoder) ---
LATENT_DIM = 64
N_MELS = 128
FIXED_FRAMES = 512

# --- SimpleDecoder Definition ---
class SimpleDecoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, n_output_channels=N_MELS, fixed_frames=FIXED_FRAMES):
        super().__init__()
        self.n_output_channels = n_output_channels
        self.initial_frames = fixed_frames // 4
        self.initial_channels = 64

        self.fc_out_features = self.initial_channels * self.initial_frames
        self.fc = nn.Linear(latent_dim, self.fc_out_features)

        # We'll simulate layer by layer. For now, define up to the first ReLU.
        self.relu1 = nn.ReLU()

        # Placeholder for subsequent layers - we'll add them as we simulate
        # self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        # self.conv1 = nn.Conv1d(self.initial_channels, 32, kernel_size=3, stride=1, padding=1)
        # ... and so on

    def forward_fc_only(self, z):
        """Forward pass up to the output of the fc layer."""
        return self.fc(z)

    def forward_fc_relu1(self, z):
        """Forward pass up to the output of the first ReLU."""
        x_fc = self.fc(z)
        x_relu1 = self.relu1(x_fc)
        return x_fc, x_relu1 # Return intermediate for analysis if needed

    # Full forward pass (for completeness, not fully simulated in fixed-point yet)
    def forward(self, z):
        x = self.fc(z)
        x = self.relu1(x) # First ReLU

        # Reshape before convolutions
        x_reshaped = x.view(-1, self.initial_channels, self.initial_frames)

        # The rest of the decode_layers (conceptual, from original model)
        # x_upsample1 = self.upsample1(x_reshaped)
        # x_conv1 = self.conv1(x_upsample1)
        # ...
        # For now, this function is not fully implemented with all layers defined above.
        # We are focusing on fc and relu1 first.
        # To make this runnable as-is, we'd need to define all layers or return early.
        # For this script's purpose, we'll call specific forward_fc_relu1 etc.

        # This is just a placeholder to make the class complete
        # In a full fixed-point sim, you'd pass data through each simulated layer.
        # For now, let's assume we are interested in the output of relu1 reshaped.
        return x.view(-1, self.initial_channels, self.initial_frames) # Example output shape


# --- Fixed-Point Simulation Parameters ---
# YOU MUST DETERMINE THESE BASED ON YOUR RANGE ANALYSIS OF THE TRAINED MODEL
TOTAL_BITS = 16  # Example: 16-bit fixed-point for data and weights
FRAC_BITS = 8   # Example: 8 fractional bits for data and weights
# This implies: Integer Bits = TOTAL_BITS - FRAC_BITS - 1 (for sign) = 7

# Parameters for weights (could be different from activations)
WEIGHT_TOTAL_BITS = 16
WEIGHT_FRAC_BITS = 10 # Example: more fractional bits for weights if their range is small (e.g. [-2, 2])

# --- Quantization and Dequantization Functions ---
def quantize_value(value, total_bits, frac_bits):
    scale = 2.0**frac_bits
    min_val_scaled = -(2.0**(total_bits - 1))
    max_val_scaled = (2.0**(total_bits - 1)) - 1
    scaled_value = np.round(value * scale)
    clamped_value = np.clip(scaled_value, min_val_scaled, max_val_scaled)
    return clamped_value.astype(np.int64)

def dequantize_value(quantized_value_scaled, frac_bits):
    scale = 2.0**frac_bits
    return quantized_value_scaled / scale

# --- Main Fixed-Point Analysis Script ---
if __name__ == '__main__':
    print(f"--- Fixed-Point Conversion Analysis for ASIC PoC ---")
    print(f"Activation fixed-point: TOTAL_BITS={TOTAL_BITS}, FRAC_BITS={FRAC_BITS}")
    print(f"Weight fixed-point: TOTAL_BITS={WEIGHT_TOTAL_BITS}, FRAC_BITS={WEIGHT_FRAC_BITS}\n")

    # 1. Instantiate the decoder
    decoder = SimpleDecoder(latent_dim=LATENT_DIM, n_output_channels=N_MELS, fixed_frames=FIXED_FRAMES)

    # --- !!! IMPORTANT: Load your TRAINED model weights here !!! ---
    # Example:
    # model_path = 'path_to_your_trained_SimpleVAE_model.pth'
    # if os.path.exists(model_path):
    #     try:
    #         full_vae_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    #         # Filter for decoder weights if VAE model is saved
    #         decoder_state_dict = {k.replace('decoder.', '', 1): v
    #                               for k, v in full_vae_state_dict.items()
    #                               if k.startswith('decoder.')}
    #         if not decoder_state_dict: # If only decoder was saved
    #             decoder_state_dict = full_vae_state_dict
    #         decoder.load_state_dict(decoder_state_dict)
    #         print(f"Successfully loaded trained decoder weights from {model_path}")
    #     except Exception as e:
    #         print(f"Error loading trained weights from {model_path}: {e}")
    #         print("Proceeding with randomly initialized weights.")
    # else:
    #     print(f"Trained model weights not found at {model_path}. Proceeding with randomly initialized weights.")
    print("NOTE: Ensure you load your TRAINED model weights for accurate analysis.")
    decoder.eval()

    # --- 2. Range Analysis (Methodology) ---
    print("\n--- Range Analysis Methodology (Action Required) ---")
    print("1. Generate a diverse set of N (e.g., 1000-10000) random latent vectors 'z'.")
    print("2. For each 'z', perform a full forward pass through your TRAINED FLOATING-POINT decoder.")
    print("3. Record the MIN and MAX floating-point values observed for:")
    print("   - All weight tensors (e.g., decoder.fc.weight, decoder.conv1.weight, etc.)")
    print("   - All bias tensors.")
    print("   - The INPUT to EACH layer (activations).")
    print("   - The OUTPUT of EACH layer (activations).")
    print("4. Use these global MIN/MAX ranges to choose appropriate TOTAL_BITS and FRAC_BITS for weights and activations.")
    print("   (You might use different bit-widths for weights vs. activations, or even per layer).")
    print("-----------------------------------------------------\n")

    # --- 3. Extract and Quantize Weights for self.fc ---
    fc_weights_float = decoder.fc.weight.data.detach().clone().numpy()
    fc_bias_float = decoder.fc.bias.data.detach().clone().numpy()

    fc_weights_quant = quantize_value(fc_weights_float, WEIGHT_TOTAL_BITS, WEIGHT_FRAC_BITS)
    fc_bias_quant = quantize_value(fc_bias_float, TOTAL_BITS, FRAC_BITS) # Bias often uses activation format

    print(f"FC Layer: fc_weights_quant shape: {fc_weights_quant.shape}, dtype: {fc_weights_quant.dtype}")
    print(f"FC Layer: fc_bias_quant shape: {fc_bias_quant.shape}, dtype: {fc_bias_quant.dtype}")


    # --- 4. Simulate FC Layer and first ReLU with a sample input ---
    sample_z_float_tensor = torch.randn(1, LATENT_DIM) # Batch size 1
    sample_z_float_np = sample_z_float_tensor.numpy().squeeze() # (LATENT_DIM,)

    # --- Floating-point computation (Reference) ---
    fc_output_float_tensor, relu1_output_float_tensor = decoder.forward_fc_relu1(sample_z_float_tensor)
    fc_output_float_np = fc_output_float_tensor.detach().numpy().squeeze()
    relu1_output_float_np = relu1_output_float_tensor.detach().numpy().squeeze()

    print(f"\n--- Reference Floating-Point Outputs ---")
    print(f"FC output float min: {fc_output_float_np.min():.4f}, max: {fc_output_float_np.max():.4f}")
    print(f"ReLU1 output float min: {relu1_output_float_np.min():.4f}, max: {relu1_output_float_np.max():.4f}")

    # --- Fixed-Point Simulation ---
    # Quantize the input sample_z
    sample_z_quant = quantize_value(sample_z_float_np, TOTAL_BITS, FRAC_BITS)

    # A. Simulate FC Layer (Fixed-Point)
    # output_scaled = (input_scaled @ weights_scaled.T) / (2**WEIGHT_FRAC_BITS_FOR_PRODUCT_SCALING) + bias_scaled
    # Product: (z_act_q * w_q) -> scale is (2^FRAC_BITS * 2^WEIGHT_FRAC_BITS)
    # Accumulator needs to handle this. Then rescale before adding bias.

    fc_output_quant_scaled = np.zeros(decoder.fc.out_features, dtype=np.int64)
    intermediate_scale_factor = 2.0**WEIGHT_FRAC_BITS # Scale of weights

    for i in range(decoder.fc.out_features): # For each output neuron
        accumulator = np.int64(0)
        for j in range(decoder.fc.in_features): # Dot product
            # sample_z_quant[j] is scaled by 2^FRAC_BITS
            # fc_weights_quant[i,j] is scaled by 2^WEIGHT_FRAC_BITS
            # Product is scaled by 2^(FRAC_BITS + WEIGHT_FRAC_BITS)
            product_mega_scaled = np.int64(sample_z_quant[j]) * np.int64(fc_weights_quant[i, j])
            accumulator += product_mega_scaled

        # Accumulator is sum of (z_q * w_q), scaled by 2^(FRAC_BITS + WEIGHT_FRAC_BITS)
        # We want the result to be scaled by 2^FRAC_BITS (like activations)
        # So, divide accumulator by 2^WEIGHT_FRAC_BITS
        acc_rescaled_for_activation = np.round(accumulator / intermediate_scale_factor).astype(np.int64)
        fc_output_quant_scaled[i] = acc_rescaled_for_activation + fc_bias_quant[i] # Bias is already in activation Q format

    # Clamp the output of FC layer to activation fixed-point range
    fc_output_quant_scaled = np.clip(fc_output_quant_scaled,
                                     -(2**(TOTAL_BITS - 1)),
                                     (2**(TOTAL_BITS - 1)) - 1)


    # B. Simulate ReLU Layer (Fixed-Point)
    # ReLU input is fc_output_quant_scaled (which are scaled integers)
    # ReLU: max(0, x)
    relu1_output_quant_scaled = np.maximum(0, fc_output_quant_scaled).astype(np.int64)
    # No change in scale for ReLU

    # --- Dequantize Fixed-Point Outputs for Comparison ---
    fc_output_fixed_dequant = dequantize_value(fc_output_quant_scaled, FRAC_BITS)
    relu1_output_fixed_dequant = dequantize_value(relu1_output_quant_scaled, FRAC_BITS)

    print(f"\n--- Simulated Fixed-Point Outputs (Dequantized) ---")
    print(f"FC output fixed (dequant) min: {fc_output_fixed_dequant.min():.4f}, max: {fc_output_fixed_dequant.max():.4f}")
    print(f"ReLU1 output fixed (dequant) min: {relu1_output_fixed_dequant.min():.4f}, max: {relu1_output_fixed_dequant.max():.4f}")

    # --- Compare Outputs ---
    mse_fc = np.mean((fc_output_float_np - fc_output_fixed_dequant)**2)
    mae_fc = np.mean(np.abs(fc_output_float_np - fc_output_fixed_dequant))
    mse_relu1 = np.mean((relu1_output_float_np - relu1_output_fixed_dequant)**2)
    mae_relu1 = np.mean(np.abs(relu1_output_float_np - relu1_output_fixed_dequant))

    print(f"\n--- Comparison of Layer Outputs (Float vs Fixed-Point Sim) ---")
    print(f"FC Layer MSE: {mse_fc:.2e}, MAE: {mae_fc:.2e}")
    print(f"ReLU1 Layer MSE: {mse_relu1:.2e}, MAE: {mae_relu1:.2e}")

    print("\nNext steps: Perform thorough range analysis on your TRAINED model, then extend this simulation to Conv1D and Upsample layers.")


--- Fixed-Point Conversion Analysis for ASIC PoC ---
Activation fixed-point: TOTAL_BITS=16, FRAC_BITS=8
Weight fixed-point: TOTAL_BITS=16, FRAC_BITS=10

NOTE: Ensure you load your TRAINED model weights for accurate analysis.

--- Range Analysis Methodology (Action Required) ---
1. Generate a diverse set of N (e.g., 1000-10000) random latent vectors 'z'.
2. For each 'z', perform a full forward pass through your TRAINED FLOATING-POINT decoder.
3. Record the MIN and MAX floating-point values observed for:
   - All weight tensors (e.g., decoder.fc.weight, decoder.conv1.weight, etc.)
   - All bias tensors.
   - The INPUT to EACH layer (activations).
   - The OUTPUT of EACH layer (activations).
4. Use these global MIN/MAX ranges to choose appropriate TOTAL_BITS and FRAC_BITS for weights and activations.
   (You might use different bit-widths for weights vs. activations, or even per layer).
-----------------------------------------------------

FC Layer: fc_weights_quant shape: (8192, 64), dt

In [None]:
# --- Code to Save Decoder Weights ---
# Make sure your 'vae' model instance is trained and available in this scope.
# For example, if your trained VAE model instance is named `vae`:

# Define the path where you want to save the decoder's weights
# You might want to save this in your Google Drive if using Colab, or a local path.
# Example for Colab, assuming Drive is mounted at /content/drive:
# output_directory = "/content/drive/MyDrive/Neural_Drum_Machine_Output"
# Or a local path:
output_directory = "./model_weights" # Saves in a 'model_weights' subdirectory
os.makedirs(output_directory, exist_ok=True) # Create directory if it doesn't exist

decoder_weights_filename = "trained_simple_decoder_weights.pth"
decoder_weights_path = os.path.join(output_directory, decoder_weights_filename)

# Check if the 'vae' model exists (it should after your training script runs)
if 'vae' in locals() and isinstance(vae, SimpleVAE):
    print(f"Found trained VAE model instance 'vae'.")

    # Ensure the model is on the CPU before saving, for broader compatibility,
    # unless you specifically need to save it with GPU information (not typical for state_dict).
    vae.cpu() # Move the whole VAE to CPU; this moves its submodules (like decoder) too.

    # Access the decoder submodule and save its state_dict
    try:
        torch.save(vae.decoder.state_dict(), decoder_weights_path)
        print(f"Successfully saved SimpleDecoder state_dict to: {decoder_weights_path}")
        print("\nYou can now use this path in your fixed-point analysis script to load these weights.")
    except Exception as e:
        print(f"Error saving decoder weights: {e}")
else:
    print("Error: Trained VAE model instance named 'vae' not found in the current scope.")
    print("Please ensure your VAE model is trained and the instance is named 'vae',")
    print("or modify this script to use the correct variable name for your trained VAE model.")

# To load these weights into a standalone SimpleDecoder instance later:
#
# device_for_loading = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # 1. Instantiate a new SimpleDecoder (make sure class definition is available)
# standalone_decoder = SimpleDecoder(latent_dim=LATENT_DIM, n_output_channels=N_MELS)
# # 2. Load the saved state_dict
# standalone_decoder.load_state_dict(torch.load(decoder_weights_path, map_location=device_for_loading))
# # 3. Set to evaluation mode if you're doing inference
# standalone_decoder.eval()
# print(f"\nExample loading: standalone_decoder.load_state_dict(torch.load('{decoder_weights_path}'))")


Found trained VAE model instance 'vae'.
Successfully saved SimpleDecoder state_dict to: ./model_weights/trained_simple_decoder_weights.pth

You can now use this path in your fixed-point analysis script to load these weights.


In [None]:
import torch
import torch.nn as nn
import numpy as np
import os

# --- Constants (ensure these match your trained model) ---
LATENT_DIM = 64
N_MELS = 128
FIXED_FRAMES = 512

# --- SimpleDecoder Definition (Revised for detailed layer access) ---
class SimpleDecoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, n_output_channels=N_MELS, fixed_frames=FIXED_FRAMES):
        super().__init__()
        self.n_output_channels = n_output_channels
        self.initial_frames = fixed_frames // 4
        self.initial_channels = 64

        self.fc = nn.Linear(latent_dim, self.initial_channels * self.initial_frames)

        # Manually define layers to match the structure from which weights were saved
        # (e.g., if saved from an nn.Sequential block named 'decode_layers')
        self.relu1 = nn.ReLU()
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv1d(self.initial_channels, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv2 = nn.Conv1d(32, 16, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.conv3 = nn.Conv1d(16, self.n_output_channels, kernel_size=1) # Final Mel conv
        self.relu4 = nn.ReLU() # Final ReLU

    def forward(self, z, analyse_all_layers=True): # Default to analyse_all_layers=True now
        activations = {}

        x = self.fc(z)
        if analyse_all_layers: activations['fc_out'] = x.detach().clone()

        x = self.relu1(x)
        if analyse_all_layers: activations['relu1_out'] = x.detach().clone()

        x = x.view(-1, self.initial_channels, self.initial_frames)
        if analyse_all_layers: activations['reshape_out'] = x.detach().clone()

        x = self.upsample1(x)
        if analyse_all_layers: activations['upsample1_out'] = x.detach().clone()

        x = self.conv1(x)
        if analyse_all_layers: activations['conv1_out'] = x.detach().clone()

        x = self.relu2(x)
        if analyse_all_layers: activations['relu2_out'] = x.detach().clone()

        x = self.upsample2(x)
        if analyse_all_layers: activations['upsample2_out'] = x.detach().clone()

        x = self.conv2(x)
        if analyse_all_layers: activations['conv2_out'] = x.detach().clone()

        x = self.relu3(x)
        if analyse_all_layers: activations['relu3_out'] = x.detach().clone()

        x = self.conv3(x)
        if analyse_all_layers: activations['conv3_out_final_mel'] = x.detach().clone()

        x = self.relu4(x)
        if analyse_all_layers: activations['relu4_out_final'] = x.detach().clone()

        final_output = x
        if analyse_all_layers: return final_output, activations
        else: return final_output # Just return final if not analysing all

    def load_custom_state_dict(self, state_dict_path, device):
        """Loads state_dict where sequential layers were saved with 'decode_layers.X' prefix."""
        state_dict = torch.load(state_dict_path, map_location=device)
        new_state_dict = self.state_dict()

        # Manually map keys from 'decode_layers.X.param' to direct attribute names
        key_map = {
            # ReLU0 (self.relu1) has no params
            # Upsample1 (self.upsample1) has no params
            'decode_layers.2.weight': 'conv1.weight', 'decode_layers.2.bias': 'conv1.bias',
            # ReLU1 (self.relu2) has no params
            # Upsample2 (self.upsample2) has no params
            'decode_layers.5.weight': 'conv2.weight', 'decode_layers.5.bias': 'conv2.bias',
            # ReLU2 (self.relu3) has no params
            'decode_layers.7.weight': 'conv3.weight', 'decode_layers.7.bias': 'conv3.bias',
            # ReLU3 (self.relu4) has no params
        }

        loaded_count = 0
        # Load fc layer directly
        if 'fc.weight' in state_dict: new_state_dict['fc.weight'] = state_dict['fc.weight']; loaded_count+=1
        if 'fc.bias' in state_dict: new_state_dict['fc.bias'] = state_dict['fc.bias']; loaded_count+=1

        for old_key_prefix_idx, new_attr_base in [('2','conv1'), ('5','conv2'), ('7','conv3')]:
            old_w_key = f'decode_layers.{old_key_prefix_idx}.weight'
            old_b_key = f'decode_layers.{old_key_prefix_idx}.bias'
            new_w_key = f'{new_attr_base}.weight'
            new_b_key = f'{new_attr_base}.bias'

            if old_w_key in state_dict: new_state_dict[new_w_key] = state_dict[old_w_key]; loaded_count+=1
            if old_b_key in state_dict: new_state_dict[new_b_key] = state_dict[old_b_key]; loaded_count+=1

        try:
            self.load_state_dict(new_state_dict)
            print(f"Successfully loaded and mapped weights ({loaded_count} tensors) from: {state_dict_path}\n")
        except RuntimeError as e:
            print(f"RuntimeError during load_state_dict after mapping: {e}")
            print("Current model keys:", self.state_dict().keys())
            print("Attempted to load with keys:", new_state_dict.keys())
            print("Original loaded state_dict keys:", state_dict.keys())
            raise e

# --- Chosen Fixed-Point Simulation Parameters (Based on Your Range Analysis) ---
ACT_TOTAL_BITS = 16
ACT_FRAC_BITS = 7
WEIGHT_TOTAL_BITS = 16
WEIGHT_FRAC_BITS = 13
BIAS_TOTAL_BITS = ACT_TOTAL_BITS
BIAS_FRAC_BITS = ACT_FRAC_BITS

# --- Quantization and Dequantization Functions ---
def quantize_value(value_np, total_bits, frac_bits):
    scale = 2.0**frac_bits
    min_val_representable = -(2.0**(total_bits - 1))
    max_val_representable = (2.0**(total_bits - 1)) - 1
    scaled_value = np.round(value_np * scale)
    clamped_value = np.clip(scaled_value, min_val_representable, max_val_representable)
    return clamped_value.astype(np.int64)

def dequantize_value(quantized_value_scaled, frac_bits):
    scale = 2.0**frac_bits
    return quantized_value_scaled / scale

# --- Conv1D Fixed-Point Simulation Function ---
def simulate_conv1d_fixed_point(input_quant_scaled, conv_layer_torch,
                                weights_quant, bias_quant,
                                act_total_bits, act_frac_bits,
                                weight_frac_bits, bias_frac_bits): # Added bias_frac_bits

    in_channels, in_length = input_quant_scaled.shape
    out_channels = conv_layer_torch.out_channels
    kernel_size = conv_layer_torch.kernel_size[0]
    padding = conv_layer_torch.padding[0]
    # stride = conv_layer_torch.stride[0] # Assuming stride 1 for this sim

    # Output length for Conv1D with padding P, kernel K, stride S=1: L_out = L_in - K + 2P + 1 (Incorrect formula)
    # L_out = floor((L_in + 2P - K) / S) + 1. For S=1, L_out = L_in + 2P - K + 1
    # If padding = 'same' effectively (K=3, P=1 -> L_out = L_in)
    out_length = in_length # Assuming 'same' padding effect

    padded_input_quant = np.pad(
        input_quant_scaled,
        ((0,0), (padding, padding)),
        mode='constant',
        constant_values=0 # Quantized zero
    )

    output_quant_scaled = np.zeros((out_channels, out_length), dtype=np.int64)
    # Factor to rescale product (act_q * weight_q) before adding bias_q
    product_rescale_factor = 2.0**weight_frac_bits

    for oc in range(out_channels):
        for l_out in range(out_length):
            accumulator = np.int64(0)
            for ic in range(in_channels):
                for k_idx in range(kernel_size):
                    input_val_q = padded_input_quant[ic, l_out + k_idx] # Stride 1 assumed
                    weight_val_q = weights_quant[oc, ic, k_idx]
                    prod = np.int64(input_val_q) * np.int64(weight_val_q)
                    accumulator += prod

            acc_rescaled = np.round(accumulator / product_rescale_factor).astype(np.int64)
            # Bias is already quantized to its own Q format (here, same as activation)
            output_quant_scaled[oc, l_out] = acc_rescaled + bias_quant[oc]

    output_quant_scaled = np.clip(output_quant_scaled, -(2**(act_total_bits-1)), (2**(act_total_bits-1))-1)
    return output_quant_scaled

# --- Main Fixed-Point Analysis Script ---
if __name__ == '__main__':
    print(f"--- Full Decoder Fixed-Point Simulation ---")
    device = torch.device("cpu")

    decoder = SimpleDecoder(latent_dim=LATENT_DIM, n_output_channels=N_MELS, fixed_frames=FIXED_FRAMES)
    decoder.to(device)

    decoder_weights_path = "./model_weights/trained_simple_decoder_weights.pth"
    if os.path.exists(decoder_weights_path):
        decoder.load_custom_state_dict(decoder_weights_path, device)
    else:
        print(f"Trained decoder weights not found at {decoder_weights_path}. Exiting.")
        exit()
    decoder.eval()

    print(f"\n--- Using Chosen Fixed-Point Simulation Parameters ---")
    act_int_bits = ACT_TOTAL_BITS - ACT_FRAC_BITS - 1
    weight_int_bits = WEIGHT_TOTAL_BITS - WEIGHT_FRAC_BITS - 1
    bias_int_bits = BIAS_TOTAL_BITS - BIAS_FRAC_BITS - 1
    print(f"Activation Q-format: Q{act_int_bits}.{ACT_FRAC_BITS} (Total: {ACT_TOTAL_BITS})")
    print(f"Weight Q-format: Q{weight_int_bits}.{WEIGHT_FRAC_BITS} (Total: {WEIGHT_TOTAL_BITS})")
    print(f"Bias Q-format: Q{bias_int_bits}.{BIAS_FRAC_BITS} (Total: {BIAS_TOTAL_BITS})")

    # --- Quantize all necessary weights and biases ---
    fc_weights_q = quantize_value(decoder.fc.weight.data.cpu().numpy(), WEIGHT_TOTAL_BITS, WEIGHT_FRAC_BITS)
    fc_bias_q = quantize_value(decoder.fc.bias.data.cpu().numpy(), BIAS_TOTAL_BITS, BIAS_FRAC_BITS)

    conv1_weights_q = quantize_value(decoder.conv1.weight.data.cpu().numpy(), WEIGHT_TOTAL_BITS, WEIGHT_FRAC_BITS)
    conv1_bias_q = quantize_value(decoder.conv1.bias.data.cpu().numpy(), BIAS_TOTAL_BITS, BIAS_FRAC_BITS)

    conv2_weights_q = quantize_value(decoder.conv2.weight.data.cpu().numpy(), WEIGHT_TOTAL_BITS, WEIGHT_FRAC_BITS)
    conv2_bias_q = quantize_value(decoder.conv2.bias.data.cpu().numpy(), BIAS_TOTAL_BITS, BIAS_FRAC_BITS)

    conv3_weights_q = quantize_value(decoder.conv3.weight.data.cpu().numpy(), WEIGHT_TOTAL_BITS, WEIGHT_FRAC_BITS)
    conv3_bias_q = quantize_value(decoder.conv3.bias.data.cpu().numpy(), BIAS_TOTAL_BITS, BIAS_FRAC_BITS)

    # --- Single Sample Simulation ---
    sim_z_float_tensor = torch.randn(1, LATENT_DIM, device=device)

    # --- Floating-Point Reference Pass (Full Decoder) ---
    ref_final_output_float, ref_activations_float = decoder.forward(sim_z_float_tensor, analyse_all_layers=True)

    # --- Fixed-Point Simulation (Full Decoder) ---
    fixed_activations_quant_scaled = {} # To store quantized outputs of each layer
    fixed_activations_dequant = {}    # To store dequantized outputs for comparison

    # 1. Input Quantization
    current_input_quant_scaled = quantize_value(sim_z_float_tensor.numpy().squeeze(), ACT_TOTAL_BITS, ACT_FRAC_BITS)
    fixed_activations_quant_scaled['latent_z'] = current_input_quant_scaled
    fixed_activations_dequant['latent_z'] = dequantize_value(current_input_quant_scaled, ACT_FRAC_BITS)

    # 2. FC Layer
    fc_output_q_s = np.zeros(decoder.fc.out_features, dtype=np.int64)
    product_rescale_fc = 2.0**WEIGHT_FRAC_BITS
    for i in range(decoder.fc.out_features):
        acc = np.int64(0)
        for j in range(decoder.fc.in_features):
            acc += np.int64(current_input_quant_scaled[j]) * np.int64(fc_weights_q[i, j])
        fc_output_q_s[i] = np.round(acc / product_rescale_fc).astype(np.int64) + fc_bias_q[i]
    current_output_quant_scaled = np.clip(fc_output_q_s, -(2**(ACT_TOTAL_BITS-1)), (2**(ACT_TOTAL_BITS-1))-1)
    fixed_activations_quant_scaled['fc_out'] = current_output_quant_scaled
    fixed_activations_dequant['fc_out'] = dequantize_value(current_output_quant_scaled, ACT_FRAC_BITS)

    # 3. ReLU1
    current_output_quant_scaled = np.maximum(0, current_output_quant_scaled).astype(np.int64)
    fixed_activations_quant_scaled['relu1_out'] = current_output_quant_scaled
    fixed_activations_dequant['relu1_out'] = dequantize_value(current_output_quant_scaled, ACT_FRAC_BITS)

    # 4. Reshape
    current_output_quant_scaled = current_output_quant_scaled.reshape((decoder.initial_channels, decoder.initial_frames))
    fixed_activations_quant_scaled['reshape_out'] = current_output_quant_scaled
    fixed_activations_dequant['reshape_out'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)


    # 5. Upsample1
    current_output_quant_scaled = np.repeat(current_output_quant_scaled, 2, axis=1)
    fixed_activations_quant_scaled['upsample1_out'] = current_output_quant_scaled
    fixed_activations_dequant['upsample1_out'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)

    # 6. Conv1
    current_output_quant_scaled = simulate_conv1d_fixed_point(
        current_output_quant_scaled, decoder.conv1, conv1_weights_q, conv1_bias_q,
        ACT_TOTAL_BITS, ACT_FRAC_BITS, WEIGHT_FRAC_BITS, BIAS_FRAC_BITS
    )
    fixed_activations_quant_scaled['conv1_out'] = current_output_quant_scaled
    fixed_activations_dequant['conv1_out'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)

    # 7. ReLU2
    current_output_quant_scaled = np.maximum(0, current_output_quant_scaled).astype(np.int64)
    fixed_activations_quant_scaled['relu2_out'] = current_output_quant_scaled
    fixed_activations_dequant['relu2_out'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)

    # 8. Upsample2
    current_output_quant_scaled = np.repeat(current_output_quant_scaled, 2, axis=1)
    fixed_activations_quant_scaled['upsample2_out'] = current_output_quant_scaled
    fixed_activations_dequant['upsample2_out'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)

    # 9. Conv2
    current_output_quant_scaled = simulate_conv1d_fixed_point(
        current_output_quant_scaled, decoder.conv2, conv2_weights_q, conv2_bias_q,
        ACT_TOTAL_BITS, ACT_FRAC_BITS, WEIGHT_FRAC_BITS, BIAS_FRAC_BITS
    )
    fixed_activations_quant_scaled['conv2_out'] = current_output_quant_scaled
    fixed_activations_dequant['conv2_out'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)

    # 10. ReLU3
    current_output_quant_scaled = np.maximum(0, current_output_quant_scaled).astype(np.int64)
    fixed_activations_quant_scaled['relu3_out'] = current_output_quant_scaled
    fixed_activations_dequant['relu3_out'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)

    # 11. Conv3 (Final Mel)
    current_output_quant_scaled = simulate_conv1d_fixed_point(
        current_output_quant_scaled, decoder.conv3, conv3_weights_q, conv3_bias_q,
        ACT_TOTAL_BITS, ACT_FRAC_BITS, WEIGHT_FRAC_BITS, BIAS_FRAC_BITS
    )
    fixed_activations_quant_scaled['conv3_out_final_mel'] = current_output_quant_scaled
    fixed_activations_dequant['conv3_out_final_mel'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)

    # 12. ReLU4 (Final)
    current_output_quant_scaled = np.maximum(0, current_output_quant_scaled).astype(np.int64)
    fixed_activations_quant_scaled['relu4_out_final'] = current_output_quant_scaled
    fixed_activations_dequant['relu4_out_final'] = dequantize_value(current_output_quant_scaled.flatten(), ACT_FRAC_BITS).reshape(current_output_quant_scaled.shape)

    # --- Compare All Layer Outputs ---
    print(f"\n--- Comparison of ALL Layer Outputs (using DERIVED Q-formats) ---")
    activation_keys_ordered = [ # Ensure order matches forward pass for clarity
        'fc_out', 'relu1_out', 'reshape_out', 'upsample1_out',
        'conv1_out', 'relu2_out', 'upsample2_out', 'conv2_out',
        'relu3_out', 'conv3_out_final_mel', 'relu4_out_final'
    ]

    for key in activation_keys_ordered:
        if key in ref_activations_float and key in fixed_activations_dequant:
            ref_val_np = ref_activations_float[key].cpu().numpy().squeeze()
            fixed_val_dequant_np = fixed_activations_dequant[key] # Already squeezed if 1D, or reshaped

            # Ensure shapes match for comparison, especially after squeeze
            if ref_val_np.shape != fixed_val_dequant_np.shape:
                # print(f"Shape mismatch for {key}: Ref {ref_val_np.shape}, Fixed {fixed_val_dequant_np.shape}. Skipping comparison for this key.")
                # Attempt to reshape fixed if it was flattened during dequant for some reason
                try:
                    fixed_val_dequant_np = fixed_val_dequant_np.reshape(ref_val_np.shape)
                except ValueError:
                    print(f"Could not reshape fixed_val_dequant_np for key {key} to match reference. Ref shape: {ref_val_np.shape}, Fixed shape: {fixed_val_dequant_np.shape}")
                    continue


            mse = np.mean((ref_val_np - fixed_val_dequant_np)**2)
            mae = np.mean(np.abs(ref_val_np - fixed_val_dequant_np))
            print(f"{key} -> MSE: {mse:.3e}, MAE: {mae:.3e}")
        else:
            print(f"Key {key} not found in reference or fixed-point activations for comparison.")

    print("\nExamine the MSE/MAE for 'relu4_out_final' (the final output).")
    print("If this final error is acceptable, your Q-formats are good for the PoC.")


--- Full Decoder Fixed-Point Simulation ---
Successfully loaded and mapped weights (8 tensors) from: ./model_weights/trained_simple_decoder_weights.pth


--- Using Chosen Fixed-Point Simulation Parameters ---
Activation Q-format: Q8.7 (Total: 16)
Weight Q-format: Q2.13 (Total: 16)
Bias Q-format: Q8.7 (Total: 16)

--- Comparison of ALL Layer Outputs (using DERIVED Q-formats) ---
fc_out -> MSE: 2.016e-05, MAE: 3.557e-03
relu1_out -> MSE: 7.458e-06, MAE: 1.278e-03
reshape_out -> MSE: 7.458e-06, MAE: 1.278e-03
upsample1_out -> MSE: 7.458e-06, MAE: 1.278e-03
conv1_out -> MSE: 1.031e-04, MAE: 7.336e-03
relu2_out -> MSE: 1.952e-05, MAE: 1.307e-03
upsample2_out -> MSE: 1.952e-05, MAE: 1.307e-03
conv2_out -> MSE: 1.335e-04, MAE: 7.789e-03
relu3_out -> MSE: 7.558e-05, MAE: 3.287e-03
conv3_out_final_mel -> MSE: 6.174e-05, MAE: 5.369e-03
relu4_out_final -> MSE: 4.727e-08, MAE: 1.151e-05

Examine the MSE/MAE for 'relu4_out_final' (the final output).
If this final error is acceptable, your Q-formats

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os

# --- Constants (ensure these match your trained model) ---
LATENT_DIM = 64
N_MELS = 128
FIXED_FRAMES = 512

# --- SimpleDecoder Definition (Corrected for state_dict loading) ---
class SimpleDecoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, n_output_channels=N_MELS, fixed_frames=FIXED_FRAMES):
        super().__init__()
        self.n_output_channels = n_output_channels
        self.initial_frames = fixed_frames // 4
        self.initial_channels = 64

        self.fc_out_features = self.initial_channels * self.initial_frames
        self.fc = nn.Linear(latent_dim, self.fc_out_features)

        self.relu1 = nn.ReLU()
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv1d(self.initial_channels, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv2 = nn.Conv1d(32, 16, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.conv3 = nn.Conv1d(16, self.n_output_channels, kernel_size=1) # Final Mel conv
        self.relu4 = nn.ReLU() # Final ReLU

    def forward(self, z, analyse_all_layers=True): # Default to analyse_all_layers=True now
        activations = {}
        x_current = self.fc(z)
        if analyse_all_layers: activations['fc_out'] = x_current.detach().clone()
        x_current = self.relu1(x_current)
        if analyse_all_layers: activations['relu1_out'] = x_current.detach().clone()
        x_current = x_current.view(-1, self.initial_channels, self.initial_frames)
        if analyse_all_layers: activations['reshape_out'] = x_current.detach().clone()
        x_current = self.upsample1(x_current)
        if analyse_all_layers: activations['upsample1_out'] = x_current.detach().clone()
        x_current = self.conv1(x_current)
        if analyse_all_layers: activations['conv1_out'] = x_current.detach().clone()
        x_current = self.relu2(x_current)
        if analyse_all_layers: activations['relu2_out'] = x_current.detach().clone()
        x_current = self.upsample2(x_current)
        if analyse_all_layers: activations['upsample2_out'] = x_current.detach().clone()
        x_current = self.conv2(x_current)
        if analyse_all_layers: activations['conv2_out'] = x_current.detach().clone()
        x_current = self.relu3(x_current)
        if analyse_all_layers: activations['relu3_out'] = x_current.detach().clone()
        x_current = self.conv3(x_current)
        if analyse_all_layers: activations['conv3_out_final_mel'] = x_current.detach().clone()
        x_current = self.relu4(x_current)
        if analyse_all_layers: activations['relu4_out_final'] = x_current.detach().clone()
        final_output = x_current
        if analyse_all_layers: return final_output, activations
        else: return final_output

    def load_custom_state_dict(self, state_dict_path, device):
        state_dict = torch.load(state_dict_path, map_location=device)
        new_state_dict = self.state_dict()
        loaded_count = 0
        if 'fc.weight' in state_dict: new_state_dict['fc.weight'] = state_dict['fc.weight']; loaded_count+=1
        if 'fc.bias' in state_dict: new_state_dict['fc.bias'] = state_dict['fc.bias']; loaded_count+=1
        for old_key_prefix_idx, new_attr_base in [('2','conv1'), ('5','conv2'), ('7','conv3')]:
            old_w_key = f'decode_layers.{old_key_prefix_idx}.weight'
            old_b_key = f'decode_layers.{old_key_prefix_idx}.bias'
            new_w_key = f'{new_attr_base}.weight'
            new_b_key = f'{new_attr_base}.bias'
            if old_w_key in state_dict: new_state_dict[new_w_key] = state_dict[old_w_key]; loaded_count+=1
            if old_b_key in state_dict: new_state_dict[new_b_key] = state_dict[old_b_key]; loaded_count+=1
        try:
            self.load_state_dict(new_state_dict)
            print(f"Successfully loaded and mapped weights ({loaded_count} unique tensors mapped) from: {state_dict_path}\n")
        except RuntimeError as e:
            print(f"RuntimeError during load_state_dict after mapping: {e}")
            raise e

# --- Chosen Fixed-Point Simulation Parameters (Based on Your Range Analysis) ---
ACT_TOTAL_BITS = 16
ACT_FRAC_BITS = 7
WEIGHT_TOTAL_BITS = 16
WEIGHT_FRAC_BITS = 13
BIAS_TOTAL_BITS = ACT_TOTAL_BITS
BIAS_FRAC_BITS = ACT_FRAC_BITS

# --- Quantization and Dequantization Functions ---
def quantize_value(value_np, total_bits, frac_bits):
    scale = 2.0**frac_bits
    min_val_representable = -(2.0**(total_bits - 1))
    max_val_representable = (2.0**(total_bits - 1)) - 1
    scaled_value = np.round(value_np * scale)
    clamped_value = np.clip(scaled_value, min_val_representable, max_val_representable)
    return clamped_value.astype(np.int64)

# ... (dequantize_value and MinMaxTracker are the same, can be omitted for brevity if already defined)
def dequantize_value(quantized_value_scaled, frac_bits):
    scale = 2.0**frac_bits
    return quantized_value_scaled / scale

# --- Main Fixed-Point Analysis Script ---
if __name__ == '__main__':
    print(f"--- Generating Test Vectors for DotProductBiasUnit ---")
    device = torch.device("cpu")

    decoder = SimpleDecoder(latent_dim=LATENT_DIM, n_output_channels=N_MELS, fixed_frames=FIXED_FRAMES)
    decoder.to(device)

    decoder_weights_path = "./model_weights/trained_simple_decoder_weights.pth"
    if os.path.exists(decoder_weights_path):
        decoder.load_custom_state_dict(decoder_weights_path, device)
    else:
        print(f"Trained decoder weights not found at {decoder_weights_path}. CANNOT GENERATE ACCURATE TEST VECTORS.")
        exit()
    decoder.eval()

    # --- Quantize FC weights and biases (needed for test vector generation) ---
    fc_weights_q_for_tv = quantize_value(decoder.fc.weight.data.cpu().numpy(), WEIGHT_TOTAL_BITS, WEIGHT_FRAC_BITS)
    fc_bias_q_for_tv = quantize_value(decoder.fc.bias.data.cpu().numpy(), BIAS_TOTAL_BITS, BIAS_FRAC_BITS)

    # --- Generate a single, deterministic input sample for test vector ---
    # Use a fixed seed for reproducibility of the "random" input
    torch.manual_seed(42) # For PyTorch's random number generator
    np.random.seed(42)    # For NumPy's random number generator (if used for z elsewhere)

    sim_z_float_tensor_tv = torch.randn(1, LATENT_DIM, device=device)
    sim_z_float_np_tv = sim_z_float_tensor_tv.numpy().squeeze() # (LATENT_DIM,)

    # Quantize the input sample
    sim_z_quant_tv = quantize_value(sim_z_float_np_tv, ACT_TOTAL_BITS, ACT_FRAC_BITS)

    # --- Simulate FC Layer (Fixed-Point) for the chosen neuron to get expected output ---
    neuron_index_for_tv = 0 # Test the first output neuron of the FC layer

    fc_output_quant_scaled_tv = np.zeros(decoder.fc.out_features, dtype=np.int64)
    product_rescale_fc_tv = 2.0**WEIGHT_FRAC_BITS

    # Simulate only the neuron_index_for_tv
    accumulator_tv = np.int64(0)
    for j in range(decoder.fc.in_features): # Dot product for the specific neuron
        prod_tv = np.int64(sim_z_quant_tv[j]) * np.int64(fc_weights_q_for_tv[neuron_index_for_tv, j])
        accumulator_tv += prod_tv

    acc_rescaled_tv = np.round(accumulator_tv / product_rescale_fc_tv).astype(np.int64)
    neuron_output_q_s_tv = acc_rescaled_tv + fc_bias_q_for_tv[neuron_index_for_tv]

    # Clamp the output
    expected_neuron_output_quant_scaled = np.clip(
        neuron_output_q_s_tv,
        -(2**(ACT_TOTAL_BITS - 1)),
        (2**(ACT_TOTAL_BITS - 1)) - 1
    )

    # --- Print Test Vectors for Verilog ---
    print("\n--- Verilog Test Vectors for DotProductBiasUnit (neuron_index = 0) ---")
    print(f"// Fixed-Point Formats used for generation:")
    print(f"// Activation: Q{ACT_TOTAL_BITS-ACT_FRAC_BITS-1}.{ACT_FRAC_BITS} (Total: {ACT_TOTAL_BITS})")
    print(f"// Weight:     Q{WEIGHT_TOTAL_BITS-WEIGHT_FRAC_BITS-1}.{WEIGHT_FRAC_BITS} (Total: {WEIGHT_TOTAL_BITS})")
    print(f"// Bias:       Q{BIAS_TOTAL_BITS-BIAS_FRAC_BITS-1}.{BIAS_FRAC_BITS} (Total: {BIAS_TOTAL_BITS})\n")

    print(f"// Input Vector (sim_z_quant_tv) - {LATENT_DIM} elements, {ACT_TOTAL_BITS}-bit signed decimal:")
    for i in range(LATENT_DIM):
        print(f"input_vector_regs[{i}] = {ACT_TOTAL_BITS}'sd{sim_z_quant_tv[i]};")

    print(f"\n// Weight Vector for neuron {neuron_index_for_tv} (fc_weights_q_for_tv[{neuron_index_for_tv}, :]) - {LATENT_DIM} elements, {WEIGHT_TOTAL_BITS}-bit signed decimal:")
    for i in range(LATENT_DIM):
        print(f"weight_vector_regs[{i}] = {WEIGHT_TOTAL_BITS}'sd{fc_weights_q_for_tv[neuron_index_for_tv, i]};")

    print(f"\n// Bias Value for neuron {neuron_index_for_tv} (fc_bias_q_for_tv[{neuron_index_for_tv}]) - {BIAS_TOTAL_BITS}-bit signed decimal:")
    print(f"bias_value_reg = {BIAS_TOTAL_BITS}'sd{fc_bias_q_for_tv[neuron_index_for_tv]};")

    print(f"\n// Expected Output Neuron Value (scaled integer) - {ACT_TOTAL_BITS}-bit signed decimal:")
    print(f"// expected_output_neuron_scaled = {ACT_TOTAL_BITS}'sd{expected_neuron_output_quant_scaled};")

    # Also print dequantized version for human understanding
    expected_neuron_output_dequant = dequantize_value(expected_neuron_output_quant_scaled, ACT_FRAC_BITS)
    print(f"// Expected Output Neuron Value (dequantized float): {expected_neuron_output_dequant:.6f}")



--- Generating Test Vectors for DotProductBiasUnit ---
Successfully loaded and mapped weights (8 unique tensors mapped) from: ./model_weights/trained_simple_decoder_weights.pth


--- Verilog Test Vectors for DotProductBiasUnit (neuron_index = 0) ---
// Fixed-Point Formats used for generation:
// Activation: Q8.7 (Total: 16)
// Weight:     Q2.13 (Total: 16)
// Bias:       Q8.7 (Total: 16)

// Input Vector (sim_z_quant_tv) - 64 elements, 16-bit signed decimal:
input_vector_regs[0] = 16'sd247;
input_vector_regs[1] = 16'sd190;
input_vector_regs[2] = 16'sd115;
input_vector_regs[3] = 16'sd-270;
input_vector_regs[4] = 16'sd87;
input_vector_regs[5] = 16'sd-158;
input_vector_regs[6] = 16'sd-6;
input_vector_regs[7] = 16'sd-205;
input_vector_regs[8] = 16'sd-96;
input_vector_regs[9] = 16'sd211;
input_vector_regs[10] = 16'sd-50;
input_vector_regs[11] = 16'sd-180;
input_vector_regs[12] = 16'sd-93;
input_vector_regs[13] = 16'sd-72;
input_vector_regs[14] = 16'sd-98;
input_vector_regs[15] = 16'sd98;
in