In [4]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from scipy.io import wavfile
from sklearn.preprocessing import MinMaxScaler

class AudioDataset:
    def __init__(self, audio_dir, annotation_dir):
        self.audio_dir = audio_dir
        self.annotation_dir = annotation_dir
        self.audio_files = self._get_matched_files()
    
    def _get_matched_files(self):
        """
        Find audio files with corresponding annotation files
        """
        matched_files = []
        for audio_file in os.listdir(self.audio_dir):
            base_name = os.path.splitext(audio_file)[0]
            annotation_path = os.path.join(self.annotation_dir, base_name + '.txt')
            
            if os.path.exists(annotation_path):
                matched_files.append(os.path.join(self.audio_dir, audio_file))
        
        return matched_files
    
    def load_audio(self, file_path, target_sr=16000, max_duration=3):
        """
        Load and preprocess audio file
        """
        waveform, sample_rate = torchaudio.load(file_path)
        
        # Resample if necessary
        if sample_rate != target_sr:
            resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
            waveform = resampler(waveform)
        
        # Trim or pad to fixed length
        max_length = max_duration * target_sr
        if waveform.shape[1] > max_length:
            waveform = waveform[:, :max_length]
        elif waveform.shape[1] < max_length:
            pad_length = max_length - waveform.shape[1]
            waveform = nn.functional.pad(waveform, (0, pad_length))
        
        return waveform.squeeze()
    
    def prepare_dataset(self):
        """
        Prepare normalized audio data
        """
        audio_data = []
        for file in self.audio_files:
            audio = self.load_audio(file)
            audio_data.append(audio.numpy())
        
        # Normalize
        scaler = MinMaxScaler(feature_range=(-1, 1))
        normalized_data = scaler.fit_transform(np.array(audio_data).reshape(-1, 1)).reshape(-1)
        
        return torch.tensor(normalized_data, dtype=torch.float32)

class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

def train_gan(dataset, latent_dim=100, epochs=200, batch_size=64):
    """
    Train GAN for synthetic audio generation
    """
    generator = Generator(latent_dim, dataset.shape[1])
    discriminator = Discriminator(dataset.shape[1])
    
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    criterion = nn.BCELoss()
    
    for epoch in range(epochs):
        # Train discriminator
        d_optimizer.zero_grad()
        real_data = dataset
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z)
        
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        real_loss = criterion(discriminator(real_data), real_labels)
        fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
        d_loss = real_loss + fake_loss
        
        d_loss.backward()
        d_optimizer.step()
        
        # Train generator
        g_optimizer.zero_grad()
        z = torch.randn(batch_size, latent_dim)
        fake_data = generator(z)
        g_loss = criterion(discriminator(fake_data), real_labels)
        
        g_loss.backward()
        g_optimizer.step()
    
    return generator

def save_generated_audio(generator, output_dir, num_samples=10, sample_rate=16000):
    """
    Generate and save synthetic audio samples
    """
    os.makedirs(output_dir, exist_ok=True)
    
    with torch.no_grad():
        z = torch.randn(num_samples, 100)
        generated_audio = generator(z).numpy()
        
        for i, audio in enumerate(generated_audio):
            wavfile.write(
                os.path.join(output_dir, f'synthetic_audio_{i}.wav'), 
                sample_rate, 
                audio
            )

def main(audio_dir, annotation_dir, output_dir):
    dataset = AudioDataset(audio_dir, annotation_dir)
    processed_data = dataset.prepare_dataset()
    
    generator = train_gan(processed_data)
    save_generated_audio(generator, output_dir)

if __name__ == "__main__":
    main('./cleaned_wav_files', 
         './JSON', 
         './dummy_audios')

ValueError: Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required by MinMaxScaler.

In [3]:
os.listdir('./cleaned_wav_files')

['Voice 007_sd.wav',
 'R1.wav',
 'Voice 004.wav',
 'omg dj.wav',
 'practice .wav',
 'Voice 006.wav',
 'Voice 001_sd (1).wav',
 'Sep 25, 1.56 sa.wav',
 'R3.wav',
 'R2.wav',
 'Voice 007.wav',
 'hjjjhhggh.wav',
 'tamanna.wav',
 'Voice 003.wav',
 'Oct 8, 1.45 PM.wav',
 'R7.wav',
 'Voice 002.wav',
 'cricket toppic.wav',
 'Sep 26, 1.30 PM.wav',
 'Voice 009_sd.wav',
 'R5.wav',
 'R4.wav',
 'Voice 001.wav',
 'Voice 001 (2).wav',
 'Account.wav',
 'kerala.wav',
 'Voice 001_sd.wav',
 'R10.wav',
 'REC005.wav',
 'Voice 002 (2).wav',
 'progga.wav',
 'Oct 20, 12.28 PM.wav',
 'English 1st conversation.wav',
 'SS practice.wav',
 '2024-02-27 21-05-26.wav',
 'Voice 002 (1).wav',
 'keralas calture.wav',
 'New Recording 6.wav',
 'Recode.wav',
 'Voice 008_sd.wav',
 'Mirpur Road.wav',
 'summery.wav',
 'Oct hhjh.wav',
 'R9.wav',
 'the scientist.wav',
 'R8.wav',
 'Voice 001 (1).wav',
 'sanjid meyad.wav',
 'Oct 7, 9.34 AM.wav',
 'Recording_2.wav',
 'Sep 23, 9.24 AM.wav',
 'Voice 010_sd.wav',
 'Recording_1.wav',


In [7]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from scipy.io import wavfile
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

class AudioDataset:
    def __init__(self, audio_dir, annotation_dir):
        self.audio_dir = audio_dir
        self.annotation_dir = annotation_dir
        self.audio_files = self._get_matched_files()
        
        if not self.audio_files:
            raise ValueError(f"No matched audio files found in {audio_dir}")
    
    def _get_matched_files(self):
        """Find audio files with corresponding annotation files"""
        matched_files = []
        for audio_file in os.listdir(self.audio_dir):
            base_name = os.path.splitext(audio_file)[0]
            annotation_path = os.path.join(self.annotation_dir, base_name + '.json')
            
            if os.path.exists(annotation_path):
                matched_files.append(os.path.join(self.audio_dir, audio_file))
        
        return matched_files
    
    def load_audio(self, file_path, target_sr=16000, max_duration=3):
        """Load and preprocess individual audio file"""
        try:
            waveform, sample_rate = torchaudio.load(file_path)
            
            # Convert to mono if stereo
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            # Resample if necessary
            if sample_rate != target_sr:
                resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
                waveform = resampler(waveform)
            
            # Trim or pad to fixed length
            max_length = max_duration * target_sr
            if waveform.shape[1] > max_length:
                waveform = waveform[:, :max_length]
            elif waveform.shape[1] < max_length:
                pad_length = max_length - waveform.shape[1]
                waveform = nn.functional.pad(waveform, (0, pad_length))
            
            return waveform.squeeze()
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            return None
    
    def prepare_dataset(self):
        """Prepare normalized audio data"""
        audio_data = []
        for file in self.audio_files:
            audio = self.load_audio(file)
            if audio is not None:
                audio_data.append(audio.numpy())
        
        if not audio_data:
            raise ValueError("No valid audio files could be loaded")
        
        # Reshape and normalize
        audio_data = np.array(audio_data)
        audio_data = audio_data.reshape(audio_data.shape[0], -1)
        
        scaler = MinMaxScaler(feature_range=(-1, 1))
        normalized_data = scaler.fit_transform(audio_data)
        
        return {
            'data': torch.tensor(normalized_data, dtype=torch.float32),
            'scaler': scaler
        }

class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

def train_gan(dataset, latent_dim=100, epochs=200, batch_size=64):
    """Train GAN for synthetic audio generation"""
    # Split data
    X_train, X_test = train_test_split(dataset['data'], test_size=0.2)
    
    generator = Generator(latent_dim, X_train.shape[1])
    discriminator = Discriminator(X_train.shape[1])
    
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    criterion = nn.BCELoss()
    
    for epoch in range(epochs):
        for i in range(0, len(X_train), batch_size):
            batch = X_train[i:i+batch_size]
            
            # Train discriminator
            d_optimizer.zero_grad()
            real_data = batch
            z = torch.randn(len(batch), latent_dim)
            fake_data = generator(z)
            
            real_labels = torch.ones(len(batch), 1)
            fake_labels = torch.zeros(len(batch), 1)
            
            real_loss = criterion(discriminator(real_data), real_labels)
            fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
            d_loss = real_loss + fake_loss
            
            d_loss.backward()
            d_optimizer.step()
            
            # Train generator
            g_optimizer.zero_grad()
            z = torch.randn(len(batch), latent_dim)
            fake_data = generator(z)
            g_loss = criterion(discriminator(fake_data), torch.ones(len(batch), 1))
            
            g_loss.backward()
            g_optimizer.step()
        
        # Optional: Print epoch loss
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: D Loss = {d_loss.item()}, G Loss = {g_loss.item()}")
    
    return generator, dataset['scaler']

def save_generated_audio(generator, scaler, output_dir, num_samples=10, sample_rate=16000):
    """Generate and save synthetic audio samples"""
    os.makedirs(output_dir, exist_ok=True)
    
    with torch.no_grad():
        z = torch.randn(num_samples, 100)
        generated_audio = generator(z).numpy()
        
        # Inverse transform to original scale
        generated_audio = scaler.inverse_transform(generated_audio)
        
        for i, audio in enumerate(generated_audio):
            # Normalize to int16 range
            audio = audio.reshape(-1)
            audio = (audio / np.max(np.abs(audio)) * 32767).astype(np.int16)
            
            wavfile.write(
                os.path.join(output_dir, f'synthetic_audio_{i}.wav'), 
                sample_rate, 
                audio
            )

def main(audio_dir, annotation_dir, output_dir):
    # Initialize dataset
    dataset = AudioDataset(audio_dir, annotation_dir)
    
    # Prepare data
    processed_data = dataset.prepare_dataset()
    
    # Train GAN
    generator, scaler = train_gan(processed_data)
    
    # Generate and save synthetic audio
    save_generated_audio(generator, scaler, output_dir)

if __name__ == "__main__":
    main('./cleaned_wav_files', 
         './JSON', 
         './dummy_audios')

Epoch 0: D Loss = 1.390404462814331, G Loss = 1.435276746749878
Epoch 10: D Loss = 1.5485504865646362, G Loss = 1.465993046760559
Epoch 20: D Loss = 0.6792703866958618, G Loss = 2.950965642929077
Epoch 30: D Loss = 0.19791191816329956, G Loss = 4.28369140625
Epoch 40: D Loss = 1.5522111654281616, G Loss = 1.8755478858947754
Epoch 50: D Loss = 1.364415168762207, G Loss = 0.8510583639144897
Epoch 60: D Loss = 0.8493151068687439, G Loss = 1.168866515159607
Epoch 70: D Loss = 0.6472721695899963, G Loss = 2.6100823879241943
Epoch 80: D Loss = 0.27620455622673035, G Loss = 2.524670362472534
Epoch 90: D Loss = 0.2219024896621704, G Loss = 2.7784674167633057
Epoch 100: D Loss = 0.3148179054260254, G Loss = 2.6815073490142822
Epoch 110: D Loss = 0.3387452960014343, G Loss = 2.319704294204712
Epoch 120: D Loss = 0.7214686274528503, G Loss = 3.800121784210205
Epoch 130: D Loss = 0.44992297887802124, G Loss = 1.609717845916748
Epoch 140: D Loss = 0.20898646116256714, G Loss = 3.567080020904541
Epo

In [1]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import noisereduce as nr
from scipy.io import wavfile
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import librosa
import soundfile as sf

class AudioProcessor:
    def __init__(self, sr=16000):
        self.sr = sr
        self.pause_detector = self._initialize_pause_detector()
    
    def _initialize_pause_detector(self):
        keywords = ['aaah', 'mmm', 'hmm', 'uhh', 'umm', 'err']
        return {word: librosa.sequence.dtw for word in keywords}
    
    def reduce_noise(self, audio):
        reduced_noise = nr.reduce_noise(
            y=audio,
            sr=self.sr,
            prop_decrease=0.95,
            n_fft=2048,
            win_length=2048,
            hop_length=512
        )
        return reduced_noise
    
    def detect_pauses(self, audio):
        mfcc = librosa.feature.mfcc(y=audio, sr=self.sr, n_mfcc=13)
        segments = librosa.effects.split(audio, top_db=20)
        
        pauses = []
        for start, end in segments:
            segment = audio[start:end]
            segment_mfcc = librosa.feature.mfcc(y=segment, sr=self.sr, n_mfcc=13)
            
            is_filled_pause = False
            for keyword, dtw in self.pause_detector.items():
                if len(segment) > self.sr * 0.1:  # Longer than 100ms
                    is_filled_pause = True
                    break
            
            pauses.append({
                'start': float(start / self.sr),
                'end': float(end / self.sr),
                'type': 'filled_pause' if is_filled_pause else 'non_filled_pause'
            })
        
        return pauses

class AudioDataset:
    def __init__(self, audio_dir, annotation_dir, target_sr=16000, max_duration=3):
        self.audio_dir = audio_dir
        self.annotation_dir = annotation_dir
        self.target_sr = target_sr
        self.max_duration = max_duration
        self.audio_processor = AudioProcessor(sr=target_sr)
        self.audio_files = self._get_matched_files()
    
    def load_and_process_audio(self, file_path):
        try:
            waveform, sr = torchaudio.load(file_path)
            
            # Convert to mono
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            # Resample if needed
            if sr != self.target_sr:
                waveform = torchaudio.transforms.Resample(sr, self.target_sr)(waveform)
            
            # Convert to numpy for noise reduction
            audio_np = waveform.squeeze().numpy()
            cleaned_audio = self.audio_processor.reduce_noise(audio_np)
            
            # Process fixed length
            max_length = self.max_duration * self.target_sr
            if len(cleaned_audio) > max_length:
                cleaned_audio = cleaned_audio[:max_length]
            else:
                pad_length = max_length - len(cleaned_audio)
                cleaned_audio = np.pad(cleaned_audio, (0, pad_length))
            
            return cleaned_audio
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            return None
    
    def prepare_dataset(self, batch_size=32):
        processed_data = []
        annotations = []
        
        for file in self.audio_files:
            if audio := self.load_and_process_audio(file):
                processed_data.append(audio)
                pauses = self.audio_processor.detect_pauses(audio)
                annotations.append(pauses)
        
        if not processed_data:
            raise ValueError("No valid audio files could be processed")
        
        processed_data = np.array(processed_data)
        scaler = MinMaxScaler(feature_range=(-1, 1))
        normalized_data = scaler.fit_transform(processed_data.reshape(len(processed_data), -1))
        
        return self._create_dataloaders(normalized_data, annotations, batch_size), scaler

def save_generated_audio_with_annotations(generator, scaler, output_dir, num_samples=10, 
                                        sample_rate=16000, device='cuda'):
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'annotations'), exist_ok=True)
    
    audio_processor = AudioProcessor(sr=sample_rate)
    
    with torch.no_grad():
        z = torch.randn(num_samples, generator.model[0].in_features, device=device)
        generated_audio = generator(z).cpu().numpy()
        generated_audio = scaler.inverse_transform(generated_audio)
        
        for i, audio in enumerate(generated_audio):
            # Save audio
            audio = audio.reshape(-1)
            audio = np.clip(audio, -1, 1)
            audio_int16 = (audio * 32767).astype(np.int16)
            
            audio_path = os.path.join(output_dir, f'synthetic_audio_{i}.wav')
            wavfile.write(audio_path, sample_rate, audio_int16)
            
            # Generate and save annotations
            pauses = audio_processor.detect_pauses(audio)
            annotation = {
                'audio_file': f'synthetic_audio_{i}.wav',
                'sample_rate': sample_rate,
                'duration': len(audio) / sample_rate,
                'pauses': pauses
            }
            
            json_path = os.path.join(output_dir, 'annotations', f'synthetic_audio_{i}.json')
            with open(json_path, 'w') as f:
                json.dump(annotation, f, indent=2)

# [Previous WGAN and other model classes remain the same]

def main(audio_dir, annotation_dir, output_dir, device='cuda'):
    dataset = AudioDataset(audio_dir, annotation_dir)
    processed_data, scaler = dataset.prepare_dataset()
    generator, _ = train_gan(processed_data, device=device)
    save_generated_audio_with_annotations(generator, scaler, output_dir, device=device)

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    main('./cleaned_wav_files', './JSON', './generated_audio', device)



AttributeError: 'AudioDataset' object has no attribute '_get_matched_files'

In [None]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import noisereduce as nr
from scipy.io import wavfile
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import librosa
import soundfile as sf

class AudioProcessor:
    def __init__(self, sr=16000):
        self.sr = sr
        self.pause_detector = self._initialize_pause_detector()
    
    def _initialize_pause_detector(self):
        keywords = ['aaah', 'mmm', 'hmm', 'uhh', 'umm', 'err']
        return {word: librosa.sequence.dtw for word in keywords}
    
    def reduce_noise(self, audio):
        return nr.reduce_noise(
            y=audio,
            sr=self.sr,
            prop_decrease=0.95,
            n_fft=2048,
            win_length=2048,
            hop_length=512
        )
    
    def detect_pauses(self, audio):
        segments = librosa.effects.split(audio, top_db=20)
        pauses = []
        
        for start, end in segments:
            segment = audio[start:end]
            is_filled_pause = len(segment) > self.sr * 0.1
            
            pauses.append({
                'start': float(start / self.sr),
                'end': float(end / self.sr),
                'type': 'filled_pause' if is_filled_pause else 'non_filled_pause'
            })
        
        return pauses

class AudioDataset:
    def __init__(self, audio_dir, annotation_dir, target_sr=16000, max_duration=3):
        self.audio_dir = audio_dir
        self.annotation_dir = annotation_dir
        self.target_sr = target_sr
        self.max_duration = max_duration
        self.audio_processor = AudioProcessor(sr=target_sr)
        self.audio_files = self._get_matched_files()
        
        if not self.audio_files:
            raise ValueError(f"No matched audio files found in {audio_dir}")
    
    def _get_matched_files(self):
        audio_files = []
        for audio_file in os.listdir(self.audio_dir):
            base_name = os.path.splitext(audio_file)[0]
            annotation_path = os.path.join(self.annotation_dir, f"{base_name}.json")
            
            if os.path.exists(annotation_path):
                audio_files.append(os.path.join(self.audio_dir, audio_file))
        return audio_files
    
    def load_and_process_audio(self, file_path):
        try:
            waveform, sr = torchaudio.load(file_path)
            
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            if sr != self.target_sr:
                waveform = torchaudio.transforms.Resample(sr, self.target_sr)(waveform)
            
            audio_np = waveform.squeeze().numpy()
            
            if len(audio_np) > 0:
                cleaned_audio = self.audio_processor.reduce_noise(audio_np)
                
                max_length = self.max_duration * self.target_sr
                if len(cleaned_audio) > max_length:
                    cleaned_audio = cleaned_audio[:max_length]
                else:
                    pad_length = max_length - len(cleaned_audio)
                    cleaned_audio = np.pad(cleaned_audio, (0, pad_length))
                
                return cleaned_audio
            return None
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            return None
    
    def prepare_dataset(self, batch_size=32):
        processed_data = []
        annotations = []
        
        for file in self.audio_files:
            audio = self.load_and_process_audio(file)
            if audio is not None:
                processed_data.append(audio)
                pauses = self.audio_processor.detect_pauses(audio)
                annotations.append(pauses)
        
        if not processed_data:
            raise ValueError("No valid audio files could be processed")
        
        processed_data = np.array(processed_data)
        scaler = MinMaxScaler(feature_range=(-1, 1))
        normalized_data = scaler.fit_transform(processed_data.reshape(len(processed_data), -1))
        tensor_data = torch.tensor(normalized_data, dtype=torch.float32)
        
        train_data, test_data = train_test_split(tensor_data, test_size=0.2)
        train_loader = DataLoader(TensorDataset(train_data), batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(TensorDataset(test_data), batch_size=batch_size)
        
        return {'train_loader': train_loader, 'test_loader': test_loader, 'scaler': scaler}

class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

class WGAN(nn.Module):
    def __init__(self, latent_dim, output_dim, device='cuda'):
        super().__init__()
        self.device = device
        self.latent_dim = latent_dim
        self.generator = Generator(latent_dim, output_dim).to(device)
        self.discriminator = Discriminator(output_dim).to(device)
        
        self.g_optimizer = optim.RMSprop(self.generator.parameters(), lr=0.00005)
        self.d_optimizer = optim.RMSprop(self.discriminator.parameters(), lr=0.00005)
    
    def train_step(self, real_data):
        batch_size = real_data.size(0)
        real_data = real_data.to(self.device)
        
        # Train discriminator
        for _ in range(5):
            self.d_optimizer.zero_grad()
            
            z = torch.randn(batch_size, self.latent_dim, device=self.device)
            fake_data = self.generator(z)
            
            d_real = self.discriminator(real_data)
            d_fake = self.discriminator(fake_data.detach())
            
            d_loss = -(torch.mean(d_real) - torch.mean(d_fake))
            d_loss.backward()
            self.d_optimizer.step()
            
            # Weight clipping
            for p in self.discriminator.parameters():
                p.data.clamp_(-0.01, 0.01)
        
        # Train generator
        self.g_optimizer.zero_grad()
        fake_data = self.generator(z)
        g_loss = -torch.mean(self.discriminator(fake_data))
        g_loss.backward()
        self.g_optimizer.step()
        
        return d_loss.item(), g_loss.item()

def train_gan(dataset, latent_dim=100, epochs=200, device='cuda'):
    output_dim = next(iter(dataset['train_loader']))[0].shape[1]
    model = WGAN(latent_dim, output_dim, device)
    
    for epoch in range(epochs):
        d_losses, g_losses = [], []
        
        for batch in dataset['train_loader']:
            d_loss, g_loss = model.train_step(batch[0])
            d_losses.append(d_loss)
            g_losses.append(g_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: D Loss = {np.mean(d_losses):.4f}, G Loss = {np.mean(g_losses):.4f}")
    
    return model.generator, dataset['scaler']

def save_generated_audio_with_annotations(generator, scaler, output_dir, num_samples=10, 
                                        sample_rate=16000, device='cuda'):
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'annotations'), exist_ok=True)
    
    audio_processor = AudioProcessor(sr=sample_rate)
    
    with torch.no_grad():
        z = torch.randn(num_samples, generator.model[0].in_features, device=device)
        generated_audio = generator(z).cpu().numpy()
        generated_audio = scaler.inverse_transform(generated_audio)
        
        for i, audio in enumerate(generated_audio):
            # Save audio
            audio = audio.reshape(-1)
            audio = np.clip(audio, -1, 1)
            audio_int16 = (audio * 32767).astype(np.int16)
            
            audio_path = os.path.join(output_dir, f'synthetic_audio_{i}.wav')
            wavfile.write(audio_path, sample_rate, audio_int16)
            
            # Generate and save annotations
            pauses = audio_processor.detect_pauses(audio)
            annotation = {
                'audio_file': f'synthetic_audio_{i}.wav',
                'sample_rate': sample_rate,
                'duration': len(audio) / sample_rate,
                'pauses': pauses
            }
            
            json_path = os.path.join(output_dir, 'annotations', f'synthetic_audio_{i}.json')
            with open(json_path, 'w') as f:
                json.dump(annotation, f, indent=2)

def main(audio_dir, annotation_dir, output_dir, device='cuda'):
    dataset = AudioDataset(audio_dir, annotation_dir)
    processed_data = dataset.prepare_dataset()
    generator, scaler = train_gan(processed_data, device=device)
    save_generated_audio_with_annotations(generator, scaler, output_dir, device=device)

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    main('./cleaned_wav_files', './JSON', './generated_audio', device)

Epoch 0: D Loss = -0.5372, G Loss = -0.3834
Epoch 10: D Loss = -0.9829, G Loss = -0.0128
Epoch 20: D Loss = -0.9924, G Loss = -0.0047
Epoch 30: D Loss = -0.9956, G Loss = -0.0026
Epoch 40: D Loss = -0.9972, G Loss = -0.0020
Epoch 50: D Loss = -0.9988, G Loss = -0.0013
Epoch 60: D Loss = -0.9989, G Loss = -0.0009
Epoch 70: D Loss = -0.9993, G Loss = -0.0005
Epoch 80: D Loss = -0.9996, G Loss = -0.0005
Epoch 90: D Loss = -0.9998, G Loss = -0.0002
Epoch 100: D Loss = -0.9998, G Loss = -0.0003
Epoch 110: D Loss = -0.9999, G Loss = -0.0001
Epoch 120: D Loss = -0.9999, G Loss = -0.0001
Epoch 130: D Loss = -0.9998, G Loss = -0.0001
Epoch 140: D Loss = -0.9999, G Loss = -0.0001
Epoch 150: D Loss = -1.0000, G Loss = -0.0000
Epoch 160: D Loss = -1.0000, G Loss = -0.0000
