In [1]:
# Cell 1: Imports for Spectrogram Generation
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

# Imports needed for the custom layer and generator architecture
from tensorflow.keras.layers import (Input, Dense, Reshape, Conv2DTranspose,
                                     LayerNormalization, LeakyReLU, Conv2D, Layer)
from tensorflow.keras.models import Sequential

2025-04-08 00:11:12.249432: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-08 00:11:12.261471: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744051272.274885  302968 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744051272.278728  302968 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744051272.289575  302968 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
# Cell 2: Custom Self-Attention Layer Definition

class SelfAttention(Layer):
    """
    Self-attention layer based on SAGAN.
    Input shape: (batch, height, width, channels)
    Output shape: (batch, height, width, channels_out) where channels_out is typically channels
    """
    def __init__(self, channels_out=None, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        self.channels_out = channels_out

    def build(self, input_shape):
        self.input_channels = input_shape[-1]
        if self.channels_out is None:
            self.channels_out = self.input_channels

        # Convolution layers for query, key, value
        self.f = Conv2D(self.input_channels // 8, kernel_size=1, strides=1, padding='same', name='conv_f') # Query
        self.g = Conv2D(self.input_channels // 8, kernel_size=1, strides=1, padding='same', name='conv_g') # Key
        self.h = Conv2D(self.channels_out, kernel_size=1, strides=1, padding='same', name='conv_h')        # Value
        self.out_conv = Conv2D(self.channels_out, kernel_size=1, strides=1, padding='same', name='conv_out')
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zeros', trainable=True)
        super(SelfAttention, self).build(input_shape)

    def call(self, x):
        batch_size, height, width, num_channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        location_num = height * width
        f_proj = self.f(x)
        g_proj = self.g(x)
        h_proj = self.h(x)
        f_flat = tf.reshape(f_proj, shape=(batch_size, location_num, self.input_channels // 8))
        g_flat = tf.reshape(g_proj, shape=(batch_size, location_num, self.input_channels // 8))
        h_flat = tf.reshape(h_proj, shape=(batch_size, location_num, self.channels_out))
        g_flat_t = tf.transpose(g_flat, perm=[0, 2, 1])
        attention_score = tf.matmul(f_flat, g_flat_t)
        attention_prob = tf.nn.softmax(attention_score, axis=-1)
        attention_output = tf.matmul(attention_prob, h_flat)
        attention_output_reshaped = tf.reshape(attention_output, shape=(batch_size, height, width, self.channels_out))
        o = self.out_conv(attention_output_reshaped)
        y = self.gamma * o + x # Additive skip connection
        return y

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.channels_out,)

    # --- IMPORTANT: Add get_config if you were saving/loading full model ---
    # Although not strictly needed for load_weights, it's best practice
    def get_config(self):
        config = super(SelfAttention, self).get_config()
        config.update({"channels_out": self.channels_out})
        return config

In [3]:
# Cell 3: Generator Model Definition

def create_generator(latent_dim, output_shape): # output_shape (n_mels, n_frames) e.g., (80, 126)
    """Creates the generator model with Self-Attention."""
    n_mels, n_frames = output_shape
    init_h, init_w = n_mels // 8, n_frames // 8
    init_c = 128

    if init_h * 8 != n_mels or init_w * 8 != n_frames:
         print(f"Note: Output shape {output_shape} might require final adjustment layer.")
         init_w = (n_frames + 7) // 8

    nodes = init_h * init_w * init_c

    model = Sequential(name='generator')
    model.add(Input(shape=(latent_dim,)))
    model.add(Dense(nodes))
    model.add(LeakyReLU(negative_slope=0.2))
    model.add(Reshape((init_h, init_w, init_c)))

    # Upsample 1: -> (20, 32, 64) approx
    model.add(Conv2DTranspose(init_c // 2, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(LayerNormalization())
    model.add(LeakyReLU(negative_slope=0.2))

    # Upsample 2: -> (40, 64, 32) approx
    model.add(Conv2DTranspose(init_c // 4, kernel_size=(4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(LayerNormalization())
    model.add(LeakyReLU(negative_slope=0.2))

    # Add Self-Attention Layer (ensure placement matches trained model)
    model.add(SelfAttention(channels_out=init_c // 4))

    # Upsample 3: -> (80, 128, 1) approx
    model.add(Conv2DTranspose(1, kernel_size=(4, 4), strides=(2, 2), padding='same', activation='tanh'))

    # Final adjustment layer if needed (e.g., width 128 -> 126)
    current_width = init_w * 8
    if current_width != n_frames:
        print(f"Generator adding final Conv2D to adjust width from {current_width} to {n_frames}")
        kernel_w = current_width - n_frames + 1
        if kernel_w > 0:
             model.add(Conv2D(1, kernel_size=(1, kernel_w), padding='valid', activation='tanh'))
        else:
             print(f"Warning: Could not adjust width with Conv2D.")

    return model

In [4]:
# Cell 4: Configuration for Generation (Multi-Epoch)

# --- Essential Parameters (Match the training setup) ---
latent_dim = 100
N_MELS = 80
TARGET_FRAMES = 126
mel_spectrogram_shape = (N_MELS, TARGET_FRAMES)

# --- Generation Parameters ---
num_images_per_epoch = 8  # How many examples to generate PER epoch for comparison
epochs_to_check = [ 75, 80, 85, 100, 125] # <--- LIST of epochs weights you have saved and want to compare
output_dir_base = "generator_comparison" # Base directory to save comparison images

# --- Ensure necessary modules are imported ---
import os
os.makedirs(output_dir_base, exist_ok=True)
print(f"Base output directory: {output_dir_base}")
print(f"Epochs to check: {epochs_to_check}")
print(f"Images per epoch: {num_images_per_epoch}")

Base output directory: generator_comparison
Epochs to check: [75, 80, 85, 100, 125]
Images per epoch: 8


In [5]:
# Cell 5: Load Generators (Multiple Epochs), Generate, and Save Images

print("\n--- Loading Generators and Generating Comparison Spectrograms ---")

# --- Ensure prerequisite definitions are available ---
if 'SelfAttention' not in locals(): raise NameError("SelfAttention class definition not found.")
if 'create_generator' not in locals(): raise NameError("create_generator function definition not found.")

loaded_generators = {} # Dictionary to hold loaded generators for each epoch

# --- Loop through epochs to load weights ---
for epoch in epochs_to_check:
    generator_weights_path = f'generator_wgan_sa_epoch_{epoch}.weights.h5'
    print(f"\n--- Processing Epoch {epoch} ---")

    if not os.path.exists(generator_weights_path):
        print(f"Weights file not found: {generator_weights_path}. Skipping this epoch.")
        continue

    print("Recreating generator structure...")
    try:
        gen = create_generator(latent_dim, mel_spectrogram_shape)
        # Build generator via dummy pass
        print("Building generator...")
        dummy_input = tf.zeros((1, latent_dim), dtype=tf.float32)
        _ = gen(dummy_input, training=False)
        print("Structure built.")
    except Exception as e:
        print(f"Error recreating/building generator for epoch {epoch}: {e}")
        continue # Skip to next epoch

    print(f"Loading weights from: {generator_weights_path}")
    try:
        gen.load_weights(generator_weights_path)
        print("Weights loaded successfully.")
        loaded_generators[epoch] = gen # Store the loaded generator
    except Exception as e:
        print(f"Error loading generator weights for epoch {epoch}: {e}.")
        # Don't add to dict if loading failed


# --- Generate and Save Images ---
print(f"\n--- Generating Images for Loaded Epochs: {list(loaded_generators.keys())} ---")

for epoch, generator in loaded_generators.items():
    print(f"\nGenerating {num_images_per_epoch} images for Epoch {epoch}...")
    epoch_output_dir = os.path.join(output_dir_base, f"epoch_{epoch}")
    os.makedirs(epoch_output_dir, exist_ok=True)

    try:
        # Generate noise vectors (can do in batch for efficiency here if memory allows for num_images_per_epoch)
        noise_vectors = tf.random.normal([num_images_per_epoch, latent_dim])
        generated_spectrograms = generator(noise_vectors, training=False)
        generated_spectrograms_np = generated_spectrograms.numpy()

        # Plot and save each generated spectrogram
        for i in range(num_images_per_epoch):
            spec_to_plot = generated_spectrograms_np[i]
            if spec_to_plot.shape[-1] == 1:
                spec_to_plot = np.squeeze(spec_to_plot, axis=-1)
            elif len(spec_to_plot.shape) > 2:
                 spec_to_plot = spec_to_plot[:, :, 0] # Plot first channel

            if len(spec_to_plot.shape) == 2:
                plt.figure(figsize=(10, 4))
                plt.imshow(spec_to_plot, cmap='viridis', origin='lower', aspect='auto')
                plt.colorbar(label='Log Amplitude (Normalized)')
                plt.title(f'Generated Spectrogram (Epoch {epoch} - Sample {i+1})')
                plt.xlabel('Time Frames'); plt.ylabel('Mel Bins')
                plt.tight_layout()
                save_path = os.path.join(epoch_output_dir, f'generated_spec_epoch{epoch}_sample{i+1}.png')
                plt.savefig(save_path); plt.close()
            else:
                 print(f"Skipping plot for epoch {epoch} sample {i+1} due to unexpected dimensions.")

    except Exception as e:
         print(f"Error generating images for epoch {epoch}: {e}")

print(f"\nFinished generating comparison images in '{output_dir_base}'.")


--- Loading Generators and Generating Comparison Spectrograms ---

--- Processing Epoch 75 ---
Recreating generator structure...
Note: Output shape (80, 126) might require final adjustment layer.


I0000 00:00:1744051274.555790  302968 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 2143 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 Ti Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


Generator adding final Conv2D to adjust width from 128 to 126
Building generator...


I0000 00:00:1744051275.152090  302968 cuda_dnn.cc:529] Loaded cuDNN version 90300


Structure built.
Loading weights from: generator_wgan_sa_epoch_75.weights.h5
Weights loaded successfully.

--- Processing Epoch 80 ---
Recreating generator structure...
Note: Output shape (80, 126) might require final adjustment layer.
Generator adding final Conv2D to adjust width from 128 to 126
Building generator...
Structure built.
Loading weights from: generator_wgan_sa_epoch_80.weights.h5
Weights loaded successfully.

--- Processing Epoch 85 ---
Recreating generator structure...
Note: Output shape (80, 126) might require final adjustment layer.
Generator adding final Conv2D to adjust width from 128 to 126
Building generator...
Structure built.
Loading weights from: generator_wgan_sa_epoch_85.weights.h5
Weights loaded successfully.

--- Processing Epoch 100 ---
Recreating generator structure...
Note: Output shape (80, 126) might require final adjustment layer.
Generator adding final Conv2D to adjust width from 128 to 126
Building generator...
Structure built.
Loading weights from: 