In [None]:
import tensorflow as tf
import zipfile
import os
import matplotlib
matplotlib.use("Agg")  
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pathlib
from pathlib import Path
from pydub import AudioSegment
import librosa
from tensorflow.keras.utils import plot_model
import soundfile as sf
import shutil
from PIL import Image


import shap
import gc
from collections import Counter
from pydub.silence import detect_silence
import librosa.display
import pandas as pd
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import concurrent.futures

import json
from PIL import Image
import shutil
import csv


# DATA_DIR = pathlib.Path('data')
# TEST_DIR = pathlib.Path('data/new_test')

ROOT_DIR = Path('../').resolve()  
UNZIP_DIR = ROOT_DIR / 'Unzipped_Data_Picture' 
TEST_DIR = ROOT_DIR / UNZIP_DIR / 'new_test'
# TEST_DIR = ROOT_DIR / UNZIP_DIR / 'new_test_2'
# TEST_DIR = ROOT_DIR / UNZIP_DIR / 'new_test_3'
# TEST_DIR = ROOT_DIR / UNZIP_DIR / 'final_test'

_zip_file_path = "model_results/model_(10112_256px-resnet_model)_loss_0.215_acc_0.863_val_loss_1.141_val_acc_0.773.zip"

session = None
max_length = 0
downsize = False

In [None]:
def extract_zip(zip_path, extract_to):
    zip_path_str = str(zip_path)
    
    if not zip_path_str.endswith('.zip'):
        zip_path_str += '.zip'
    
    zip_file_path = pathlib.Path(zip_path_str)
    
    folder_name = zip_file_path.stem 
    target_folder = pathlib.Path(extract_to) / folder_name
    
    if target_folder.exists():
        print(f"Das Verzeichnis {target_folder} existiert bereits. Überspringe das Extrahieren.")
    else:
        if zip_file_path.exists():
            print(f"Extrahiere die Zip-Datei {zip_file_path} nach {extract_to}.")
            with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
                zip_ref.extractall(extract_to)
            print(f"Zip-Datei {zip_file_path} erfolgreich extrahiert.")
        else:
            print(f"Die Zip-Datei {zip_file_path} existiert nicht.")

def rename_audio_files(root_path):
    for root, dirs, files in os.walk(root_path):
        parent_folder = os.path.basename(root)
        for file in files:
            if not file.startswith(f"{parent_folder}_"):
                if file.endswith(('.wav', '.mp3')):  
                    
                    old_file_path = os.path.join(root, file)
                    new_file_name = f"{parent_folder}_{file}"
                    new_file_path = os.path.join(root, new_file_name)
                        
                    os.rename(old_file_path, new_file_path)
        print(f"renaming of {root_path}/{parent_folder} complete")

In [None]:
def normalize_audio_length(input_dir):
    output_dir = input_dir.parent / f"{input_dir.name}_normalized"
    if output_dir.exists():
        print(f"Überspringe Verarbeitung, da {output_dir} bereits existiert.")
        return output_dir
    os.makedirs(output_dir, exist_ok=True)
    
    max_length = session.get("max_length", 0)
    print(f"max_length : {max_length}")
    audio_files = []
    
    for subdir, _, files in os.walk(input_dir):
        for file in files:
            if file.endswith(".wav"):
                input_path = Path(subdir) / file
                audio = AudioSegment.from_file(input_path)
                duration = len(audio)
                audio_files.append((input_path, audio))
    
    print(f"Maximale Länge: {max_length / 1000} Sekunden")
    
    for input_path, audio in audio_files:
        if len(audio) > max_length:
            print(f"⚠️ {input_path.name} ist länger als {max_length / 1000} Sekunden. Kürze Datei!")
            audio = audio[:max_length] 

        padded_audio = audio + AudioSegment.silent(duration=max(0, max_length - len(audio)))
        
        relative_path = input_path.parent.relative_to(input_dir)
        target_dir = output_dir / relative_path
        os.makedirs(target_dir, exist_ok=True)
        output_path = target_dir / input_path.name
        
        padded_audio.export(output_path, format="wav")
        print(f"Processed {input_path.name}: expanded to {max_length / 1000} seconds")
    
    print(f"Processing complete. Normalized files saved in {output_dir}")
    return output_dir

In [None]:
def process_audio_file(audio_file, input_dir, output_dir, n_mels, fmin, fmax):
    """
    Diese Funktion verarbeitet eine einzelne Audiodatei und berechnet das Mel-Spektrogramm.
    """
    relative_path = audio_file.relative_to(input_dir)
    
    target_dir = output_dir / relative_path.parent
    target_dir.mkdir(parents=True, exist_ok=True)
    
    y, sr = librosa.load(audio_file, sr=44100)

    mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, fmin=fmin, fmax=fmax)


    mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)

    return relative_path, mel_spectrogram, sr, target_dir, audio_file.stem

def generate_mel_spectrograms_with_structure(input_dir, output_dir, n_mels=256, fmin=20, fmax=44100, batch_size=25, square = False):
    """
    Optimierte Funktion für die Verarbeitung von Mel-Spektrogrammen:
    1. Berechnung wird parallelisiert.
    2. Ergebnisse werden sequentiell geplottet, um Thread-Sicherheitsprobleme zu vermeiden.
    3. Batches werden verwendet, um den Speicherverbrauch zu kontrollieren.
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)

    if output_dir.exists() and any(output_dir.rglob("*.png")):
        print(f"Überspringe Verarbeitung, da {output_dir} bereits Mel-Spektrogramme enthält.")
        return

    audio_files = list(input_dir.rglob("*.wav"))

    if not audio_files:
        print("Keine Audiodateien gefunden.")
        return

    total_files = len(audio_files)
    print(f"{total_files} Audiodateien gefunden. Verarbeitung startet.")

    for batch_start in range(0, total_files, batch_size):
        batch_files = audio_files[batch_start:batch_start + batch_size]
        plt.cla()
        plt.clf()
        plt.close('all')
        gc.collect()
        print(f"Verarbeite Batch {batch_start // batch_size + 1} von {total_files // batch_size + 1}")

        results = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(process_audio_file, audio_file, input_dir, output_dir, n_mels, fmin, fmax)
                for audio_file in batch_files
            ]
            for future in concurrent.futures.as_completed(futures):
                results.append(future.result())

        for relative_path, mel_spectrogram_db, sr, target_dir, audio_file_stem in results:
            mel_spectrogram_path = target_dir / f"{audio_file_stem}_mel_spectrogram.png"

            if mel_spectrogram_path.exists():
                print(f"Spektrogramm {mel_spectrogram_path} existiert bereits. Überspringen.")
                continue

            try:
                if square:
                    plt.figure(figsize=(2, 2))
                    librosa.display.specshow(mel_spectrogram_db, x_axis='time', y_axis='mel', sr=sr, cmap='magma', fmin=fmin, fmax=fmax)
                    plt.axis('off')

                    plt.savefig(mel_spectrogram_path, bbox_inches='tight', pad_inches=0, dpi=300)
                    plt.close()
                    print(f"{batch_start // batch_size + 1} von {total_files // batch_size + 1}__Mel-Spektrogramm für {audio_file_stem} gespeichert in {mel_spectrogram_path}")
                else:
                    height = 333 if downsize else 2000 
                    width = height * 30 
                    dpi = 100  
                    figsize = (width / dpi, height / dpi)

                    fig, ax = plt.subplots(figsize=figsize, dpi=dpi, frameon=False)

                    librosa.display.specshow(mel_spectrogram_db, x_axis='time', y_axis='mel', sr=sr, cmap='magma', fmin=fmin, fmax=fmax)
                    ax.set_axis_off()

                    plt.savefig(mel_spectrogram_path, bbox_inches='tight', pad_inches=0, dpi=dpi)
                    plt.close(fig)
                    print(f"{batch_start // batch_size + 1} von {total_files // batch_size + 1}__Mel-Spektrogramm für {audio_file_stem} gespeichert in {mel_spectrogram_path}")

            except Exception as e:
                print(f"Fehler beim Plotten von {audio_file_stem}: {e}")
            finally:
                del mel_spectrogram_db

    print(f"Alle Mel-Spektrogramme gespeichert in {output_dir}")

In [None]:
def split_spectrogram(image_path, output_dir):
    """
    Schneidet ein Spektrogramm in gleich große Quadrate.
    :param image_path: Pfad zum Spektrogramm (PNG)
    :param output_dir: Ordner zum Speichern der Segmente
    :param segment_size: Größe jedes quadratischen Segments (Standard: 924x924)
    """
    os.makedirs(output_dir, exist_ok=True)
    trash_dir = output_dir.parent.parent / f"{output_dir.parent.name}_trash"
    os.makedirs(trash_dir, exist_ok=True)

    img = Image.open(image_path)
    width, height = img.size
    segment_size = height

    num_segments = width // segment_size - 1
    for i in range(1, num_segments):
        left = i * segment_size
        right = left + segment_size
        segment = img.crop((left, 0, right, segment_size))
        segment_array = np.array(segment)

        silence_threshold = 10  
        silence_ratio = np.mean(segment_array < silence_threshold)

        transparency_ratio = 0
        if segment.mode == 'RGBA':
            alpha_channel = segment_array[:, :, 3] 
            transparency_ratio = np.mean(alpha_channel == 0)  

        output_path = Path(output_dir) / f"{Path(image_path).stem}_part{i}.png"
        if silence_ratio >= 0.7 or transparency_ratio > 0.2:
            output_path = trash_dir / f"{Path(image_path).stem}_part{i}.png"
            print(f"Segment {i} enthält {silence_ratio * 100:.2f}% Stille und wird in den Trash {output_path} verschoben.")
        
        try:
            segment.save(output_path)
        except IOError:
            print(f"Fehler beim Speichern des Segments: {output_path}")
            continue

    print(f"Spektrogramm in {num_segments} Segmente geschnitten und gespeichert in {output_dir}")


def process_spectrograms(input_dir):
    """
    Verarbeitet Spektrogramme im angegebenen Verzeichnis, ohne Unterverzeichnisse zu durchsuchen.
    Jedes Spektrogramm wird in einem eigenen Unterordner gespeichert.
    :param input_dir: Verzeichnis mit Spektrogrammen
    """
    input_dir = Path(input_dir)
    output_dir = input_dir.parent / f"{input_dir.name}_splits"
    
    if output_dir.exists():
        print(f"Output directory already exists: {output_dir}. Skipping splitting.")
        return output_dir

    os.makedirs(output_dir, exist_ok=True)

    for image_path in input_dir.glob("*.png"): 
        subdir_name = image_path.stem.replace("_mel_spectrogram", "")
        target_dir = output_dir / subdir_name  
        os.makedirs(target_dir, exist_ok=True)

        try:
            img = Image.open(image_path)
            img.verify() 
            split_spectrogram(image_path, target_dir) 
        except (IOError, SyntaxError):
            print(f"Fehler: Beschädigtes oder ungültiges Bild übersprungen: {image_path}")
            continue


        if not any(target_dir.iterdir()):
            os.rmdir(target_dir)
            print(f"Leerer Ordner {target_dir} wurde gelöscht.")

    return output_dir

In [None]:
def color_table(table, df=None):
    for (i, j), cell in table.get_celld().items():
        if i == 0: 
            cell.set_fontsize(12)
            cell.set_text_props(weight='bold')
            cell.set_facecolor('#d3d3d3')  

        elif df is not None and i > 0 and j < len(df.columns):  
            if j == 0:  
                cell.set_facecolor('#f0f8ff')  

            elif j == 1 or j == 2:  
                true_class = df.iloc[i - 1]["True Class"]  
                predicted_class = df.iloc[i - 1]["Predicted Class"]

  
                if true_class == predicted_class:
                    cell.set_facecolor('#d4edda')  
                else:
                    cell.set_facecolor('#f8d7da') 
                cell.set_fontsize(10)

            else: 
                try:
                    if float(cell.get_text().get_text()) < 0.5:
                        cell.set_facecolor('#f8d7da')  
                    else:
                        cell.set_facecolor('#d4edda')  
                except ValueError:
                    pass 
            cell.set_fontsize(10)

    return table



def generate_summary(results):
    total_predictions = len(results)
    correct_predictions = sum(1 for result in results if result['True Class'].split()[0] == result['Predicted Class'].split()[0]) 
    accuracy = correct_predictions / total_predictions * 100

    correct_original = sum(1 for result in results if result['True Class'].split()[0] == 'original' and result['Predicted Class'].split()[0] == 'original')
    correct_upscale = sum(1 for result in results if result['True Class'].split()[0] == 'upscale-from-mp3-128' and result['Predicted Class'].split()[0] == 'upscale-from-mp3-128')
    
    original_accuracy = correct_original / sum(1 for result in results if result['True Class'].split()[0] == 'original') * 100 if sum(1 for result in results if result['True Class'].split()[0] == 'original') > 0 else 0
    upscale_accuracy = correct_upscale / sum(1 for result in results if result['True Class'].split()[0] == 'upscale-from-mp3-128') * 100 if sum(1 for result in results if result['True Class'].split()[0] == 'upscale-from-mp3-128') > 0 else 0

    summary = {
        "Total Correct": f"{correct_predictions} / {total_predictions} ({accuracy:.2f}%)",
        "Original Accuracy": f"{original_accuracy:.2f}%",
        "Upscale Accuracy": f"{upscale_accuracy:.2f}%",
        "Overall Accuracy": f"{accuracy:.2f}%",
    }

    return summary

def display_summary(summary):
    summary_data = [
        ["Total Correct", summary["Total Correct"]],
        ["Original Accuracy", summary["Original Accuracy"]],
        ["Upscale Accuracy", summary["Upscale Accuracy"]],
        ["Overall Accuracy", summary["Overall Accuracy"]]
    ]

    summary_df = pd.DataFrame(summary_data, columns=[f"Parameter", "Value"])

    fig, ax = plt.subplots(figsize=(8, 2))
    ax.axis('tight')
    ax.axis('off')

    table = ax.table(cellText=summary_df.values, colLabels=summary_df.columns, cellLoc="center", loc="center", colWidths=[0.5, 0.5])
    
    color_table(table)

    plt.show()

    return fig

def calculate_accuracy(results):
    """
    Berechnet die Anzahl korrekt und falsch klassifizierter Dateien basierend auf den Ergebnissen.
    Erwartet eine Liste von Dictionaries mit den Schlüsseln 'True Class' und 'Predicted Class'.
    """
    correct_predictions = sum(1 for result in results if result['True Class'] == result['Predicted Class'])
    total_predictions = len(results)
    accuracy = correct_predictions / total_predictions * 100
    print(f"Korrekte Vorhersagen: {correct_predictions}")
    print(f"Falsche Vorhersagen: {total_predictions - correct_predictions}")
    print(f"Genauigkeit: {accuracy:.2f}%")


def save_results(zip_file_path, table_figure, summary_figure):
    base_name = os.path.splitext(os.path.basename(zip_file_path))[0]
    output_dir = os.path.dirname(zip_file_path)

    table_image_path = os.path.join(output_dir, f"{base_name}_table.png")
    summary_image_path = os.path.join(output_dir, f"{base_name}_summary.png")

    if not os.path.exists(table_image_path):
        table_figure.savefig(table_image_path, bbox_inches="tight")
        print(f"Haupttabelle als Bild gespeichert: {table_image_path}")
    else:
        print(f"Haupttabelle übersprungen (bereits vorhanden): {table_image_path}")

    if not os.path.exists(summary_image_path):
        summary_figure.savefig(summary_image_path, bbox_inches="tight")
        print(f"Zusammenfassungstabelle als Bild gespeichert: {summary_image_path}")
    else:
        print(f"Zusammenfassungstabelle übersprungen (bereits vorhanden): {summary_image_path}")


def load_model_and_session(zip_file_path):
    """Lädt das Modell und die Session-Daten aus einer ZIP-Datei."""
    extract_path = "restored_model"

    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)

    model_path = f"{extract_path}/model.h5"
    loaded_model = tf.keras.models.load_model(model_path)

    session_data_path = f"{extract_path}/session_data.json"
    if os.path.exists(session_data_path):
        try:
            with open(session_data_path, "r", encoding="utf-8") as json_file:
                session_data = json.load(json_file)
        except json.JSONDecodeError:
            print("Fehler: Die JSON-Datei ist fehlerhaft oder beschädigt.")
            session_data = {} 
        except Exception as e:
            print(f"Ein unerwarteter Fehler ist aufgetreten: {e}")
            session_data = {}
    else:
        print("Warnung: Die Datei existiert nicht.")
        session_data = {}

    return loaded_model, session_data



def predict_with_model(model, splits_dir):
    """
    Durchläuft alle Spektrogramm-Splits eines Tracks (Ordners),
    berechnet die Vorhersagen und gibt den Durchschnitt zurück.
    
    :param model: Das geladene Modell für die Vorhersage
    :param splits_dir: Pfad zum Ordner mit den Spektrogramm-Splits
    :return: Durchschnittliche Vorhersage für den gesamten Track
    """
    splits_dir = Path(splits_dir)
    if not splits_dir.exists() or not splits_dir.is_dir():
        raise ValueError(f"Das Verzeichnis {splits_dir} existiert nicht oder ist kein Verzeichnis.")
    
    input_shape = model.input_shape 
    _, target_height, target_width, _ = input_shape  
    print(f"Erwartete Bildgröße: {target_height} x {target_width}")
    
    predictions = []
    
    for image_path in splits_dir.glob("*.png"):  
        image = Image.open(image_path).convert("RGB") 
        if downsize:
            image = image.resize((256, 256))  
        else: 
            image = image.resize((target_width, target_height))

        image_array = np.array(image)  
        input_data = tf.convert_to_tensor(image_array, dtype=tf.float32)
        input_data = np.expand_dims(image_array, axis=0)  
        
        pred = model.predict(input_data)[0]
        print(f"Vorhersage für {image_path}: {pred}")
        predictions.append(pred)
    
    if not predictions:
        raise ValueError(f"Keine validen Spektrogramm-Splits in {splits_dir} gefunden.")
    
    avg_prediction = np.mean(predictions, axis=0)  
    print(f"Durchschnittliche Vorhersage für {splits_dir.name}: {avg_prediction}")
    
    return avg_prediction


def run(zip_file_path = None, _downsize = False):
    extract_zip(TEST_DIR, UNZIP_DIR)
    global session, downsize, _zip_file_path
    downsize = _downsize
    if zip_file_path is None:
        zip_file_path = _zip_file_path

    model, session = load_model_and_session(zip_file_path)
    print(f"Model loaded from: {zip_file_path}")
    print (f"Session : {session}")
    
    audio_dir = TEST_DIR

    path = normalize_audio_length(Path(audio_dir))
    test_mel_dir = Path(f"{path.stem}_mel_spectrograms" + ("_downsize" if downsize else ""))
    generate_mel_spectrograms_with_structure(path,test_mel_dir)
    test_mel_dir = process_spectrograms(test_mel_dir)

    results = []
    class_names = ["original", "upscale-from-mp3-128"]
    
    for track_dir in Path(test_mel_dir).iterdir():
        if track_dir.is_dir():
            if not any(track_dir.iterdir()):
                print(f"Überspringe leeres Verzeichnis: {track_dir}")
                continue
            
            if "orig-16-44-mono" in track_dir.name:
                true_class = "original"
            elif "upscale-from-mp3-128" in track_dir.name:
                true_class = "upscale-from-mp3-128"
            elif "upscale-from-aac-128" in track_dir.name:
                true_class = "upscale-from-aac-128"
            else:
                raise ValueError(f"Unbekannte Klasse im Ordnernamen: {track_dir.name}")
            
            predictions = predict_with_model(model, track_dir)
            predicted_class_index = np.argmax(predictions)
            predicted_class = class_names[predicted_class_index]
            class_probabilities = {class_names[i]: predictions[i] for i in range(len(predictions))}
            
            true_class_with_prob = f"{true_class} ({class_probabilities.get(true_class, 0.0):.4f})"
            predicted_class_with_prob = f"{predicted_class} ({class_probabilities[predicted_class]:.4f})"
            
            results.append({
                "Track_Dir": track_dir.name,
                "True Class": true_class_with_prob,
                "Predicted Class": predicted_class_with_prob
            })
           
    df = pd.DataFrame(results)


    fig, ax = plt.subplots(figsize=(16, len(df) * 0.4))
    ax.axis('tight')
    ax.axis('off')

    cell_text = [
        [str(value) if isinstance(value, str) else f"{value:.4f}" for value in row]
        for row in df.values
    ]

    table = ax.table(
        cellText=cell_text,
        colLabels=df.columns,
        cellLoc="center",
        loc="center",
        colWidths=[0.5] + [0.25] * (len(df.columns) - 1)
    )
    
    for row in range(len(df)):
        cell = table[(row + 1, 0)]
        cell.get_text().set_ha('left')

    table = color_table(table, df)

    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.2) 
    plt.show()

    
    print(results)
    summary = generate_summary(results)
    summary_figure = display_summary(summary)

    save_results(zip_file_path, fig, summary_figure)
    shutil.rmtree("restored_model")


run()