In [None]:
import mido
import numpy as np
import string
import os
import random
import torch
from torch.utils.data import Dataset
import gc
from torch.utils.data import DataLoader
import torch.nn.utils.rnn as rnn_utils
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA L4


In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()

In [None]:
def msg2dict(msg):
    result = dict()
    if 'note_on' in msg:
        on_ = True
    elif 'note_off' in msg:
        on_ = False
    else:
        on_ = None
    result['time'] = int(msg[msg.rfind('time'):].split(' ')[0].split('=')[1].translate(
        str.maketrans({a: None for a in string.punctuation})))

    if on_ is not None:
        for k in ['note', 'velocity']:
            result[k] = int(msg[msg.rfind(k):].split(' ')[0].split('=')[1].translate(
                str.maketrans({a: None for a in string.punctuation})))
    return [result, on_]

In [None]:
def switch_note(last_state, note, velocity, on_=True):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of this range will be ignored
    result = [0] * 88 if last_state is None else last_state.copy()
    if 21 <= note <= 108:
        result[note-21] = velocity if on_ else 0
    return result

In [None]:
def get_new_state(new_msg, last_state):
    new_msg, on_ = msg2dict(str(new_msg))
    new_state = switch_note(last_state, note=new_msg['note'], velocity=new_msg['velocity'], on_=on_) if on_ is not None else last_state
    return [new_state, new_msg['time']]

In [None]:
def track2seq(track):
    result = []
    last_state, last_time = get_new_state(str(track[0]), [0]*88)
    for i in range(1, len(track)):
        new_state, new_time = get_new_state(track[i], last_state)
        if new_time > 0:
            result += [last_state]*new_time
        last_state, last_time = new_state, new_time
    return result

In [None]:
def mid2arry(mid, min_msg_pct=0.1):
    tracks_len = [len(tr) for tr in mid.tracks]
    min_n_msg = max(tracks_len) * min_msg_pct
    # convert each track to nested list
    all_arys = []
    for i in range(len(mid.tracks)):
        if len(mid.tracks[i]) > min_n_msg:
            ary_i = track2seq(mid.tracks[i])
            all_arys.append(ary_i)
    # make all nested list the same length
    max_len = max([len(ary) for ary in all_arys])
    for i in range(len(all_arys)):
        if len(all_arys[i]) < max_len:
            all_arys[i] += [[0] * 88] * (max_len - len(all_arys[i]))
    all_arys = np.array(all_arys)
    all_arys = all_arys.max(axis=0)
    # trim: remove consecutive 0s in the beginning and at the end
    sums = all_arys.sum(axis=1)
    ends = np.where(sums > 0)[0]
    return all_arys[min(ends): max(ends)]

In [None]:
class MIDIDataset(Dataset):
    def __init__(self, midi_dir, max_files=30000):
        self.file_paths = [os.path.join(midi_dir, f) for f in os.listdir(midi_dir) if f.endswith('.mid')]
        random.shuffle(self.file_paths)
        self.file_paths = self.file_paths[:max_files]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        for _ in range(10):  # Try up to 10 times before giving up
            path = self.file_paths[idx]
            try:
                mid = mido.MidiFile(path)
                arr = mid2arry(mid)

                if arr is None or len(arr) < 6000:
                    #print(f"Skipping (too short or silent): {path}")
                    idx = (idx + 1) % len(self.file_paths)
                    continue

                arr = arr[:8000]  # Cap at 8000 rows
                arr = np.array(arr, dtype=np.float32)
                arr /= 127.0
                tensor = torch.from_numpy(arr).clone()
                del arr
                gc.collect()
                return tensor

            except Exception as e:
                #print(f"Failed at {path}: {e}")
                idx = (idx + 1) % len(self.file_paths)

        # If it fails 10 times in a row, just return a zero tensor
        print("Too many failures. Returning dummy tensor.")
        return torch.zeros((8000, 88), dtype=torch.float32)

In [None]:
def midi_collate_fn(batch):
    """
    Pads each sample in the batch to exactly 8000 rows (sequence length).
    """
    target_len = 8000
    processed_batch = []

    for x in batch:
        length = x.size(0)
        if length < target_len:
            pad_len = target_len - length
            padding = torch.zeros((pad_len, x.size(1)))
            x_padded = torch.cat([x, padding], dim=0)
        else:
            x_padded = x[:target_len]
        processed_batch.append(x_padded)

    return torch.stack(processed_batch)  # Shape: (batch_size, 8000, features)

In [None]:
dataset = MIDIDataset(midi_dir="/home/sakshisahemail/Desktop/GANs/Dataset/Full_Dataset/", max_files=30000)

In [None]:
print(f"Total MIDI files found: {len(dataset)}")

Total MIDI files found: 30000


In [None]:
import multiprocessing
num_workers = max(1, multiprocessing.cpu_count() - 1)
num_workers

3

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=midi_collate_fn,
    num_workers=num_workers, 
    pin_memory=True,    
    drop_last=True     
)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Starting from latent vector
            nn.Linear(latent_dim, 256 * 500 * 11),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 500, 11)),  # Shape: (256, 500, 11)

            # Upsample to (128, 1000, 22)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # Upsample to (64, 2000, 22) — only height increases
            nn.ConvTranspose2d(128, 64, kernel_size=(4, 3), stride=(2, 1), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # Upsample to (32, 4000, 44)
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            # Final upsample to (1, 8000, 88)
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.net(z)


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )

        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 8000, 88)
            out = self.features(dummy_input)
            flattened_size = out.view(1, -1).size(1)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, 1),
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

In [None]:
mean = 0.0
std = 0.02

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)

    d_interpolates = D(interpolates)
    fake = torch.ones(d_interpolates.size(), device=device)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) + 1e-8 - 1) ** 2).mean()
    return gradient_penalty

In [None]:
z_dim = 100  # Latent vector dimension for generator
G = Generator(z_dim)
D = Discriminator()
G.apply(weights_init)
D.apply(weights_init)

Discriminator(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(4, 4))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(4, 4))
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=128000, out_features=1, bias=True)
  )
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
G.to(device)

Generator(
  (net): Sequential(
    (0): Linear(in_features=100, out_features=1408000, bias=True)
    (1): ReLU(inplace=True)
    (2): Unflatten(dim=1, unflattened_size=(256, 500, 11))
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 3), stride=(2, 1), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Sigmoid()
  )
)

In [None]:
D.to(device)

Discriminator(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(4, 4))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(4, 4))
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=128000, out_features=1, bias=True)
  )
)

In [None]:
latent_dim = z_dim
lr_G = 0.0001
lr_D = 0.00005
optimizer_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0.5, 0.999))

In [None]:
lambda_gp = 10
n_critic = 5
epochs = 20

In [None]:
G.train()
D.train()

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/10")

    epoch_loss_D = 0.0
    epoch_loss_G = 0.0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}", leave=False)

    for batch_idx, real_imgs in progress_bar:
        real_imgs = real_imgs.to(device).unsqueeze(1)
        b_size = real_imgs.size(0)

        # === Train Discriminator ===
        for _ in range(n_critic):
            z = torch.randn(b_size, latent_dim, device=device)
            fake_imgs = G(z)

            D_real = D(real_imgs)
            D_fake = D(fake_imgs.detach())

            gradient_penalty = compute_gradient_penalty(D, real_imgs, fake_imgs.detach())

            loss_D = -torch.mean(D_real) + torch.mean(D_fake) + lambda_gp * gradient_penalty

            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

        # === Train Generator ===
        z = torch.randn(b_size, latent_dim, device=device)
        fake_imgs = G(z)
        D_fake = D(fake_imgs)

        loss_G = -torch.mean(D_fake)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        epoch_loss_D += loss_D.item()
        epoch_loss_G += loss_G.item()

        progress_bar.set_postfix({
            'Loss D': f"{loss_D.item():.4f}",
            'Loss G': f"{loss_G.item():.4f}"
        })

    avg_loss_D = epoch_loss_D / len(dataloader)
    avg_loss_G = epoch_loss_G / len(dataloader)
    print(f"✅ Epoch {epoch+1} | Avg Loss D: {avg_loss_D:.4f} | Avg Loss G: {avg_loss_G:.4f}")


Epoch 1/10


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                                                                                

✅ Epoch 1 | Avg Loss D: -591.5999 | Avg Loss G: -1510.0460

Epoch 2/10


                                                                                

✅ Epoch 2 | Avg Loss D: -170.1729 | Avg Loss G: -271.3940

Epoch 3/10


                                                                                

✅ Epoch 3 | Avg Loss D: -119.2517 | Avg Loss G: -11.2613

Epoch 4/10


                                                                                

✅ Epoch 4 | Avg Loss D: -47.9646 | Avg Loss G: -22.8571

Epoch 5/10


                                                                                

✅ Epoch 5 | Avg Loss D: -35.2043 | Avg Loss G: 0.3676

Epoch 6/10


                                                                                

✅ Epoch 6 | Avg Loss D: -30.9095 | Avg Loss G: 6.0554

Epoch 7/10


                                                                                

✅ Epoch 7 | Avg Loss D: -27.8434 | Avg Loss G: 0.9519

Epoch 8/10


                                                                                

✅ Epoch 8 | Avg Loss D: -25.9520 | Avg Loss G: 1.3571

Epoch 9/10


                                                                                

✅ Epoch 9 | Avg Loss D: -24.6375 | Avg Loss G: -7.2292

Epoch 10/10


                                                                                

✅ Epoch 10 | Avg Loss D: -23.2596 | Avg Loss G: -23.9502

Epoch 11/10


                                                                                

✅ Epoch 11 | Avg Loss D: -22.5672 | Avg Loss G: -38.7396

Epoch 12/10


                                                                                

✅ Epoch 12 | Avg Loss D: -21.3228 | Avg Loss G: -42.3104

Epoch 13/10


                                                                                

✅ Epoch 13 | Avg Loss D: -20.4848 | Avg Loss G: -49.6574

Epoch 14/10


                                                                                

✅ Epoch 14 | Avg Loss D: -20.0593 | Avg Loss G: -56.8813

Epoch 15/10


                                                                                

✅ Epoch 15 | Avg Loss D: -19.4485 | Avg Loss G: -67.2189

Epoch 16/10


                                                                                

✅ Epoch 16 | Avg Loss D: -19.0795 | Avg Loss G: -74.1820

Epoch 17/10


                                                                                

✅ Epoch 17 | Avg Loss D: -18.9488 | Avg Loss G: -80.0631

Epoch 18/10


                                                                                

✅ Epoch 18 | Avg Loss D: -18.5834 | Avg Loss G: -87.7245

Epoch 19/10


                                                                                

✅ Epoch 19 | Avg Loss D: -18.4222 | Avg Loss G: -92.7662

Epoch 20/10


                                                                                

✅ Epoch 20 | Avg Loss D: -18.4916 | Avg Loss G: -97.8763




In [None]:
def generate_and_save_samples(generator, num_samples=500, output_dir="generated_samples_wgangp", batch_size=1):
    import gc
    os.makedirs(output_dir, exist_ok=True)
    generator.eval()

    num_batches = (num_samples + batch_size - 1) // batch_size
    sample_count = 0

    with torch.no_grad():
        for batch_idx in range(num_batches):
            current_batch_size = min(batch_size, num_samples - sample_count)
            z = torch.randn(current_batch_size, latent_dim, device=device)
            fake_imgs = generator(z)

            fake_imgs = fake_imgs.squeeze(1).cpu().numpy()
            for i, img in enumerate(fake_imgs):
                np.save(os.path.join(output_dir, f"generated_sample_{sample_count + 1}.npy"), img)
                sample_count += 1

            # Memory cleanup
            del z, fake_imgs
            torch.cuda.empty_cache()
            gc.collect()

    print(f"Generated {num_samples} samples and saved to {output_dir}")


In [None]:
generate_and_save_samples(G, num_samples=500)

Generated 500 samples and saved to generated_samples_wgangp


In [None]:
def save_model_and_optimizer_generator(model, optimizer, epoch, model_dir="model_checkpoints_wgangp"):
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, f"generator_epoch_{epoch+1}.pth")
    optimizer_path = os.path.join(model_dir, f"generator_optimizer_epoch_{epoch+1}.pth")
    torch.save(model.state_dict(), model_path)
    torch.save(optimizer.state_dict(), optimizer_path)
    print(f"Generator model and optimizer states saved after epoch {epoch+1}")

def save_model_and_optimizer_discriminator(model, optimizer, epoch, model_dir="model_checkpoints_wgangp"):
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, f"discriminator_epoch_{epoch+1}.pth")
    optimizer_path = os.path.join(model_dir, f"discriminator_optimizer_epoch_{epoch+1}.pth")
    torch.save(model.state_dict(), model_path)
    torch.save(optimizer.state_dict(), optimizer_path)
    print(f"Discriminator model and optimizer states saved after epoch {epoch+1}")

In [None]:
save_model_and_optimizer_generator(G, optimizer_G, epoch)
save_model_and_optimizer_discriminator(D, optimizer_D, epoch)

Generator model and optimizer states saved after epoch 20
Discriminator model and optimizer states saved after epoch 20
