## 1. Setup, Imports & Configuration

These cell loads all required libraries and utilities for audio processing, modeling, visualization, and inline playback.

We configure project paths, device (MPS/CPU), audio/spectrogram parameters, and training hyperparameters. Also import our custom `src/` modules (`ChordSpec`, `AblationGenerator128`, `train_gan`, `spec_to_audio`).  

### Note: Ensure you run these 2 cells before proceeding with the rest of the notebook.


In [None]:
import os
import sys
import random

import numpy as np
import torch
from torch.utils.data import DataLoader
import librosa
import matplotlib.pyplot as plt
from scipy.signal import fftconvolve
from matplotlib.animation import FuncAnimation

from IPython.display import Audio, display, HTML

In [None]:
# confguration of paths
NOTEBOOK_DIR = os.getcwd()
PROJECT_ROOT = os.path.abspath(os.path.join(NOTEBOOK_DIR, os.pardir))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# device selection: MPS (Apple Silicon) if available, otherwise CPU
# all development of this project was done on an Apple Silicon Mac, if you have an NVIDIA GPU, please adjust the code to use CUDA.
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# setting up paths
DATA_ROOT   = os.path.join(PROJECT_ROOT, "data", "guitar", "Training")
OUTPUT_DIR  = os.path.join(PROJECT_ROOT, "checkpoints", "guitar")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Training data root: {DATA_ROOT}")
print(f"Checkpoints: {OUTPUT_DIR}")

# Note: ChatGPT was used to research and set the following audio parameters

# audio & spectrogram parameters
# these parameters are all set to convential values
SR = 22050 # sampling rate
N_FFT = 1024 # fft window size
HOP_LENGTH = 256 # hop (stride) for stft
N_MELS = 128 # number of mel bands
FMIN = 50.0 # min freq for mel scale
FMAX = 8000.0 # max freq for mel scale
CLIP_DUR = 3.0 # clip duration in seconds

# training hyperparameters
BATCH_SIZE = 16
LATENT_DIM = 100
LR = 2e-4
BETA1, BETA2 = 0.5, 0.999
EPOCHS = 500
CHECKPOINT_EPOCH = 100

# project modules
from src.data.dataset import ChordSpec
from src.models import AblationGenerator128
from src.train import train_gan
from src.utils import spec_to_audio

## 2. Dataset Inspection & Spectrogram Reconstruction

This cell:

1. Loads the `ChordSpec` dataset and pick a random example.  
2. Visualizes its 128-band mel-spectrogram.  
3. Inverts the spectrogram back to audio using our `spec_to_audio` Griffin–Lim routine.  
4. Plays the original and reconstructed waveforms side by side for comparison.  

Run this cell to verify data loader and inversion pipeline before proceeding.  


In [None]:
# instantiate dataset & pick one random sample wav
dataset = ChordSpec(
    root=DATA_ROOT,
    sr=SR,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_mels=N_MELS,
    fmin=FMIN,
    fmax=FMAX,
    duration=CLIP_DUR
)
idx = np.random.randint(len(dataset))

# get & plot its mel-spectrogram
spec = dataset[idx].squeeze(0).numpy()
plt.figure(figsize=(6,4))
plt.imshow(spec, origin='lower', aspect='auto')
plt.title('Random Mel-spectrogram sample')
plt.colorbar(label='Normalized dB')
plt.tight_layout()
plt.show()

# load original waveform
file_path = dataset.files[idx]
y_orig, _ = librosa.load(file_path, sr=SR, mono=True, duration=CLIP_DUR)

# reconstruct via Griffin–Lim utility, code for this function can be found in src/utils.py
y_recon = spec_to_audio(
    spec,
    sr=SR,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH
)

# display and play both
print("Original:")
display(Audio(y_orig, rate=SR))
print("Reconstructed:")
display(Audio(y_recon, rate=SR))


## 3. Optional: Train GAN

You can load any of the checkpoints available in `./checkpoints/guitar/` for epochs 100, 200, 300, and 400 and skip training. 

- **Skip training:** Move on to the next section to load a pretrained generator.
- **To train from scratch:** Uncomment the `train_gan(...)` call in the last line of the cell blow and run the cell.

In [None]:
# dataloader setup
loader = DataLoader(
    ChordSpec(
        root=DATA_ROOT,
        sr=SR,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        n_mels=N_MELS,
        fmin=FMIN,
        fmax=FMAX,
        duration=CLIP_DUR
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=(DEVICE != torch.device('cpu'))
)

# training config
cfg = {
    'latent_dim': LATENT_DIM,
    'feature_maps': 64,
    'lr': LR,
    'beta1': BETA1,
    'beta2': BETA2,
    'epochs': EPOCHS,
    'n_mels': N_MELS,
    'checkpoint_epoch': CHECKPOINT_EPOCH
}

# run training loop. uncomment the line below if you wish to train the GAN
#G = train_gan(loader, DEVICE, OUTPUT_DIR, cfg)

## 4. Inference: Generate & Listen to Samples

This cell loads the pretrained generator from `G_ep400.pth` and produces five random spectrograms along with their audio reconstructions.

1. Load checkpoint: Instantiates `AblationGenerator128` and loads weights from epoch 400.  
2. Sample latent vectors: Draws 5 random `z` vectors and runs them through the generator.  
3. Visualize: Plots each generated mel-spectrogram.  
4. Listen: Uses `spec_to_audio` to invert each spectrogram and plays back the resulting waveform.  

Feel free to rerun this cell to hear different random samples.


In [None]:
# load the pretrained generator
G = AblationGenerator128(latent_dim=LATENT_DIM, feature_maps=64, mask=None).to(DEVICE)
ckpt_path = os.path.join(OUTPUT_DIR, "G_ep400.pth")
G.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
G.eval()

# sample 5 random spectrograms
zs = torch.randn(5, LATENT_DIM, 1, 1, device=DEVICE)
with torch.no_grad():
    specs, _ = G(zs)
specs_np = specs.squeeze(1).cpu().numpy()

# plot the generated spectrograms
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, ax in enumerate(axes):
    ax.imshow(specs_np[i], origin='lower', aspect='auto')
    ax.set_title(f"Sample {i+1}")
    ax.axis('off')
plt.tight_layout()
plt.show()

# reconstruct audio from generated spectrograms and play them
for i, spec in enumerate(specs_np, start=1):
    y = spec_to_audio(
        spec,
        sr=SR,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        win_length=N_FFT
    )
    print(f"Play generated sample {i}")
    display(Audio(y, rate=SR))

## 5. Ablation Soundscape: Progressive Neuron Death

This cell synthesizes the full “memory‐loss” soundscape by:

1. **Loading the pretrained generator**
2. **Defining ablation parameters**: zeroing out 5 random channels per segment.  
3. **Building cumulative masks**: we shuffle all `(layer, channel)` pairs, chunk them into groups of 5, and accumulate the masks over segments (stored in `ablation_info`).  
4. **Generating segments**: for each mask, we draw a fresh random latent vector `z`, apply the mask in `Gm.mask`, generate its mel-spectrogram, and invert it via `spec_to_audio`.  
5. **Segment stitching**: we overlap consecutive segments slightly to create one continuous waveform.
6. **Optional reverb**: an exponential‐decay impulse response adds some depth/reverb if desired.  
7. **Playback**: listen to the dry soundscape, and the added reverb one side by side.

Run this cell to generate the soundscape and experience the gradual auditory degradation as neurons are “killed” in the generator.


In [None]:
import os
import random
import torch
import numpy as np
from scipy.signal import fftconvolve
from IPython.display import Audio

# assume these are defined elsewhere or set appropriately:
# LATENT_DIM, DEVICE, PROJECT_ROOT, SR, N_FFT, HOP_LENGTH, spec_to_audio

# Note: ChatGPT was used to debug and sanity check the ablation mask code

# load pretrained generator for ablation study
Gm = AblationGenerator128(latent_dim=LATENT_DIM, feature_maps=64, mask=None).to(DEVICE)
ckpt = torch.load(
    os.path.join(PROJECT_ROOT, "checkpoints", "guitar", "G_ep400.pth"),
    map_location=DEVICE
)
Gm.load_state_dict(ckpt)
Gm.eval()

# ablation parameters
kills_per_segment = 5

# build deconv neuron list & random kill order
deconv_layers = [
    (i, layer) for i, layer in enumerate(Gm.net)
    if isinstance(layer, torch.nn.ConvTranspose2d)
]
neurons = [
    (li, c)
    for li, layer in deconv_layers
    for c in range(layer.out_channels)
]
random.shuffle(neurons)

# build cumulative masks for ablation
chunks = [
    neurons[i : i + kills_per_segment]
    for i in range(0, len(neurons), kills_per_segment)
]
used, masks = set(), []
for chunk in chunks:
    used.update(chunk)
    masks.append({
        l: [c for (l0, c) in used if l0 == l]
        for (l, _) in deconv_layers
        if any(l0 == l for (l0, _) in used)
    })
ablation_info = {"order": neurons, "masks": masks}

# generate one random z per segment, apply mask, invert to audio
segments = []
for mask in masks:
    z = torch.randn(1, LATENT_DIM, 1, 1, device=DEVICE)
    Gm.mask = mask
    with torch.no_grad():
        spec_t, _ = Gm(z)
    spec = spec_t.cpu().squeeze().squeeze().numpy()
    y = spec_to_audio(spec, sr=SR, n_fft=N_FFT, hop_length=HOP_LENGTH)
    segments.append(y)

# stitch together with 30% overlap between consecutive segments
overlap = int(len(segments[0]) * 0.30) 
out = segments[0].copy()
for seg in segments[1:]:
    # sum overlap region
    overlapped = out[-overlap:] + seg[:overlap]
    out = np.concatenate([
        out[:-overlap],
        overlapped,
        seg[overlap:]
    ])

# normalize to prevent clipping
out /= np.max(np.abs(out))

print(f"Total duration: {len(out)/SR:.2f}s, Segments: {len(segments)}")

# apply slight reverb to the output
reverb_secs = 0.2
decay_rate  = 4.0
ir_len      = int(reverb_secs * SR)
t_ir        = np.linspace(0, reverb_secs, ir_len)
ir = np.exp(-decay_rate * t_ir)
ir[0] = 1.0

wet = fftconvolve(out, ir, mode="full")[:len(out)]
wet /= np.max(np.abs(wet))

# playback
print("Dry (no reverb)")
display(Audio(out, rate=SR))

print("Wet (with reverb)")
display(Audio(wet, rate=SR))


## 6. Neuron Ablation Animation

This cell visualizes the sequence of neuron “deaths” in the generator as a synchronized animation:

1. **Layout the neuron grid**: each deconvolution channel is drawn as a small white circle positioned by layer (y-axis) and channel index (x-axis).  
2. **Compile kill indices**: frame 0 shows all neurons alive; each subsequent frame highlights the newly ablated neurons in red.  
3. **Matplotlib scatter plot** 
4. **Animate with `FuncAnimation`** at the same tempo as the soundscape segments, so the visual and audio degradations stay in sync.  

Run this cell _after_ generating the soundscape to see network degradation alongside audio playback. You can see exactly when and where neurons are being switched off in the network’s latent-to-spectrogram pipeline.  


In [None]:
# build neuron grid
layer_sizes = [layer.out_channels for _,layer in deconv_layers]
n_layers = len(layer_sizes)
neuron_list, xs, ys, ss = [], [], [], []
for depth, (li, size) in enumerate(zip([i for i,_ in deconv_layers], layer_sizes)):
    for c in range(size):
        neuron_list.append((li, c))
        xs.append((c + 0.5)/size)
        ys.append(1 - (depth + 0.5)/n_layers)
        ss.append(max(0.5, 500/size))

# build kill indices per frame
kill_by_frame = [[]]
for mask in masks:
    killed = [(l,c) for l,chans in mask.items() for c in chans]
    kill_by_frame.append([neuron_list.index(n) for n in killed])
total_frames = len(kill_by_frame)

# plot the setup
fig, ax = plt.subplots(figsize=(6,6))
fig.patch.set_facecolor('black'); ax.set_facecolor('black')
ax.axis('off'); ax.set_xlim(0,1); ax.set_ylim(0,1)
scat = ax.scatter(xs, ys, c='white', s=ss, edgecolors='none')

# update frame function
def update(frame):
    dead = set(kill_by_frame[frame])
    colors = ['red' if i in dead else 'white' for i in range(len(neuron_list))]
    scat.set_color(colors)
    return scat,

# one frame per segment
interval_ms = int((len(segments[0])*(1-0.3))/SR*1000)  # approx segment play time
anim = FuncAnimation(fig, update, frames=total_frames, interval=interval_ms, blit=True)

# display the animation
HTML(anim.to_jshtml())

In [None]:
out_dir = '../animations'
os.makedirs(out_dir, exist_ok=True)
fps = 1000 / interval_ms

anim.save(os.path.join(out_dir, 'ablation_animation.mp4'), writer='ffmpeg', fps=fps)