# How to Organize and Use CNN Autoencoder Files

This notebook demonstrates how to organize all files related to the CNN autoencoder into a single folder and how to use them for training, diagnostics, and inference.

## 1. Import Required Libraries

We will use Python's built-in `os` and `shutil` libraries for file and folder operations, and `torch` for model loading.

In [None]:
import os
import shutil
import torch
from pathlib import Path

# For demonstration, set the working directory
WORK_DIR = Path('..').resolve()
CNN_FOLDER = WORK_DIR / 'cnn_autoencoder'
MODEL_FOLDER = CNN_FOLDER / 'models'
OUTPUT_FOLDER = CNN_FOLDER / 'output'
os.makedirs(CNN_FOLDER, exist_ok=True)
os.makedirs(MODEL_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
print(f"Working directory: {WORK_DIR}")
print(f"CNN folder: {CNN_FOLDER}")
print(f"Model folder: {MODEL_FOLDER}")
print(f"Output folder: {OUTPUT_FOLDER}")

## 2. Create a Folder for CNN Classification Files

This step ensures the folder exists. (Already created as `cnn_autoencoder`.)

In [None]:
# Create the folder if it doesn't exist
os.makedirs(CNN_FOLDER, exist_ok=True)
print(f"Folder created: {CNN_FOLDER}")

## 3. Move or Copy CNN Files into the Folder

Copy all relevant CNN autoencoder files (scripts, weights, etc.) into the `cnn_autoencoder` folder. (Skip if already organized.)

In [None]:
# List of files to copy (edit as needed)
files_to_copy = [
    'train_autoencoder_fixed.py',
    'diagnose_autoencoder.py',
    'generate_all_spectrograms.py',
    'cluster_events.py',
    'analyze_clusters.py',
    'README.md'
]

MODEL_FOLDER = CNN_FOLDER / 'models'
os.makedirs(MODEL_FOLDER, exist_ok=True)

# Model files to move (demo: you may want to keep originals elsewhere)
model_files = [
    'autoencoder_model.pth',
    'encoder_model.pkl',
    'autoencoder_learner.pkl'
]

for fname in files_to_copy:
    src = WORK_DIR / fname
    dst = CNN_FOLDER / fname
    if src.exists():
        if not dst.exists():
            shutil.copy2(src, dst)
            print(f"Copied {src} -> {dst}")
        else:
            print(f"Already exists: {dst}")
    else:
        print(f"Source file not found: {src}")

for fname in model_files:
    src = WORK_DIR / fname
    dst = MODEL_FOLDER / fname
    if src.exists():
        if not dst.exists():
            shutil.copy2(src, dst)
            print(f"Copied {src} -> {dst}")
        else:
            print(f"Already exists: {dst}")
    else:
        print(f"Model file not found: {src}")

## 4. List Files in the CNN Folder

Display all files now present in the `cnn_autoencoder` folder.

In [None]:
print("Files in cnn_autoencoder folder:")
for f in os.listdir(CNN_FOLDER):
    print(f"- {f}")
print("\nFiles in cnn_autoencoder/models folder:")
for f in os.listdir(MODEL_FOLDER):
    print(f"- {f}")
print("\nFiles in cnn_autoencoder/output folder:")
for f in os.listdir(OUTPUT_FOLDER):
    print(f"- {f}")

## 5. Load and Use CNN Files in the Notebook

Demonstrate how to load the trained autoencoder and run inference on a sample image.

**Notes:**
- Edit the `RECORDING_FOLDER` variable to point to your specific recording's folder.
- The script will save all spectrogram images in the `all_spectrograms` subfolder of your recording folder.
- You can also modify the script to accept command-line arguments for input/output paths if needed.
- Make sure your data paths in `generate_all_spectrograms.py` are set correctly for your data organization.

In [None]:
# Example: Generate all spectrograms for a specific recording
import sys
from pathlib import Path

# Set the path to your recording folder (edit as needed)
RECORDING_FOLDER = WORK_DIR / 'recording_01'  # Change to your recording folder
SPECTROGRAM_OUTPUT = RECORDING_FOLDER / 'all_spectrograms'
os.makedirs(SPECTROGRAM_OUTPUT, exist_ok=True)

# Option 1: Run the script as a subprocess (recommended for full pipeline)
import subprocess
subprocess.run([sys.executable, str(CNN_FOLDER / 'generate_all_spectrograms.py')])

# Option 2: Import and call the function directly (advanced users)
# from cnn_autoencoder.generate_all_spectrograms import generate_spectrogram_images
# generate_spectrogram_images()

print(f"Spectrograms will be saved in: {SPECTROGRAM_OUTPUT}")

## 6. Generate All Spectrograms for a Recording

To generate spectrogram images for all detected SWR events in a given recording, use the `generate_all_spectrograms.py` script. This script will:

- Perform SWR detection on your LFP and spike data
- Compute a spectrogram for each detected event
- Save each spectrogram as a PNG image in a specified output folder (e.g., `all_spectrograms` inside your recording folder)
- Save the detected events as a pickle file for downstream analysis

**Usage:** You can run this script from the command line, or call its main function from Python. Below is an example of how to use it programmatically and how to specify the output directory for a particular recording.

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Import the model definition from the script (if needed)
from cnn_autoencoder.train_autoencoder_fixed import ImprovedAutoencoder

# Load the trained model from the models folder
model_path = MODEL_FOLDER / 'autoencoder_model.pth'
model = ImprovedAutoencoder()
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

# Load a sample image from all_spectrograms
sample_img_path = WORK_DIR / 'all_spectrograms' / 'event_00000.png'
img = Image.open(sample_img_path).convert('L').resize((128, 128))
img_array = np.array(img, dtype=np.float32) / 255.0
img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)

# Run inference
with torch.no_grad():
    recon = model(img_tensor)

# Visualize original and reconstructed
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(img_array, cmap='viridis')
axes[0].set_title('Original')
axes[1].imshow(recon.squeeze().numpy(), cmap='viridis')
axes[1].set_title('Reconstruction')
plt.show()