In [None]:
!sudo apt install -y fluidsynth
!pip install --upgrade pyfluidsynth
!pip install pretty_midi
!pip install midi2audio
!pip install pydub

In [None]:
import tensorflow as tf
import pathlib
import glob
import fluidsynth
import pretty_midi

from IPython.display import display, Audio
from midi2audio import FluidSynth
from pydub import AudioSegment

_SAMPLING_RATE = 16000

In [None]:
# ALTENRATIVE DATASET Lakh MIDI data layout
data_matched_dir = pathlib.Path('data/LakhMIDI')
if not data_matched_dir.exists():
  tf.keras.utils.get_file(
      'lmd_matched.tar.gz',
      origin='http://hog.ee.columbia.edu/craffel/lmd/lmd_matched.tar.gz',
      extract=True,
      cache_dir='.', cache_subdir='data',
  )

data_allgined_dir = pathlib.Path('data/LakhMIDI')
if not data_allgined_dir.exists():
  tf.keras.utils.get_file(
      'lmd_aligned.tar.gz',
      origin='http://hog.ee.columbia.edu/craffel/lmd/lmd_aligned.tar.gz',
      extract=True,
      cache_dir='.', cache_subdir='data',
  )

data_matchscores_dir = pathlib.Path('data/LakhMIDI')
if not data_allgined_dir.exists():
  tf.keras.utils.get_file(
      'match_scores.json',
      origin='http://hog.ee.columbia.edu/craffel/lmd/match_scores.json',
      extract=True,
      cache_dir='.', cache_subdir='data',
  )

In [None]:
# ALTENRATIVE DATASET Lakh MIDI
import json
import os

# Local path constants
DATA_PATH = "data"
RESULTS_PATH = 'data'
# Path to the file match_scores.json distributed with the LMD
score_file = 'data/match_scores.json'

# Utility functions for retrieving paths
def msd_id_to_dirs(msd_id):
    """Given an MSD ID, generate the path prefix.
    E.g. TRABCD12345678 -> A/B/C/TRABCD12345678"""
    return os.path.join(msd_id[2], msd_id[3], msd_id[4], msd_id)

def msd_id_to_mp3(msd_id):
    """Given an MSD ID, return the path to the corresponding mp3"""
    return os.path.join(DATA_PATH, 'msd', 'mp3',
                        msd_id_to_dirs(msd_id) + '.mp3')

def msd_id_to_h5(h5):
    """Given an MSD ID, return the path to the corresponding h5"""
    return os.path.join(RESULTS_PATH, 'lmd_matched_h5',
                        msd_id_to_dirs(msd_id) + '.h5')

def get_midi_path(msd_id, midi_md5, kind):
    """Given an MSD ID and MIDI MD5, return path to a MIDI file.
    kind should be one of 'matched' or 'aligned'. """
    return os.path.join('data', 'lmd_{}'.format(kind), # line might be incorrect, change later after some testing
                        msd_id_to_dirs(msd_id), midi_md5 + '.mid')

# Load the match scores to find all aligned MIDI files
def load_aligned_midi_paths(score_file, kind='aligned'):
    with open(score_file, 'r') as file:
        scores = json.load(file)

    midi_paths = []
    for msd_id, matches in scores.items():
        for midi_md5 in matches:
            midi_path = get_midi_path(msd_id, midi_md5, kind)
            midi_paths.append(midi_path)

    return midi_paths

In [None]:
aligned_midi_files = load_aligned_midi_paths(score_file, 'aligned')
print('Number of files:', len(aligned_midi_files))

In [None]:
import random
import IPython.display as ipd
import copy

original_dir = pathlib.Path('augmented/original')
pitch_dir = pathlib.Path('augmented/pitch')
vol_dir = pathlib.Path('augmented/volume')
tempo_dir = pathlib.Path('augmented/tempo')
original_dir.mkdir(parents=True, exist_ok=True)
pitch_dir.mkdir(parents=True, exist_ok=True)
vol_dir.mkdir(parents=True, exist_ok=True)
tempo_dir.mkdir(parents=True, exist_ok=True)

def save_as_wav(midi_file, file_name):
  wav_file_path = original_dir / f"{file_name}.wav"
  FluidSynth().midi_to_audio(midi_file, wav_file_path)
  return wav_file_path

def play_midi(pm: pretty_midi.PrettyMIDI, seconds=5):
  waveform = pm.fluidsynth(fs=_SAMPLING_RATE)
  # Take a sample of the generated waveform to mitigate kernel resets
  waveform_short = waveform[:seconds*_SAMPLING_RATE]
  return Audio(waveform_short, rate=_SAMPLING_RATE)

def clip_length(file, seconds=5):
  audio = AudioSegment.from_wav(file)
  if len(audio) > seconds * 1000:
    audio = audio[:seconds * 1000]
    audio.export(file, format='wav')

def augment_pitch(pm, augment_dir, file_name, pitch_shift_range=(-6, 6)):
    # Apply pitch transposition
    pitch_shift = random.randint(*pitch_shift_range)
    print(f"Applying pitch shift of {pitch_shift} semitones.")
    for instrument in pm.instruments:
        for note in instrument.notes:
            note.pitch = max(0, min(127, note.pitch + pitch_shift))  # Ensure pitch is within MIDI range

    augmented_file_path = pitch_dir / f"{file_name}_pitch_{pitch_shift}.mid"

    pm.write(str(augmented_file_path))
    wav_file_path = str(augment_dir / f"{file_name}_pitch_{pitch_shift}.wav")
    FluidSynth().midi_to_audio(augmented_file_path, wav_file_path)
    clip_length(wav_file_path)

    return augmented_file_path, wav_file_path, pitch_shift

def augment_vol(pm, augment_dir, file_name, volume_augment_range=(-30, 30)):
  # Apply pitch transposition
  vol_shift = random.randint(*volume_augment_range)
  print(f"Applying volume shift of {vol_shift} velocity.")
  for instrument in pm.instruments:
      for note in instrument.notes:
          note.velocity = max(1, min(127, note.velocity + vol_shift))

  augmented_file_path = vol_dir / f"{file_name}_vol_{vol_shift}.mid"

  pm.write(str(augmented_file_path))
  wav_file_path = str(augment_dir / f"{file_name}_vol_{vol_shift}.wav")
  FluidSynth().midi_to_audio(augmented_file_path, wav_file_path)
  clip_length(wav_file_path)

  return augmented_file_path, wav_file_path, vol_shift


def augment_tempo(pm, augment_dir, file_name, tempo_augment_choices=(0.25, 0.5, 0.75, 1.25, 1.5, 1.75)):
  # Apply tempo transformation
  tempo_shift = random.choice(tempo_augment_choices)
  print(f"Applying tempo shift of {tempo_shift} tempo.")
  for intrument in pm.instruments:
    for note in intrument.notes:
      note.start *= tempo_shift
      note.end *= tempo_shift

  augmented_file_path = tempo_dir / f"{file_name}_tempo_{tempo_shift}.mid"
  pm.write(str(augmented_file_path))
  wav_file_path = str(augment_dir / f"{file_name}_tempo_{tempo_shift}.wav")
  FluidSynth().midi_to_audio(augmented_file_path, wav_file_path)
  clip_length(wav_file_path)

  return augmented_file_path, wav_file_path, tempo_shift

In [None]:
file_5 = aligned_midi_files[5]
file_name = 'file_5'
original_wav_path = save_as_wav(file_5, file_name)
pm = pretty_midi.PrettyMIDI(file_5)
play_midi(pm)

In [None]:
pitch_midi_path, pitch_wav_path, pitch_shift = augment_pitch(copy.deepcopy(pm), pitch_dir, file_name)
pm_pitch = pretty_midi.PrettyMIDI(str(pitch_midi_path))
play_midi(pm_pitch)

In [None]:
vol_midi_path, vol_wav_path, vol_shift = augment_vol(copy.deepcopy(pm), vol_dir, file_name)
pm_vol = pretty_midi.PrettyMIDI(str(vol_midi_path))
play_midi(pm_vol)

In [None]:
tempo_midi_path, tempo_wav_path, tempo_shift = augment_tempo(copy.deepcopy(pm), tempo_dir, file_name, tempo_augment_choices=[1.75])
pm_tempo = pretty_midi.PrettyMIDI(str(tempo_midi_path))
play_midi(pm_tempo)

In [None]:
!pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
!pip install audiocraft

In [None]:
import torch
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from audiocraft.utils.notebook import display_audio

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MusicGen.get_pretrained('facebook/musicgen-melody', device=device)
model.set_generation_params(duration=10)  # generate 10 seconds.

In [None]:
generated_dir = pathlib.Path('generated')
generated_dir.mkdir(parents=True, exist_ok=True)

In [None]:
melody, sample_rate = torchaudio.load(original_wav_path)
prompt_duration = 8
melody = melody[..., :int(prompt_duration * sample_rate)]
output = model.generate_continuation(melody, prompt_sample_rate=sample_rate, progress=True)
display_audio(output, sample_rate=sample_rate)

In [None]:
melody, sample_rate = torchaudio.load(tempo_wav_path)
prompt_duration = 8
melody = melody[..., :int(prompt_duration * sample_rate)]
output = model.generate_continuation(melody, prompt_sample_rate=sample_rate, progress=True)
display_audio(output, sample_rate=sample_rate)

In [None]:
def apply_augmentation_and_generate_output(i, model):
  midi_file = aligned_midi_files[i]
  file_name = f'file_{i}'
  original_wav_path = save_as_wav(midi_file, file_name)
  pm = pretty_midi.PrettyMIDI(midi_file)

  pitch_midi_path, pitch_wav_path, pitch_shift = augment_pitch(copy.deepcopy(pm), pitch_dir, file_name)
  vol_midi_path, vol_wav_path, vol_shift = augment_vol(copy.deepcopy(pm), vol_dir, file_name)
  tempo_midi_path, tempo_wav_path, tempo_shift = augment_tempo(copy.deepcopy(pm), tempo_dir, file_name)

  file_paths = {
      'original': original_wav_path,
      f'pitch_{pitch_shift}': pitch_wav_path,
      f'vol_{vol_shift}': vol_wav_path,
      f'tempo_{tempo_shift}': tempo_wav_path
  }

  for augmentation in file_paths:
    melody, sample_rate = torchaudio.load(file_paths[augmentation])
    prompt_duration = 8
    melody = melody[..., :int(prompt_duration * sample_rate)]
    output = model.generate_continuation(melody, prompt_sample_rate=sample_rate, progress=True)

    for idx, one_wav in enumerate(output):
      audio_write(f'generated/{file_name}_{augmentation}_output', one_wav.cpu(), model.sample_rate, strategy='loudness')


In [None]:
for i in range(50):
  if i % 5 == 0:
    print(f'Augmenting and Generating file {i}.')
  apply_augmentation_and_generate_output(i, model)

In [None]:
# save folders to zip so we can download
import locale
locale.getpreferredencoding = lambda: "UTF-8"

!zip -r /content/augmented.zip /content/augmented
!zip -r /content/generated.zip /content/generated