<a href="https://colab.research.google.com/github/roccaab/WaveletGAN/blob/main/mainGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install PyWavelets
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
import pywt
import matplotlib.pyplot as plt

# Funzione per normalizzare i segnali
def normalize_signal(signal):
    return (signal - np.mean(signal)) / np.std(signal)

# Funzione per calcolare le wavelet
def wavelet_decompose(signal, wavelet='db4', level=6):
    coeffs = pywt.wavedec(signal, wavelet, level=level)
    return coeffs

# Costruzione del generatore
def build_generator(input_dim):
    model = tf.keras.Sequential([
        layers.Dense(128, activation='relu', input_dim=input_dim),
        layers.BatchNormalization(),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dense(input_dim, activation='tanh')
    ])
    return model

# Costruzione del discriminatore per una singola componente wavelet
def build_discriminator(input_dim):
    model = tf.keras.Sequential([
        layers.Dense(256, activation='relu', input_dim=input_dim),
        layers.Dropout(0.3),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

# Configurazione della GAN
class WaveletGAN:
    def __init__(self, input_dim, num_discriminators=6):
        self.input_dim = input_dim
        self.num_discriminators = num_discriminators
        self.generator = build_generator(input_dim)
        self.discriminators = [build_discriminator(input_dim) for _ in range(num_discriminators)]

        self.gan_optimizers = [tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5) for _ in range(num_discriminators)]
        self.loss = tf.keras.losses.BinaryCrossentropy()

        self.compile_models()

    def compile_models(self):
        for discriminator in self.discriminators:
            discriminator.compile(optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
                                  loss=self.loss, metrics=['accuracy'])

        # Congela i discriminatori e compila la GAN
        for disc in self.discriminators:
            disc.trainable = False

        inputs = layers.Input(shape=(self.input_dim,))
        generated = self.generator(inputs)

        discriminator_outputs = [disc(generated) for disc in self.discriminators]

        self.combined = Model(inputs, discriminator_outputs)
        self.combined.compile(optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
                              loss=self.loss)

    def train(self, signals, epochs=10000, batch_size=64):
        for epoch in range(epochs):
            idx = np.random.randint(0, signals.shape[0], batch_size)
            real_signals = signals[idx]

            # Genera segnali falsi
            noise = np.random.normal(0, 1, (batch_size, self.input_dim))
            generated_signals = self.generator.predict(noise)

            # Etichette per il training
            valid = np.ones((batch_size, 1))
            fake = np.zeros((batch_size, 1))

            for i, discriminator in enumerate(self.discriminators):
                real_wavelet = np.array([wavelet_decompose(signal)[i] for signal in real_signals])
                fake_wavelet = np.array([wavelet_decompose(signal)[i] for signal in generated_signals])

                d_loss_real = discriminator.train_on_batch(real_wavelet, valid)
                d_loss_fake = discriminator.train_on_batch(fake_wavelet, fake)

            # Addestramento della GAN
            noise = np.random.normal(0, 1, (batch_size, self.input_dim))
            g_loss = self.combined.train_on_batch(noise, [valid] * self.num_discriminators)

            # Log degli epoch
            if epoch % 100 == 0:
                print(f"Epoch {epoch}/{epochs} | D Loss: {d_loss_real[0]:.4f}, G Loss: {g_loss:.4f}")

# Esempio di caricamento dati e utilizzo
file_names = [f"BHE{str(i).zfill(4)}.csv" for i in range(1, 11)]
data = []

for file in file_names:
    signal = np.loadtxt(file, delimiter=',')
    signal = normalize_signal(signal)
    data.append(signal)

data = np.array(data)
input_dim = data.shape[1]

# Creazione e addestramento della GAN
wavelet_gan = WaveletGAN(input_dim)
wavelet_gan.train(data, epochs=1000, batch_size=32)

# Visualizzazione delle reti
wavelet_gan.generator.summary()
for i, discriminator in enumerate(wavelet_gan.discriminators):
    print(f"Discriminatore {i+1}:")
    discriminator.summary()


Collecting PyWavelets
  Downloading pywavelets-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.0 kB)
Downloading pywavelets-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: PyWavelets
Successfully installed PyWavelets-1.8.0


FileNotFoundError: BHE0001.csv not found.