In [1]:
import mido
import numpy as np
import string
import os
import random
import torch
from torch.utils.data import Dataset
import gc
from torch.utils.data import DataLoader
import torch.nn.utils.rnn as rnn_utils
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()

In [3]:
def msg2dict(msg):
    result = dict()
    if 'note_on' in msg:
        on_ = True
    elif 'note_off' in msg:
        on_ = False
    else:
        on_ = None
    result['time'] = int(msg[msg.rfind('time'):].split(' ')[0].split('=')[1].translate(
        str.maketrans({a: None for a in string.punctuation})))

    if on_ is not None:
        for k in ['note', 'velocity']:
            result[k] = int(msg[msg.rfind(k):].split(' ')[0].split('=')[1].translate(
                str.maketrans({a: None for a in string.punctuation})))
    return [result, on_]

In [4]:
def switch_note(last_state, note, velocity, on_=True):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of this range will be ignored
    result = [0] * 88 if last_state is None else last_state.copy()
    if 21 <= note <= 108:
        result[note-21] = velocity if on_ else 0
    return result

In [5]:
def get_new_state(new_msg, last_state):
    new_msg, on_ = msg2dict(str(new_msg))
    new_state = switch_note(last_state, note=new_msg['note'], velocity=new_msg['velocity'], on_=on_) if on_ is not None else last_state
    return [new_state, new_msg['time']]

In [6]:
def track2seq(track):
    result = []
    last_state, last_time = get_new_state(str(track[0]), [0]*88)
    for i in range(1, len(track)):
        new_state, new_time = get_new_state(track[i], last_state)
        if new_time > 0:
            result += [last_state]*new_time
        last_state, last_time = new_state, new_time
    return result

In [7]:
def mid2arry(mid, min_msg_pct=0.1):
    tracks_len = [len(tr) for tr in mid.tracks]
    min_n_msg = max(tracks_len) * min_msg_pct
    # convert each track to nested list
    all_arys = []
    for i in range(len(mid.tracks)):
        if len(mid.tracks[i]) > min_n_msg:
            ary_i = track2seq(mid.tracks[i])
            all_arys.append(ary_i)
    # make all nested list the same length
    max_len = max([len(ary) for ary in all_arys])
    for i in range(len(all_arys)):
        if len(all_arys[i]) < max_len:
            all_arys[i] += [[0] * 88] * (max_len - len(all_arys[i]))
    all_arys = np.array(all_arys)
    all_arys = all_arys.max(axis=0)
    # trim: remove consecutive 0s in the beginning and at the end
    sums = all_arys.sum(axis=1)
    ends = np.where(sums > 0)[0]
    return all_arys[min(ends): max(ends)]

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
def load_generator(model_path, optimizer_path, lr, latent_dim):
    model = Generator(latent_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.load_state_dict(torch.load(model_path))
    optimizer.load_state_dict(torch.load(optimizer_path))
    return model, optimizer

def load_discriminator(model_path, optimizer_path, lr):
    model = Discriminator().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.load_state_dict(torch.load(model_path))
    optimizer.load_state_dict(torch.load(optimizer_path))
    return model, optimizer


In [10]:
def save_model_and_optimizer_generator(model, optimizer, epoch, model_dir="model_checkpoints_2"):
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, f"generator_epoch_{epoch+1}.pth")
    optimizer_path = os.path.join(model_dir, f"generator_optimizer_epoch_{epoch+1}.pth")

    torch.save(model.state_dict(), model_path)
    torch.save(optimizer.state_dict(), optimizer_path)

    print(f"Generator model and optimizer states saved after epoch {epoch+1}")

def save_model_and_optimizer_discriminator(model, optimizer, epoch, model_dir="model_checkpoints_2"):
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, f"discriminator_epoch_{epoch+1}.pth")
    optimizer_path = os.path.join(model_dir, f"discriminator_optimizer_epoch_{epoch+1}.pth")

    torch.save(model.state_dict(), model_path)
    torch.save(optimizer.state_dict(), optimizer_path)

    print(f"Discriminator model and optimizer states saved after epoch {epoch+1}")


In [11]:
class MIDIDataset(Dataset):
    def __init__(self, midi_dir, max_files=30000):
        self.file_paths = [os.path.join(midi_dir, f) for f in os.listdir(midi_dir) if f.endswith('.mid')]
        random.shuffle(self.file_paths)
        self.file_paths = self.file_paths[:max_files]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        for _ in range(10):  # Try up to 10 times before giving up
            path = self.file_paths[idx]
            try:
                mid = mido.MidiFile(path)
                arr = mid2arry(mid)

                if arr is None or len(arr) < 4000:
                    #print(f"Skipping (too short or silent): {path}")
                    idx = (idx + 1) % len(self.file_paths)
                    continue

                arr = arr[:8000]  # Cap at 8000 rows
                arr = np.array(arr, dtype=np.float32)
                arr /= 127.0
                tensor = torch.from_numpy(arr).clone()
                del arr
                gc.collect()
                return tensor

            except Exception as e:
                #print(f"Failed at {path}: {e}")
                idx = (idx + 1) % len(self.file_paths)

        # If it fails 10 times in a row, just return a zero tensor
        print("Too many failures. Returning dummy tensor.")
        return torch.zeros((8000, 88), dtype=torch.float32)

In [12]:
def midi_collate_fn(batch):
    """
    Pads each sample in the batch to exactly 8000 rows (sequence length).
    """
    target_len = 8000
    processed_batch = []

    for x in batch:
        length = x.size(0)
        if length < target_len:
            pad_len = target_len - length
            padding = torch.zeros((pad_len, x.size(1)))
            x_padded = torch.cat([x, padding], dim=0)
        else:
            x_padded = x[:target_len]
        processed_batch.append(x_padded)

    return torch.stack(processed_batch)  # Shape: (batch_size, 8000, features)

In [13]:
dataset = MIDIDataset(midi_dir="/home/sakshisahemail/Desktop/GANs/Dataset/Full_Dataset/", max_files=30000)

In [15]:
print(f"Total MIDI files found: {len(dataset)}")

Total MIDI files found: 30000


In [16]:
import multiprocessing
num_workers = max(1, multiprocessing.cpu_count() - 1)
num_workers

3

In [17]:
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=midi_collate_fn,
    num_workers=num_workers,  
    pin_memory=True,    
    drop_last=True      
)

In [18]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Starting from latent vector
            nn.Linear(latent_dim, 256 * 500 * 11),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 500, 11)),  # Shape: (256, 500, 11)

            # Upsample to (128, 1000, 22)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # Upsample to (64, 2000, 22) — only height increases
            nn.ConvTranspose2d(128, 64, kernel_size=(4, 3), stride=(2, 1), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # Upsample to (32, 4000, 44)
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            # Final upsample to (1, 8000, 88)
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.net(z)


In [27]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=4),      # [B, 64, 2000, 22]
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=4),    # [B, 128, 500, 5]
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # [B, 256, ~250, ~3]
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),  # [B, 512, ~125, ~2]
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Flatten size calculation
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 8000, 88)
            out = self.features(dummy_input)
            flattened_size = out.view(1, -1).size(1)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, 1),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [34]:
lr_D = 0.0004  # slow it down
lr_G = 0.0008  # let G catch up

In [35]:
latent_dim = 100  

G, optimizer_G = load_generator(
    "model_checkpoints_1/generator_epoch_10.pth",
    "model_checkpoints_1/generator_optimizer_epoch_10.pth",
    lr_G,
    latent_dim
)

D, optimizer_D = load_discriminator(
    "model_checkpoints_1/discriminator_epoch_10.pth",
    "model_checkpoints_1/discriminator_optimizer_epoch_10.pth",
    lr_D
)

  model.load_state_dict(torch.load(model_path))
  optimizer.load_state_dict(torch.load(optimizer_path))
  model.load_state_dict(torch.load(model_path))
  optimizer.load_state_dict(torch.load(optimizer_path))


In [36]:
G.to(device)
D.to(device)

Discriminator(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(4, 4))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(4, 4))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=128000, out_

In [37]:
criterion = nn.BCEWithLogitsLoss()
latent_dim = 100
epochs = 12

In [38]:
G.train()
D.train()

for epoch in range(10):
    print(f"\nEpoch {epoch+1}/{10}")
    
    epoch_loss_D = 0.0
    epoch_loss_G = 0.0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}", leave=False)

    for batch_idx, real_imgs in progress_bar:
        real_imgs = real_imgs.to(device)
        real_imgs = real_imgs.unsqueeze(1)
        b_size = real_imgs.size(0)

        # === Train Discriminator ===
        z = torch.randn(b_size, latent_dim, device=device)
        fake_imgs = G(z)

        D_real = D(real_imgs).view(-1)
        D_fake = D(fake_imgs.detach()).view(-1)

        real_labels = torch.full_like(D_real, 0.9, device=device)  # label smoothing
        fake_labels = torch.zeros_like(D_fake, device=device)

        loss_D_real = criterion(D_real, real_labels)
        loss_D_fake = criterion(D_fake, fake_labels)
        loss_D = loss_D_real + loss_D_fake

        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # === Train Generator ===
        D_fake = D(fake_imgs).view(-1)
        real_gen_labels = torch.ones_like(D_fake, device=device)
        loss_G = criterion(D_fake, real_gen_labels)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        epoch_loss_D += loss_D.item()
        epoch_loss_G += loss_G.item()

        # Update tqdm bar with loss info
        progress_bar.set_postfix({
            'Loss D': f"{loss_D.item():.4f}",
            'Loss G': f"{loss_G.item():.4f}"
        })

    avg_loss_D = epoch_loss_D / len(dataloader)
    avg_loss_G = epoch_loss_G / len(dataloader)
    print(f"✅ Epoch {epoch+1}/{10} | Avg Loss D: {avg_loss_D:.4f} | Avg Loss G: {avg_loss_G:.4f}")


Epoch 1/10


                                                                                

✅ Epoch 1/10 | Avg Loss D: 0.4395 | Avg Loss G: 11.1276

Epoch 2/10


                                                                                

✅ Epoch 2/10 | Avg Loss D: 0.3919 | Avg Loss G: 10.6032

Epoch 3/10


                                                                                

✅ Epoch 3/10 | Avg Loss D: 0.3975 | Avg Loss G: 9.0208

Epoch 4/10


                                                                                

✅ Epoch 4/10 | Avg Loss D: 0.3595 | Avg Loss G: 9.7055

Epoch 5/10


                                                                                

✅ Epoch 5/10 | Avg Loss D: 0.3746 | Avg Loss G: 9.9467

Epoch 6/10


                                                                                

✅ Epoch 6/10 | Avg Loss D: 0.4592 | Avg Loss G: 10.5539

Epoch 7/10


                                                                                

✅ Epoch 7/10 | Avg Loss D: 0.3853 | Avg Loss G: 7.8659

Epoch 8/10


                                                                                

✅ Epoch 8/10 | Avg Loss D: 0.3529 | Avg Loss G: 10.1053

Epoch 9/10


                                                                                

✅ Epoch 9/10 | Avg Loss D: 0.3467 | Avg Loss G: 11.2942

Epoch 10/10


                                                                                

✅ Epoch 10/10 | Avg Loss D: 0.3523 | Avg Loss G: 13.0075




In [39]:
save_model_and_optimizer_generator(G, optimizer_G, epoch)
save_model_and_optimizer_discriminator(D, optimizer_D, epoch)

Generator model and optimizer states saved after epoch 10
Discriminator model and optimizer states saved after epoch 10


In [40]:
def generate_and_save_samples(generator, num_samples=500, output_dir="generated_samples_round_2", batch_size=1):
    import gc
    os.makedirs(output_dir, exist_ok=True)
    generator.eval()

    num_batches = (num_samples + batch_size - 1) // batch_size
    sample_count = 0

    with torch.no_grad():
        for batch_idx in range(num_batches):
            current_batch_size = min(batch_size, num_samples - sample_count)
            z = torch.randn(current_batch_size, latent_dim, device=device)
            fake_imgs = generator(z)

            fake_imgs = fake_imgs.squeeze(1).cpu().numpy()
            for i, img in enumerate(fake_imgs):
                np.save(os.path.join(output_dir, f"generated_sample_{sample_count + 1}.npy"), img)
                sample_count += 1

            # Memory cleanup
            del z, fake_imgs
            torch.cuda.empty_cache()
            gc.collect()

    print(f"Generated {num_samples} samples and saved to {output_dir}")


In [42]:
generate_and_save_samples(G, num_samples=500)

Generated 500 samples and saved to generated_samples_round_2


In [None]:
array = np.load('path_to_file.npy')