In [1]:
import pdb
from src.components.data.data import EEGComponent, AudioComponent
from src.components.data.data import DataSegmenter, AudioMelDataset

from torch.utils.data import DataLoader

eeg_loader = EEGComponent(subject='01')
audio_loader = AudioComponent(subject='01')

eeg_data = eeg_loader._get_eeg()
audio_data = audio_loader._get_audio()


print(f"EEG data shape: {eeg_data.shape}")
print(f"Audio data shape: {audio_data.shape}")

segmenter = DataSegmenter(eeg_data, audio_data)
eeg_segments, mel_spec_segments, audio_segments = segmenter.segment_data()

print(f"Segmented EEG data shape: {eeg_segments.shape}")
print(f"Segmented mel data shape: {mel_spec_segments.shape}")
print(f"Segmented Audio data shape: {audio_segments.shape}")

dataset = AudioMelDataset(
    audio_segments=audio_segments, 
    mel_segments= mel_spec_segments
)

loader = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True)




Loading EEG data from /home/owaismujtaba/work/NeuralSpeech/data/raw/P01_sEEG.npy
stimulus from /home/owaismujtaba/work/NeuralSpeech/data/raw/P01_stimuli.npy
Channel info from /home/owaismujtaba/work/NeuralSpeech/data/raw/P01_channels.npy
Cleaning EEG data...
(307511, 130)
EEG cleaning completed.
Preprocessing EEG data...
Performing electrode shaft referencing...
(307511, 130)
Electrode shaft referencing completed.
Detrending and filtering EEG data...
Applying bandpass and notch filters...
EEG shape:  (307511, 130)
EEG preprocessing completed.


Loading and resampling audio from /home/owaismujtaba/work/NeuralSpeech/data/raw/P01_audio.npy
Audio loaded and resampled to 16kHz. New shape: (6621676,)
EEG data shape: (307511, 130)
Audio data shape: (6621676,)


Segmenting data into chunks of size 0.05 seconds
Mel spectrogram shape: (6009, 80)
EEG segmented shape: (6029, 51, 130)
Audio segments shape: (6008, 1102)
----------------------------------------------------------------------
Mel spe

In [4]:
from hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, discriminator_loss, generator_loss, feature_loss
from hifigan.env import AttrDict


In [10]:
config_file = '/home/owaismujtaba/work/NeuralSpeech/hifigan/pretrained/UNIVERSAL_V1/config.json'
model_path = '/home/owaismujtaba/work/NeuralSpeech/hifigan/pretrained/UNIVERSAL_V1/do_02500000'

In [11]:
import torch
import json
BATCH_SIZE = 8
LR_G = 1e-4
LR_D = 1e-4
EPOCHS = 50
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
lambda_feat = 10.0  # Feature matching loss weight
lambda_adv = 2.0    # Adversarial loss weight


# -----------------------------
# 4. Load HiFi-GAN generator and discriminators
# -----------------------------
with open(config_file) as f:
    config = AttrDict(json.load(f))

generator = Generator(config).to(DEVICE)
mpd = MultiPeriodDiscriminator().to(DEVICE)
msd = MultiScaleDiscriminator().to(DEVICE)

In [12]:
state_dict_g = torch.load(model_path, map_location="cpu")

In [15]:
mpd = state_dict_g['mpd']

In [16]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# -----------------------------

generator.load_state_dict(state_dict_g['generator'])
mpd.load_state_dict(state_dict_g['mpd'])
generator.train()
mpd.train()
msd.train()

# -----------------------------
# 2. Optimizers & Losses
# -----------------------------
optim_g = torch.optim.Adam(generator.parameters(), lr=LR_G, betas=(0.8, 0.99))
optim_d = torch.optim.Adam(list(mpd.parameters()) + list(msd.parameters()), lr=LR_D, betas=(0.8, 0.99))

l1_loss = nn.L1Loss()

# -----------------------------
# 3. Training loop
# -----------------------------
for epoch in range(EPOCHS):
    total_g_loss = 0.0
    total_d_loss = 0.0

    for audio, mel in loader:  # audio: [B, T], mel: [B, mel_channels, T]
        mel = mel.T.to(DEVICE).float()
        audio = audio.to(DEVICE).float()
        print(mel.shape, audio.shape)

        # ---------------------
        # Generator forward
        # ---------------------
        generated_audio = generator(mel)

        # Trim to the same length
        min_len = min(generated_audio.size(-1), audio.size(-1))
        audio = audio[:, :min_len]
        generated_audio = generated_audio[:, :min_len]

        # Add channel dimension for discriminators
        audio_in = audio.unsqueeze(1)              # [B, 1, T]
        generated_in = generated_audio.unsqueeze(1)

        # ---------------------
        # Discriminator update
        # ---------------------
        optim_d.zero_grad()

        mpd_real, mpd_fake, mpd_fmaps_r, mpd_fmaps_g = mpd(audio_in, generated_in.detach())
        msd_real, msd_fake, msd_fmaps_r, msd_fmaps_g = msd(audio_in, generated_in.detach())

        # Discriminator loss
        loss_d_mpd, _, _ = discriminator_loss(mpd_real, mpd_fake)
        loss_d_msd, _, _ = discriminator_loss(msd_real, msd_fake)
        loss_d = loss_d_mpd + loss_d_msd

        loss_d.backward()
        optim_d.step()
        total_d_loss += loss_d.item()

        # ---------------------
        # Generator update
        # ---------------------
        optim_g.zero_grad()

        mpd_real, mpd_fake, mpd_fmaps_r, mpd_fmaps_g = mpd(audio_in, generated_in)
        msd_real, msd_fake, msd_fmaps_r, msd_fmaps_g = msd(audio_in, generated_in)

        # Adversarial loss
        loss_adv_mpd, _ = generator_loss(mpd_fake)
        loss_adv_msd, _ = generator_loss(msd_fake)
        loss_adv = loss_adv_mpd + loss_adv_msd

        # Feature matching loss
        loss_fm = feature_loss(mpd_fmaps_r, mpd_fmaps_g) + feature_loss(msd_fmaps_r, msd_fmaps_g)

        # L1 waveform reconstruction
        loss_l1 = l1_loss(generated_audio, audio)

        # Total generator loss
        loss_g = lambda_adv * loss_adv + lambda_feat * loss_fm + loss_l1
        loss_g.backward()
        optim_g.step()
        total_g_loss += loss_g.item()

    # ---------------------
    # Epoch logging
    # ---------------------
    avg_g_loss = total_g_loss / len(loader)
    avg_d_loss = total_d_loss / len(loader)
    print(f"Epoch {epoch+1}/{EPOCHS} | G_Loss: {avg_g_loss:.6f} | D_Loss: {avg_d_loss:.6f}")

# -----------------------------
# -----------------------------
torch.save({'generator': generator.state_dict()}, "generator_finetuned_dutch.pt")
print("Generator saved successfully.")


KeyError: 'generator'

In [8]:
mel.T.shape

torch.Size([80, 16])

In [26]:
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
import torch

# Load HiFi-GAN vocoder
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

# (Optionally) load a Dutch finetuned version if available
# vocoder = SpeechT5HifiGan.from_pretrained("your-hf-username/dutch-hifigan")


  from .autonotebook import tqdm as notebook_tqdm


In [28]:
vocoder.base_model

SpeechT5HifiGan(
  (conv_pre): Conv1d(80, 512, kernel_size=(7,), stride=(1,), padding=(3,))
  (upsampler): ModuleList(
    (0): ConvTranspose1d(512, 256, kernel_size=(8,), stride=(4,), padding=(2,))
    (1): ConvTranspose1d(256, 128, kernel_size=(8,), stride=(4,), padding=(2,))
    (2): ConvTranspose1d(128, 64, kernel_size=(8,), stride=(4,), padding=(2,))
    (3): ConvTranspose1d(64, 32, kernel_size=(8,), stride=(4,), padding=(2,))
  )
  (resblocks): ModuleList(
    (0): HifiGanResidualBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      )
    )
    (1): HifiGanResidualBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256,