In [1]:
import json
import os
import pandas as pd
import librosa
import numpy as np
from IPython.display import Audio
import librosa
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Subset, DataLoader
import torch.nn.functional as F
from tqdm.notebook import tqdm

In [2]:
base_dir = "nsynth-train"
json_path = os.path.join(base_dir, "examples.json")
audio_dir = os.path.join(base_dir, "audio")

# Load the whole file as one big JSON object
with open(json_path, "r") as f:
    data_dict = json.load(f)

data = []
for note_str, features in data_dict.items():
    features["note_str"] = note_str
    data.append(features)
df = pd.DataFrame(data)
df["audio_path"] = df["note_str"].apply(lambda x: os.path.join(audio_dir, x + ".wav"))

print("âœ… Loaded", len(df), "examples")


âœ… Loaded 289205 examples


In [3]:
df['instrument_family_str'].unique()

array(['guitar', 'bass', 'organ', 'keyboard', 'vocal', 'string', 'reed',
       'flute', 'mallet', 'brass', 'synth_lead'], dtype=object)

In [4]:
ins1 = 'guitar'
ins2 = 'keyboard'

In [5]:
subset = df[df["instrument_family_str"].isin([ins1, ins2])].copy()
subset = subset[subset["instrument_source_str"].isin(["acoustic"])].copy()
match_cols = ["pitch", "velocity", "sample_rate"]

grouped = subset.groupby(match_cols)

pairs = []

for key, group in grouped:
    guitars = group[group["instrument_family_str"] == ins1]
    keyboards = group[group["instrument_family_str"] == ins2]
    
    if len(guitars) > 0 and len(keyboards) > 0:
        for _, g_row in guitars.iterrows():
            for _, k_row in keyboards.iterrows():
                pairs.append({
                    **{col: g_row[col] for col in match_cols},  # common columns
                    f"{ins1}_audio": g_row["audio_path"],
                    f"{ins2}_audio": k_row["audio_path"],
                    f"{ins1}_inst": g_row["instrument_str"],
                    f"{ins2}_inst": k_row["instrument_str"],
                    f"{ins1}_note": g_row["note_str"],
                    f"{ins2}_note": k_row["note_str"],
                    f"{ins1}_source": g_row["instrument_source_str"],
                    f"{ins2}_source": k_row["instrument_source_str"],
                    f"{ins1}_qualities": g_row["qualities_str"],
                    f"{ins2}_qualities": k_row["qualities_str"],
                })

pairs_df = pd.DataFrame(pairs)
print(f"âœ… Created {len(pairs_df)} guitarâ€“keyboard pairs")
pairs_df.head()

âœ… Created 209750 guitarâ€“keyboard pairs


Unnamed: 0,pitch,velocity,sample_rate,guitar_audio,keyboard_audio,guitar_inst,keyboard_inst,guitar_note,keyboard_note,guitar_source,keyboard_source,guitar_qualities,keyboard_qualities
0,21,25,16000,nsynth-train/audio/guitar_acoustic_009-021-025...,nsynth-train/audio/keyboard_acoustic_014-021-0...,guitar_acoustic_009,keyboard_acoustic_014,guitar_acoustic_009-021-025,keyboard_acoustic_014-021-025,acoustic,acoustic,"[dark, percussive]",[]
1,21,25,16000,nsynth-train/audio/guitar_acoustic_009-021-025...,nsynth-train/audio/keyboard_acoustic_005-021-0...,guitar_acoustic_009,keyboard_acoustic_005,guitar_acoustic_009-021-025,keyboard_acoustic_005-021-025,acoustic,acoustic,"[dark, percussive]","[long_release, reverb]"
2,21,25,16000,nsynth-train/audio/guitar_acoustic_009-021-025...,nsynth-train/audio/keyboard_acoustic_000-021-0...,guitar_acoustic_009,keyboard_acoustic_000,guitar_acoustic_009-021-025,keyboard_acoustic_000-021-025,acoustic,acoustic,"[dark, percussive]",[reverb]
3,21,25,16000,nsynth-train/audio/guitar_acoustic_009-021-025...,nsynth-train/audio/keyboard_acoustic_002-021-0...,guitar_acoustic_009,keyboard_acoustic_002,guitar_acoustic_009-021-025,keyboard_acoustic_002-021-025,acoustic,acoustic,"[dark, percussive]","[dark, reverb]"
4,21,25,16000,nsynth-train/audio/guitar_acoustic_009-021-025...,nsynth-train/audio/keyboard_acoustic_019-021-0...,guitar_acoustic_009,keyboard_acoustic_019,guitar_acoustic_009-021-025,keyboard_acoustic_019-021-025,acoustic,acoustic,"[dark, percussive]","[dark, long_release, reverb]"


In [6]:
#listening to audio
sample = pairs_df.iloc[85000]

# Play the guitar audio
print("ðŸŽ¸ Guitar:", sample["guitar_audio"])
display(Audio(filename=sample["guitar_audio"], rate=16000))

print("ðŸŽ¹ Keyboard:", sample["keyboard_audio"])
display(Audio(filename=sample["keyboard_audio"], rate=16000))
print(sample)

ðŸŽ¸ Guitar: nsynth-train/audio/guitar_acoustic_023-057-050.wav


ðŸŽ¹ Keyboard: nsynth-train/audio/keyboard_acoustic_003-057-050.wav


pitch                                                                57
velocity                                                             50
sample_rate                                                       16000
guitar_audio          nsynth-train/audio/guitar_acoustic_023-057-050...
keyboard_audio        nsynth-train/audio/keyboard_acoustic_003-057-0...
guitar_inst                                         guitar_acoustic_023
keyboard_inst                                     keyboard_acoustic_003
guitar_note                                 guitar_acoustic_023-057-050
keyboard_note                             keyboard_acoustic_003-057-050
guitar_source                                                  acoustic
keyboard_source                                                acoustic
guitar_qualities                    [long_release, multiphonic, reverb]
keyboard_qualities                                                   []
Name: 85000, dtype: object


In [7]:
class GuitarKeyboardDataset(Dataset):
    def __init__(self, df, sr=16000, n_mels=128, hop_length=256, duration=2.0):
        """
        df: DataFrame with columns ['guitar_audio', 'keyboard_audio']
        sr: sample rate
        n_mels: number of mel bins
        hop_length: hop length for STFT
        duration: duration in seconds to fix length of audio/mel
        """
        self.df = df.reset_index(drop=True)
        self.sr = sr
        self.n_mels = n_mels
        self.hop_length = hop_length
        self.num_samples = int(duration * sr)  # fixed number of samples

    def audio_to_mel(self, path):
        # Load audio
        y, _ = librosa.load(path, sr=self.sr)
        # Truncate/pad to fixed length
        if len(y) > self.num_samples:
            y = y[:self.num_samples]
        elif len(y) < self.num_samples:
            y = np.pad(y, (0, self.num_samples - len(y)))
        # Convert to mel
        mel = librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels, hop_length=self.hop_length)
        mel_db = librosa.power_to_db(mel, ref=np.max)
        return mel_db.astype(np.float32)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        g_mel = self.audio_to_mel(row["guitar_audio"])
        k_mel = self.audio_to_mel(row["keyboard_audio"])
        # Add channel dimension for CNNs: (C=1, n_mels, time)
        g_mel = torch.tensor(g_mel).unsqueeze(0)
        k_mel = torch.tensor(k_mel).unsqueeze(0)
        return g_mel, k_mel

    def __len__(self):
        return len(self.df)


In [17]:
dataset = GuitarKeyboardDataset(pairs_df, duration=4.0)  # 2-second clips
loader = DataLoader(dataset, batch_size=16, shuffle=True)

In [18]:
# updated rectified flow to also take time t as input, so now input: (audio,time)
class RectifiedFlowModel(nn.Module):
    def __init__(self, n_mels=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(2, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 1, 3, padding=1),
        )
        
    def forward(self, x, t):
        """
        x: (B, 1, n_mels, T)
        t: (B, 1, 1, 1) or broadcastable scalar
        """
        if t is None:
            raise ValueError("t must be provided for rectified flow model")

        # Make a constant time-channel with same H,W
        B, _, H, W = x.shape
        t_map = t.expand(B, 1, H, W)   # (B, 1, H, W)

        x_in = torch.cat([x, t_map], dim=1)  # (B, 2, H, W)
        return self.conv(x_in) 




In [19]:
# Instantiate model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RectifiedFlowModel(n_mels=128).to(device)
opt = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()
print(device)
#device = 'cuda'

cuda


In [20]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
subset_indices = np.random.choice(len(loader.dataset), 10000, replace=False)
subset_dataset = Subset(loader.dataset, subset_indices)

# Create a new loader for the subset
subset_loader = DataLoader(subset_dataset, batch_size=loader.batch_size, shuffle=True)


In [21]:
total_size = len(loader.dataset)
all_indices = set(range(total_size))
test_pool = list(all_indices - set(subset_indices))


In [23]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

# ----------------------------
# CREATE TRAIN / VAL SUBSETS
# ----------------------------

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

total_size = len(loader.dataset)

# ---- TRAIN subset: 10,000 random samples ----
train_indices = np.random.choice(total_size, 10000, replace=False)

# remaining indices
remaining = list(set(range(total_size)) - set(train_indices))

# ---- VALIDATION subset: 100 random samples from remaining ----
val_indices = np.random.choice(remaining, 200, replace=False)

# build datasets
train_dataset = Subset(loader.dataset, train_indices)
val_dataset   = Subset(loader.dataset, val_indices)

# build loaders
train_loader = DataLoader(train_dataset, batch_size=loader.batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=loader.batch_size, shuffle=False)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")


Train samples: 10000
Validation samples: 200


In [None]:
epochs = 75
train_losses = []
val_losses = []

for epoch in range(epochs):

    # ----------------------------
    # TRAIN
    # ----------------------------
    model.train()
    running_train = 0.0

    for g, k in train_loader:
        g = g.to(device)
        k = k.to(device)

        # Sample t in [0,1]
        t = torch.rand(g.size(0), 1, 1, 1, device=device)

        x_t = (1 - t) * g + t * k
        target_v = (k - g)

        v_pred = model(x_t, t)

        # Crop if needed
        if v_pred.shape != target_v.shape:
            min_time = min(v_pred.shape[-1], target_v.shape[-1])
            v_pred   = v_pred[..., :min_time]
            target_v = target_v[..., :min_time]

        loss = loss_fn(v_pred, target_v)

        opt.zero_grad()
        loss.backward()
        opt.step()

        running_train += loss.item()

    epoch_train_loss = running_train / len(train_loader)
    train_losses.append(epoch_train_loss)

    # ----------------------------
    # VALIDATION
    # ----------------------------
    model.eval()
    running_val = 0.0

    with torch.no_grad():
        for g, k in val_loader:
            g = g.to(device)
            k = k.to(device)

            t = torch.rand(g.size(0), 1, 1, 1, device=device)
            x_t = (1 - t) * g + t * k
            target_v = (k - g)

            v_pred = model(x_t, t)
            if v_pred.shape != target_v.shape:
                min_time = min(v_pred.shape[-1], target_v.shape[-1])
                v_pred   = v_pred[..., :min_time]
                target_v = target_v[..., :min_time]

            loss = loss_fn(v_pred, target_v)
            running_val += loss.item()

    epoch_val_loss = running_val / len(val_loader)
    val_losses.append(epoch_val_loss)

    print(f"Epoch {epoch+1}/{epochs} | "
          f"Train Loss: {epoch_train_loss:.6f} | "
          f"Val Loss: {epoch_val_loss:.6f}")


Epoch 1/75 | Train Loss: 256.837051 | Val Loss: 252.015227
Epoch 2/75 | Train Loss: 244.249433 | Val Loss: 238.190612
Epoch 3/75 | Train Loss: 229.863202 | Val Loss: 210.055636
Epoch 4/75 | Train Loss: 223.455367 | Val Loss: 228.011392
Epoch 5/75 | Train Loss: 218.449403 | Val Loss: 222.550588
Epoch 6/75 | Train Loss: 212.823484 | Val Loss: 204.704491
Epoch 7/75 | Train Loss: 208.626208 | Val Loss: 217.461796
Epoch 8/75 | Train Loss: 206.488108 | Val Loss: 196.781656
Epoch 9/75 | Train Loss: 203.658880 | Val Loss: 200.714588
Epoch 10/75 | Train Loss: 203.825699 | Val Loss: 195.250056
Epoch 11/75 | Train Loss: 199.881851 | Val Loss: 192.595039
Epoch 12/75 | Train Loss: 197.843902 | Val Loss: 195.578138
Epoch 13/75 | Train Loss: 194.371584 | Val Loss: 200.304590
Epoch 14/75 | Train Loss: 196.053210 | Val Loss: 181.767962
Epoch 15/75 | Train Loss: 193.543823 | Val Loss: 190.393835
Epoch 16/75 | Train Loss: 192.792097 | Val Loss: 201.605295
Epoch 17/75 | Train Loss: 189.705485 | Val Loss: 

In [None]:
# Define a file path
save_path = f"models/rectified_flow_model_{len(subset_loader)}_{epochs}.pth"

# Save after training
torch.save(model.state_dict(), save_path)

print(f"Model weights saved to {save_path}")


In [None]:
plt.figure(figsize=(8,5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Velocity Loss")
plt.title("Train vs Validation Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
plt.plot(all_losses, marker='o', linestyle='-', color='b')
plt.title("Training Velocity Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("MSE Velocity Loss")
plt.grid(True, alpha=0.3)
plt.show()


In [54]:
epochs=50
save_path = f"models/rectified_flow_model_{len(subset_loader)}_{epochs}.pth"
model = RectifiedFlowModel(n_mels=128).to(device)
model.load_state_dict(torch.load(save_path, map_location=device))
model.eval()


RectifiedFlowModel(
  (conv): Sequential(
    (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [55]:
def evaluate_model(model, loader, steps=32, max_batches=20):
    model.eval()
    mse_total = 0.0
    l1_total = 0.0
    n_batches = 0

    with torch.no_grad():
        for g, k in loader:
            g = g.to(device)   
            k = k.to(device)

            pred = translate_batch(model, g, steps=steps)  
            k_cpu = k.cpu()

            # this is for cropping the time dimension if needed
            if pred.shape != k_cpu.shape:
                min_time = min(pred.shape[-1], k_cpu.shape[-1])
                pred   = pred[..., :min_time]
                k_cpu  = k_cpu[..., :min_time]

            mse = F.mse_loss(pred, k_cpu).item()
            l1  = F.l1_loss(pred, k_cpu).item()

            mse_total += mse
            l1_total  += l1
            n_batches += 1

            if n_batches >= max_batches:
                break

    if n_batches == 0:
        return None

    mse_avg = mse_total / n_batches
    l1_avg  = l1_total / n_batches
    return mse_avg, l1_avg


In [56]:
def translate_guitar_to_keyboard(model, g, steps=32):
    """
    g: torch.Tensor of shape (1, n_mels, T) on CPU
    returns: numpy array of shape (1, n_mels, T)
    """
    model.eval()
    # add dimension: (B=1, C=1, n_mels, T)
    x = g.unsqueeze(0).to(device)   # (1, 1, n_mels, T)

    t_vals = torch.linspace(0, 1, steps, device=device)
    dt = 1.0 / steps

    with torch.no_grad():
        for t in t_vals:
            v = model(x, t)
            x = x + v * dt          # Euler integration

    return x.squeeze(0).cpu().numpy()  # (1, n_mels, T)


In [57]:
g_batch, k_batch = next(iter(loader))
g_example = g_batch[0]          # shape: (1, n_mels, T)

pred_mel = translate_guitar_to_keyboard(model, g_example)


# to go back to the audio
sr = 16000
hop_length = 256

# Use first (and only) channel
pred_mel_2d = pred_mel[0]   # (n_mels, T)

pred_audio = librosa.feature.inverse.mel_to_audio(
    librosa.db_to_power(pred_mel_2d),
    sr=sr,
    hop_length=hop_length
)
Audio(pred_audio, rate=sr)



### Chat code to compare original guitar and ground-truth keyboard for the same pair:

In [61]:
num = 125000  # pick any index in test_pool
g_example, k_example = loader.dataset[test_pool[num]] 


# Translate guitar -> predicted keyboard
pred_mel = translate_guitar_to_keyboard(model, g_example)
pred_mel_2d = pred_mel[0]
pred_audio = librosa.feature.inverse.mel_to_audio(
    librosa.db_to_power(pred_mel_2d),
    sr=sr,
    hop_length=hop_length
)

print("ðŸŽ¸ Original guitar (waveform):")
sample_row = pairs_df.iloc[test_pool[num]]
display(Audio(filename=sample_row["guitar_audio"], rate=sr))

print("ðŸŽ¹ Ground-truth keyboard (NSynth):")
display(Audio(filename=sample_row["keyboard_audio"], rate=sr))

print("ðŸŽ¹ Predicted keyboard (from model):")
display(Audio(pred_audio, rate=sr))


ðŸŽ¸ Original guitar (waveform):


ðŸŽ¹ Ground-truth keyboard (NSynth):


ðŸŽ¹ Predicted keyboard (from model):


In [31]:
import torch
import librosa
from IPython.display import Audio, display

# --- Choose a sample from pairs_df ---
sample_idx = 110055  # any index you like
sample = pairs_df.iloc[sample_idx]

# --- Load guitar and keyboard audio as tensors ---
def load_audio_to_mel(path, sr=16000, n_mels=128, hop_length=256, duration=4.0):
    y, _ = librosa.load(path, sr=sr)
    # truncate/pad to fixed length
    num_samples = int(sr * duration)
    if len(y) > num_samples:
        y = y[:num_samples]
    elif len(y) < num_samples:
        y = np.pad(y, (0, num_samples - len(y)))
    # mel spectrogram
    mel = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, hop_length=hop_length)
    mel_db = librosa.power_to_db(mel, ref=np.max)
    # add channel dim for CNN: (1, n_mels, T)
    return torch.tensor(mel_db, dtype=torch.float32).unsqueeze(0)

g_mel = load_audio_to_mel(sample["guitar_audio"])
k_mel = load_audio_to_mel(sample["keyboard_audio"])

# --- Translate guitar -> predicted keyboard ---
def translate_guitar_to_keyboard(model, g, steps=32):
    model.eval()
    x = g.unsqueeze(0).to(device)  # (1,1,n_mels,T)
    dt = 1.0 / steps
    with torch.no_grad():
        for i in range(steps):
            t = torch.full((1,1,1,1), i/steps, device=device)
            v = model(x, t)
            x = x + v * dt
    return x.squeeze(0).cpu().numpy()  # (1, n_mels, T)

pred_mel = translate_guitar_to_keyboard(model, g_mel)

# --- Convert mel spectrograms back to audio ---
def mel_to_audio(mel_tensor, sr=16000, hop_length=256):
    # Convert PyTorch tensor to NumPy
    mel_np = mel_tensor.squeeze().cpu().numpy()
    # Convert from dB to power
    mel_power = librosa.db_to_power(mel_np)
    # Invert to audio
    audio = librosa.feature.inverse.mel_to_audio(
        mel_power,
        sr=sr,
        hop_length=hop_length,
        n_iter=64
    )
    return audio

# Now call it
guitar_audio = mel_to_audio(g_mel)
keyboard_audio_gt = mel_to_audio(k_mel)
keyboard_audio_pred = mel_to_audio(torch.tensor(pred_mel))  # if pred_mel is already NumPy, skip torch.tensor()


# --- Play audios ---
print("ðŸŽ¸ Original Guitar:")
display(Audio(guitar_audio, rate=16000))

print("ðŸŽ¹ Ground-truth Keyboard:")
display(Audio(keyboard_audio_gt, rate=16000))

print("ðŸŽ¹ Predicted Keyboard:")
display(Audio(keyboard_audio_pred, rate=16000))


ðŸŽ¸ Original Guitar:


ðŸŽ¹ Ground-truth Keyboard:


ðŸŽ¹ Predicted Keyboard:


In [32]:
sample

pitch                                                                65
velocity                                                            127
sample_rate                                                       16000
guitar_audio          nsynth-train/audio/guitar_acoustic_026-065-127...
keyboard_audio        nsynth-train/audio/keyboard_acoustic_009-065-1...
guitar_inst                                         guitar_acoustic_026
keyboard_inst                                     keyboard_acoustic_009
guitar_note                                 guitar_acoustic_026-065-127
keyboard_note                             keyboard_acoustic_009-065-127
guitar_source                                                  acoustic
keyboard_source                                                acoustic
guitar_qualities                                                     []
keyboard_qualities                                             [reverb]
Name: 110055, dtype: object