In [16]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
import soundfile as sf
import matplotlib.pyplot as plt

In [17]:
# ====== Dataset Class ======
class GTZANMelDataset(Dataset):
    def __init__(self, data_dir, genre_list, max_width=1024):
        self.data = []
        for label, genre in enumerate(genre_list):
            folder = os.path.join(data_dir, genre)
            for file in os.listdir(folder):
                if file.endswith('.npy'):
                    self.data.append((os.path.join(folder, file), label))
        self.max_width = max_width

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, label = self.data[idx]
        mel = np.load(path)[:, :self.max_width]
        mel = (mel + 80.0) / 80.0 * 2.0 - 1.0
        mel = torch.tensor(mel, dtype=torch.float32).unsqueeze(0)
        label = torch.tensor(label, dtype=torch.long)
        return mel, label

In [18]:
# ====== Generator ======
class Generator(nn.Module):
    def __init__(self, noise_dim=100, genre_dim=10):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + genre_dim, 512 * 4 * 8),
            nn.ReLU()
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, (2, 4), (2, 2), (0, 1)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, (1, 4), (1, 4)),
            nn.Tanh()
        )

    def forward(self, z, labels):
        x = torch.cat([z, labels], dim=1)
        x = self.fc(x).view(-1, 512, 4, 8)
        return self.deconv(x)[:, :, :128, :1024]

# ====== Discriminator ======
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(11, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, 4, 2, 1),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(1 * 8 * 64, 1)

    def forward(self, mel, labels):
        B, _, H, W = mel.shape
        label_map = labels.view(B, 10, 1, 1).expand(B, 10, H, W)
        x = torch.cat([mel, label_map], dim=1)
        x = self.model(x)
        return self.fc(self.flatten(x))

In [None]:
# ====== Training Loop ======
def train_gan():
    genre_list = ['blues', 'classical', 'country', 'disco', 'hiphop',
                    'jazz', 'metal', 'pop', 'reggae', 'rock']
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataloader = DataLoader(GTZANMelDataset(r"C:\Users\user\OneDrive\桌面\游老師機器學習\mel_spectrograms" , genre_list), batch_size=16, shuffle=True)
    G, D = Generator().to(device), Discriminator().to(device)
    g_opt = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.9))
    d_opt = optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.9))
    criterion = nn.MSELoss()
    
    g_losses, d_losses = [], []

    for epoch in range(30):
        for real_mel, label_idx in dataloader:
            B = real_mel.size(0)
            real_mel, label_idx = real_mel.to(device), label_idx.to(device)
            labels = torch.eye(10)[label_idx].to(device)
            real_out = D(real_mel, labels)

            # Train Discriminator
            z = torch.randn(B, 100).to(device)
            fake_mel = G(z, labels).detach()
            fake_out = D(fake_mel, labels)
            d_loss = criterion(real_out, torch.ones_like(real_out)) + \
                    criterion(fake_out, torch.zeros_like(fake_out))
            D.zero_grad(); d_loss.backward(); d_opt.step()

            # Train Generator
            fake_mel = G(z, labels)
            fake_out = D(fake_mel, labels)
            g_loss = criterion(fake_out, torch.ones_like(fake_out))
            G.zero_grad(); g_loss.backward(); g_opt.step()

        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
        print(f"Epoch {epoch+1}/30 | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    torch.save(G.state_dict(), "generator.pt")
    torch.save(D.state_dict(), "discriminator.pt")
    np.save("g_losses.npy", np.array(g_losses))
    np.save("d_losses.npy", np.array(d_losses))

    plt.plot(g_losses, label='G Loss')
    plt.plot(d_losses, label='D Loss')
    plt.legend(); plt.grid(); plt.title("Loss Curve")
    plt.savefig("loss_curve.png"); plt.show()
    
    G.eval()
    with torch.no_grad():
        z = torch.randn(1, 100).to(device)
        sample_label = torch.tensor([[1.0 if i == 0 else 0.0 for i in range(10)]]).to(device)
        fake_mel = G(z, sample_label)

    fake_mel = fake_mel.squeeze().detach().cpu().numpy()
    audio = librosa.feature.inverse.mel_to_audio(fake_mel, sr=22050, n_fft=2048, hop_length=512, win_length=2048, window='hann', power=2.0, n_iter=64)
    audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
    if len(audio.shape) == 1:
        audio = audio[:, np.newaxis]
    sf.write("generated_audio_sample.wav", audio, 22050)
    print("✅ 已產生音訊並儲存為 generated_audio_sample.wav")

if __name__ == '__main__':
    train_gan()


Epoch 1/30 | D Loss: 0.1533 | G Loss: 1.0065
Epoch 2/30 | D Loss: 0.3901 | G Loss: 0.8367
