# StyleGAN2-ADA Training — Latent Resonance Spectrograms (Kaggle)

Train StyleGAN2-ADA on 512×512 grayscale spectrogram images.
ADA (Adaptive Discriminator Augmentation) is purpose-built for limited-data regimes,
making it a better fit than StyleGAN3 for small datasets (435–489 images).

**Setup:** In the Kaggle sidebar, go to **Settings → Accelerator → GPU T4 x2**.

**Dataset:** Upload your `spectrograms.zip` as a [Kaggle Dataset](https://www.kaggle.com/datasets),
then add it to this notebook via **Add data** in the sidebar.

In [None]:
import subprocess
import os
import sys
import json
from pathlib import Path

def run_command(cmd, capture=True, check=False):
    """Helper function to run shell commands and return output"""
    try:
        result = subprocess.run(cmd, shell=True, capture_output=capture, text=True, check=check)
        if capture:
            return result.stdout.strip() if result.stdout else result.stderr.strip()
        return result.returncode == 0
    except Exception as e:
        return str(e)

print("=" * 60)
print("PYTHON ENVIRONMENT SWITCHER FOR KAGGLE V3")
print("=" * 60)
print()

# Step 1: Document current environment
print("[1] DOCUMENTING CURRENT ENVIRONMENT")
print("-" * 40)
print(f"Current Python version: {sys.version}")
print(f"Python executable: {sys.executable}")
print(f"Python prefix: {sys.prefix}")
print()

print("Current conda environments:")
print(run_command("conda env list"))
print()

print("Current Python symlinks:")
print(run_command("ls -la /opt/conda/bin/python*"))
print()

# Save the original Python executable path for recovery
original_python = run_command("readlink -f /opt/conda/bin/python3.7")
print(f"Original Python executable: {original_python}")
print()

# Save original Jupyter location
original_jupyter = run_command("which jupyter")
print(f"Original Jupyter location: {original_jupyter}")
print()

print("Backing up symlink information...")
backup_content = run_command("ls -la /opt/conda/bin/python*")
with open("/tmp/python_symlinks_backup.txt", "w") as f:
    f.write(backup_content)
    f.write(f"\nOriginal Python: {original_python}")
    f.write(f"\nOriginal Jupyter: {original_jupyter}")
print("Backup saved to /tmp/python_symlinks_backup.txt")
print()

# Step 2: Create test scripts
print("[2] CREATING TEST SCRIPTS")
print("-" * 40)

test_version_code = '''import sys
print(f"Python {sys.version}")
print(f"Executable: {sys.executable}")
print(f"Path prefix: {sys.prefix}")
'''

with open("/tmp/test_version.py", "w") as f:
    f.write(test_version_code)
print("Test scripts created")
print()

# Step 3: Test current environment
print("[3] TESTING ORIGINAL ENVIRONMENT")
print("-" * 40)
print("Running test with current Python:")
print(run_command("python /tmp/test_version.py"))
print()

# Step 4: Create new conda environment with Python
print("[4] CREATING NEW CONDA ENVIRONMENT WITH PYTHON")
print("-" * 40)

# Create environment with Python 3.7 (or whatever version you prefer)
print("Creating environment 'newCondaEnvironment' with Python 3.7...")
create_env_cmd = "conda create -n newCondaEnvironment python=3.7 -c conda-forge -y"
result = run_command(create_env_cmd)
print("Environment creation completed")
print()

# Step 5: Install critical packages in new environment
print("[5] INSTALLING CRITICAL PACKAGES IN NEW ENVIRONMENT")
print("-" * 40)
print("Installing Jupyter and other essential packages to prevent kernel failures...")

essential_packages = [
    "jupyter",
    "jupyter_core", 
    "jupyter_client",
    "ipykernel",
    "nbconvert",
    "papermill",
    "numpy",
    "pandas"
]

install_cmd = f"conda install -n newCondaEnvironment {' '.join(essential_packages)} -c conda-forge -y"
print(f"Installing: {', '.join(essential_packages)}")
result = run_command(install_cmd)
print("Essential packages installation completed")
print()

# If cctbx channel is needed
if False:  # Set to True if you need cctbx channel
    print("Adding cctbx202208 channel packages...")
    add_channel_cmd = "conda install -n newCondaEnvironment -c cctbx202208 -c conda-forge --override-channels python -y"
    result = run_command(add_channel_cmd)
    print("Channel packages added")
    print()

# Step 6: Verify new environment installation
print("[6] VERIFYING NEW ENVIRONMENT")
print("-" * 40)

env_path = "/opt/conda/envs/newCondaEnvironment"
print(f"Checking if environment exists at {env_path}...")
if os.path.exists(env_path):
    print("✓ Environment directory exists")
else:
    print("✗ Environment directory not found!")
print()

# Look for Python in the new environment
print("Looking for Python executables in new environment:")
python_locations = [
    f"{env_path}/bin/python",
    f"{env_path}/bin/python3",
    f"{env_path}/bin/python3.7"
]

new_python_path = None
for path in python_locations:
    if os.path.exists(path):
        print(f"✓ Found: {path}")
        if new_python_path is None:
            new_python_path = path
            version = run_command(f"{path} --version")
            print(f"  Version: {version}")
    else:
        print(f"✗ Not found: {path}")
print()

# Check for Jupyter in new environment
new_jupyter_path = f"{env_path}/bin/jupyter"
if os.path.exists(new_jupyter_path):
    print(f"✓ Jupyter found in new environment: {new_jupyter_path}")
else:
    print(f"✗ Jupyter NOT found in new environment!")
print()

if new_python_path and os.path.exists(new_python_path):
    print(f"Using Python at: {new_python_path}")
    print(f"Version: {run_command(f'{new_python_path} --version')}")
    print()
    
    # Step 7: Remove old symlinks (but keep jupyter working)
    print("[7] REMOVING OLD PYTHON SYMLINKS")
    print("-" * 40)
    print("Removing existing Python symlinks...")
    
    # Only remove Python symlinks, not jupyter or other tools
    symlinks_to_remove = [
        "/opt/conda/bin/python",
        "/opt/conda/bin/python3"
    ]
    
    for symlink in symlinks_to_remove:
        if os.path.islink(symlink):
            result = run_command(f"sudo rm -f {symlink}")
            print(f"  Removed: {symlink}")
    
    print("Old Python symlinks removed")
    print()
    
    # Step 8: Create new symlinks
    print("[8] CREATING NEW SYMLINKS")
    print("-" * 40)
    print(f"Creating new symlinks to {new_python_path}...")
    
    symlinks_to_create = {
        "/opt/conda/bin/python": new_python_path,
        "/opt/conda/bin/python3": new_python_path
    }
    
    for symlink, target in symlinks_to_create.items():
        result = run_command(f"sudo ln -sf {target} {symlink}")
        print(f"  Created: {symlink} -> {target}")
    
    # Also update Jupyter symlink if new one exists
    if os.path.exists(new_jupyter_path):
        print(f"Updating Jupyter symlink...")
        run_command(f"sudo rm -f /opt/conda/bin/jupyter")
        run_command(f"sudo ln -sf {new_jupyter_path} /opt/conda/bin/jupyter")
        print(f"  Created: /opt/conda/bin/jupyter -> {new_jupyter_path}")
    
    print("New symlinks created")
    print()
    
    # Step 9: Verify symlink changes
    print("[9] VERIFYING SYMLINK CHANGES")
    print("-" * 40)
    print("Python symlinks:")
    print(run_command("ls -la /opt/conda/bin/python /opt/conda/bin/python3"))
    print()
    
    print("Jupyter symlink:")
    print(run_command("ls -la /opt/conda/bin/jupyter"))
    print()
    
    print("Verifying targets:")
    for symlink in ["/opt/conda/bin/python", "/opt/conda/bin/python3", "/opt/conda/bin/jupyter"]:
        if os.path.exists(symlink):
            target = run_command(f"readlink -f {symlink}")
            print(f"  {symlink} -> {target}")
    print()
    
    # Step 10: Test new environment
    print("[10] TESTING NEW ENVIRONMENT")
    print("-" * 40)
    print("Python version after switch:")
    print(f"  python: {run_command('python --version')}")
    print(f"  python3: {run_command('python3 --version')}")
    print()
    
    print("Jupyter version:")
    jupyter_version = run_command("jupyter --version")
    print(jupyter_version)
    print()
    
    print("Testing Jupyter modules:")
    test_jupyter_cmd = 'python -c "import jupyter_core; print(f\'jupyter_core: {jupyter_core.__version__}\')"'
    test_jupyter = run_command(test_jupyter_cmd)
    print(test_jupyter)
    print()
    
    print("Running test script:")
    print(run_command("python /tmp/test_version.py"))
    print()
    
    # Step 11: Test Python imports
    print("[11] TESTING PYTHON IMPORTS")
    print("-" * 40)
    
    # Test basic imports
    print("Testing basic imports:")
    basic_test_cmd = 'python -c "import sys, os; print(f\'✓ Basic imports OK. Python {sys.version.split()[0]} at {sys.executable}\')"'
    basic_test = run_command(basic_test_cmd)
    print(basic_test)
    print()
    
    # Test packages with proper syntax
    packages_to_test = ["numpy", "pandas", "matplotlib", "sklearn", "jupyter_core", "ipykernel"]
    print("Testing package availability:")
    for pkg in packages_to_test:
        # Use simpler syntax that works in all Python versions
        test_cmd = f"""python -c "
try:
    import {pkg}
    print('  ✓ {pkg}: ' + str(getattr({pkg}, '__version__', 'version unknown')))
except ImportError:
    print('  ✗ {pkg}: Not installed')
" """
        result = run_command(test_cmd)
        if result:
            print(result)
    print()
    
    # Step 12: Create rollback script
    print("[12] CREATING ROLLBACK CAPABILITY")
    print("-" * 40)
    
    rollback_code = f'''
import subprocess
import os

print("Rolling back Python environment changes...")

# Remove modified symlinks
symlinks = ["/opt/conda/bin/python", "/opt/conda/bin/python3", "/opt/conda/bin/jupyter"]

for symlink in symlinks:
    if os.path.islink(symlink):
        subprocess.run(f"sudo rm -f {{symlink}}", shell=True)
        print(f"Removed: {{symlink}}")

# Restore original symlinks
original_python = "{original_python}"
original_jupyter = "{original_jupyter}"

if os.path.exists(original_python):
    subprocess.run(f"sudo ln -sf {{original_python}} /opt/conda/bin/python", shell=True)
    subprocess.run(f"sudo ln -sf {{original_python}} /opt/conda/bin/python3", shell=True)
    print(f"Restored Python symlinks to: {{original_python}}")

if os.path.exists(original_jupyter):
    subprocess.run(f"sudo ln -sf {{original_jupyter}} /opt/conda/bin/jupyter", shell=True)
    print(f"Restored Jupyter symlink to: {{original_jupyter}}")
else:
    # Fallback: reinstall
    print("Reinstalling Python 3.7 and Jupyter...")
    subprocess.run("conda install -n base python=3.7 jupyter -y", shell=True)

# Verify rollback
result = subprocess.run("python --version", shell=True, capture_output=True, text=True)
print(f"Python version after rollback: {{result.stdout.strip()}}")
result = subprocess.run("jupyter --version", shell=True, capture_output=True, text=True)
print(f"Jupyter after rollback: {{result.stdout.strip()[:50]}}")
print("Rollback complete!")
'''
    
    with open("/tmp/rollback_python.py", "w") as f:
        f.write(rollback_code)
    
    print("Rollback script created at /tmp/rollback_python.py")
    print("To rollback, run: exec(open('/tmp/rollback_python.py').read())")
    print()
    
    # Step 13: Final summary
    print("=" * 60)
    print("ENVIRONMENT SWITCH COMPLETE")
    print("=" * 60)
    print()
    print("SUMMARY:")
    print(f"  ✓ Original environment backed up")
    print(f"  ✓ New conda environment created with Python 3.9")
    print(f"  ✓ Jupyter and essential packages installed")
    print(f"  ✓ Python symlinks updated")
    print(f"  ✓ Jupyter functionality preserved")
    print()
    print("CURRENT STATE:")
    current_version = run_command('python --version')
    current_location = run_command('which python')
    print(f"  Python version: {current_version}")
    print(f"  Python location: {current_location}")
    
    # Check Jupyter status without f-string issues
    jupyter_check_cmd = 'jupyter --version'
    jupyter_check = run_command(jupyter_check_cmd)
    jupyter_status = '✓ Working' if 'jupyter' in jupyter_check else '✗ Not working'
    print(f"  Jupyter status: {jupyter_status}")
    print()
    
    print("IMPORTANT NOTES:")
    print("  1. The Jupyter kernel itself is still running the original Python")
    print("  2. Only subprocess calls (!command) will use the new Python")
    print("  3. Jupyter should continue working normally")
    print("  4. To rollback: run exec(open('/tmp/rollback_python.py').read())")
    print()
    print("=" * 60)
    
    # Final test - fixed f-string issue
    print("FINAL TEST - Verify everything works:")
    print(f"Python: {run_command('python --version')}")
    
    # Test jupyter_core import without f-string backslash issue
    import_test_cmd = 'python -c "import jupyter_core; print(True)"'
    import_test_result = run_command(import_test_cmd)
    print(f"Can import jupyter_core: {import_test_result}")
    print("=" * 60)
    
else:
    print("=" * 60)
    print("ERROR: Could not find or install a new Python version!")
    print("The environment switch failed.")
    print("=" * 60)


## 1. Setup & GPU Check

In [12]:
!nvidia-smi
!conda create -n py37 python=3.7 anaconda --yes
!source /opt/conda/bin/activate py37 && conda install -c py37 python -y

!pip install -q ninja
!pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3

import torch
assert torch.cuda.is_available(), "No GPU — enable it in Settings → Accelerator → GPU T4 x2"
print(f"PyTorch {torch.__version__}, CUDA {torch.version.cuda}, GPU: {torch.cuda.get_device_name(0)}")

Mon Feb  2 12:59:00 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   49C    P8             10W /   70W |       3MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla T4                       Off |   00

## 2. Clone StyleGAN2-ADA & Apply Patches

In [None]:
import os
import sys
import pathlib
import shutil
import subprocess

# Fresh clone
if os.path.exists("stylegan2-ada-pytorch"):
    shutil.rmtree("stylegan2-ada-pytorch")
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
sys.path.insert(0, "stylegan2-ada-pytorch")

# ── Patch 1: Fix InfiniteSampler for PyTorch >=2.4 ────────────────────────
misc_path = pathlib.Path("stylegan2-ada-pytorch/torch_utils/misc.py")
src = misc_path.read_text()
src = src.replace("super().__init__(dataset)", "super().__init__()")

# ── Patch 2: Skip shape-mismatched params during resume ──────────────────
#    (needed for RGB ffhq512 -> 1-ch grayscale spectrograms)
#    Line-by-line patch to avoid fragile exact-string matching.
lines = src.split('\n')
new_lines = []
patched = False
for line in lines:
    if 'tensor.copy_(src_tensors[name].detach())' in line and not patched:
        indent = line[:len(line) - len(line.lstrip())]
        new_lines.append(f'{indent}if src_tensors[name].shape != tensor.shape:')
        new_lines.append(f'{indent}    continue')
        new_lines.append(line)
        patched = True
    else:
        new_lines.append(line)
src = '\n'.join(new_lines)
assert patched, "ERROR: could not patch copy_params_and_buffers"
print(f"  Patch 2: shape-mismatch guard {'applied' if patched else 'FAILED'}")

misc_path.write_text(src)
print(f"Patched {misc_path}")

# ── Patch 3: Fix Adam betas int -> float for PyTorch >=2.9 ────────────────
train_path = pathlib.Path("stylegan2-ada-pytorch/train.py")
src = train_path.read_text()
src = src.replace("betas=[0,0.99]", "betas=[0.0,0.99]")
train_path.write_text(src)
print(f"Patched {train_path}: Adam betas fix")

# ── Patch 4: Try CUDA ops compilation ────────────────────────────────────
cc_major, cc_minor = torch.cuda.get_device_capability(0)
arch = f"{cc_major}.{cc_minor}"
os.environ["TORCH_CUDA_ARCH_LIST"] = arch
os.environ["TORCH_EXTENSIONS_DIR"] = "/tmp/torch_extensions"
if os.path.exists("/tmp/torch_extensions"):
    shutil.rmtree("/tmp/torch_extensions")

result = subprocess.run(
    ["python", "-c",
     "import sys; sys.path.insert(0,'stylegan2-ada-pytorch'); "
     "from torch_utils.ops import bias_act; "
     "assert bias_act._init(), 'init failed'"],
    capture_output=True, text=True, timeout=180,
)

CUDA_OPS_OK = result.returncode == 0
if CUDA_OPS_OK:
    print(f"Custom CUDA ops compiled for sm_{cc_major}{cc_minor} — using fused kernels")
else:
    print(f"CUDA ops compilation failed (arch {arch}), using native PyTorch fallback")
    print(f"  Error: ...{result.stderr[-300:]}")
    ops_dir = pathlib.Path("stylegan2-ada-pytorch/torch_utils/ops")
    for name in ["bias_act.py", "upfirdn2d.py"]:
        p = ops_dir / name
        s = p.read_text()
        s = s.replace("def _init():", "def _init():\n    return False")
        p.write_text(s)
        print(f"  Patched {p}")

# ── Verify patch applied correctly ───────────────────────────────────────
verify = misc_path.read_text()
assert "if src_tensors[name].shape != tensor.shape:" in verify, \
    "FATAL: shape-mismatch patch missing from misc.py after write!"
print("Verified: shape-mismatch guard present in misc.py")

## 3. Load Dataset

Kaggle datasets are mounted at `/kaggle/input/<dataset-name>/`.

Set `KAGGLE_DATASET` to match your dataset name.

In [None]:
import os
import glob
import shutil

KAGGLE_DATASET = "spectrograms"  # <-- your Kaggle dataset name

input_dir = f"/kaggle/input/{KAGGLE_DATASET}"

# Find PNGs (may be in root or a subfolder)
pngs = glob.glob(f"{input_dir}/**/*.png", recursive=True)

# Copy to a writable working directory (Kaggle input is read-only)
DATASET_PATH = "/kaggle/working/spectrograms"
os.makedirs(DATASET_PATH, exist_ok=True)
for p in pngs:
    shutil.copy(p, DATASET_PATH)

print(f"Found {len(pngs)} PNG files → copied to {DATASET_PATH}")

## 4. Prepare Dataset

In [None]:
!python stylegan2-ada-pytorch/dataset_tool.py \
    --source={DATASET_PATH} \
    --dest=./spectrograms.zip

## 5. Configure Training

In [None]:
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"

GPUS = 1              # single GPU — avoids multi-GPU sync overhead on T4
GAMMA = 3.0           # higher R1 regularization for small datasets
SNAP = 5              # snapshot every 5 ticks
KIMG = 500            # fits in one Kaggle session
AUG = "ada"           # adaptive discriminator augmentation
TARGET = 0.6          # ADA target heuristic — good default for small datasets
MIRROR = False        # no mirroring — horizontal flip would reverse time axis
METRICS = "none"
BATCH_SIZE = 4        # lower VRAM pressure — 12.4 GB at batch=8 was too tight
RESUME = "ffhq512"    # patch in cell 5 handles RGB->grayscale shape mismatch

print(f"Config: gpus={GPUS}, batch={BATCH_SIZE}, gamma={GAMMA}, aug={AUG}, target={TARGET}, mirror={MIRROR}")

## 6. Train

In [None]:
import torch
torch.cuda.empty_cache()

resume_flag = f"--resume={RESUME}" if RESUME else ""
mirror_int = 1 if MIRROR else 0

!python stylegan2-ada-pytorch/train.py \
    --outdir=./training-runs \
    --cfg=auto \
    --data=./spectrograms.zip \
    --gpus={GPUS} \
    --batch={BATCH_SIZE} \
    --gamma={GAMMA} \
    --snap={SNAP} \
    --kimg={KIMG} \
    --aug={AUG} \
    --target={TARGET} \
    --mirror={mirror_int} \
    --metrics={METRICS} \
    {resume_flag}

## 7. Generate Samples

## 8. Reconstruct Audio from Generated Spectrograms

Use Griffin-Lim phase estimation to convert the generated spectrogram images back into audio waveforms.

In [None]:
import glob
import pickle

import matplotlib.pyplot as plt
import torch

pkls = sorted(glob.glob("training-runs/**/*.pkl", recursive=True))
assert pkls, "No snapshots found — has training completed at least one snapshot?"
latest_pkl = pkls[-1]
print(f"Loading {latest_pkl}")

with open(latest_pkl, "rb") as f:
    G = pickle.load(f)["G_ema"].cuda().eval()

NUM_SAMPLES = 5
z = torch.randn(NUM_SAMPLES, G.z_dim, device="cuda")
with torch.no_grad():
    imgs = G(z, None)

fig, axes = plt.subplots(1, NUM_SAMPLES, figsize=(20, 4))
for i, ax in enumerate(axes):
    img = imgs[i, 0].cpu().numpy()
    ax.imshow(img, cmap="magma", aspect="auto")
    ax.set_title(f"Sample {i}")
    ax.axis("off")
plt.suptitle("Generated Spectrograms (StyleGAN2-ADA)")
plt.tight_layout()
plt.show()

## 9. Save Results

Kaggle persists everything in `/kaggle/working/` as notebook output.
Click **Save Version** (top right) → the training-runs zip will be available under **Output**.

In [None]:
!pip install -q librosa soundfile

import numpy as np
import librosa
import soundfile as sf
import IPython.display as ipd

SR = 22050
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 512
N_ITER = 32
DB_RANGE = 80.0

output_dir = "/kaggle/working/reconstructed_audio"
os.makedirs(output_dir, exist_ok=True)

for i in range(NUM_SAMPLES):
    # Extract spectrogram as numpy array in [-1, 1]
    spec = imgs[i, 0].cpu().numpy()

    # [-1, 1] → dB → power → linear STFT → Griffin-Lim
    S_db = (spec + 1.0) * (DB_RANGE / 2.0) - DB_RANGE
    S_power = librosa.db_to_power(S_db, ref=1.0)
    S_stft = librosa.feature.inverse.mel_to_stft(S_power, sr=SR, n_fft=N_FFT, power=2.0)
    audio = librosa.griffinlim(S_stft, n_iter=N_ITER, hop_length=HOP_LENGTH, n_fft=N_FFT)

    # Normalise to -1 dBFS peak
    peak = np.abs(audio).max()
    if peak > 0:
        audio = audio / peak * 10 ** (-1.0 / 20.0)

    # Save WAV
    wav_path = f"{output_dir}/sample_{i}.wav"
    sf.write(wav_path, audio, SR)
    print(f"Sample {i}: {len(audio)} samples ({len(audio)/SR:.2f}s) → {wav_path}")

    # Inline audio player
    ipd.display(ipd.Audio(audio, rate=SR))

In [None]:
import shutil

shutil.make_archive("/kaggle/working/training-runs", "zip", ".", "training-runs")
shutil.make_archive("/kaggle/working/reconstructed_audio", "zip", ".", "reconstructed_audio")
print("Created /kaggle/working/training-runs.zip")
print("Created /kaggle/working/reconstructed_audio.zip")
print("These will be saved as notebook output when you click Save Version.")