# (一)准备工作

In [None]:
# # 从github上下载代码。
!git clone https://github.com/zmx110110/Multi-Music-Transformer

In [None]:
# 进入根目录
%cd /content/Multi-Music-Transformer

# (二)安装运行环境

In [None]:
#@title nvidia-smi gpu check  (查看显卡情况)
!nvidia-smi

In [None]:
#@title Install all dependencies (安装相关依赖,在Colab中每次运行都需要从新安装)

!pip install einops
!pip install torch
!pip install torch-summary

!pip install tqdm
!pip install matplotlib

!apt install fluidsynth  # Pip does not work for some reason. Only apt works
!pip install midi2audio

In [None]:
#@title Import all needed modules(导入需要的第三方库)

print('Loading needed modules. Please wait...')
import os
import random
import copy
import math
from collections import OrderedDict

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

import torch
from torchsummary import summary

print('Loading core modules...')
# os.chdir('/content/Perceiver-Music-Transformer')
# 这里需要使用新路径
os.chdir('/content/Multi-Music-Transformer')
import TMIDIX

from zmx_ar_pytorch import PerceiverAR
from autoregressive_wrapper import AutoregressiveWrapper

from midi2audio import FluidSynth
from IPython.display import Audio, display

os.chdir('/content/Multi-Music-Transformer')
print('第三方库导入完成!')

# (三）下载训练好的权重

In [None]:
# # 第一次运行需要下载,后面再运行就不需要下载了,总共1.3G。这里需要提前将权重上传到网盘，进行下载。
!gdown https://drive.google.com/uc\?id\=161z8svINvt_ShgKSBVLug5apjPT3Hixs

In [None]:
#@title Load/Reload the model (设置好预训练权重的路径) { vertical-output: true, form-width: "400px" }

full_path_to_model_checkpoint = "/content/Multi-Music-Transformer/Best-Model.pth" #@param {type:"string"}

print('Loading the model...')
# Load model

# constants

SEQ_LEN = 8192 * 4 # 32k
PREFIX_SEQ_LEN = (8192 * 4) - 1024

model = PerceiverAR(
    num_tokens = 512,
    dim = 1024,
    depth = 24,
    heads = 16,
    dim_head = 64,
    cross_attn_dropout = 0.5,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)
model = AutoregressiveWrapper(model)
model.cuda()

state_dict = torch.load(full_path_to_model_checkpoint)

model.load_state_dict(state_dict)

model.eval()

print('Done!')

# Model stats
summary(model)

# (四)进行推理

In [None]:
#@title Load Seed/Custom MIDI (传入一段引导音乐 , 并进行数据预处理 ) { vertical-output: true, form-width: "400px", display-mode: "both" }
full_path_to_custom_MIDI_file = "/content/Multi-Music-Transformer/Input-Midi-1.mid" #@param {type:"string"}

print('Loading custom MIDI file...')
score = TMIDIX.midi2ms_score(open(full_path_to_custom_MIDI_file, 'rb').read())

events_matrix = []

itrack = 1

#==================================================

# Memories augmentator

def augment(inputs):

  outs = []
  outy = []

  for i in range(1, 12):

    out1 = []
    out2 = []

    for j in range(0, len(inputs), 4):
      note = inputs[j:j+4]

      if (note[0] // 11) != 9:
        aug_note1 = copy.deepcopy(note)
        aug_note2 = copy.deepcopy(note)
        aug_note1[3] += i
        aug_note2[3] -= i
      else:
        aug_note1 = note
        aug_note2 = note

      out1.append(aug_note1)
      out2.append(aug_note2)

    outs.append(out1[random.randint(0, int(len(out1) / 2)):random.randint(int(len(out1) / 2), len(out1))])
    outs.append(out2[random.randint(0, int(len(out2) / 2)):random.randint(int(len(out2) / 2), len(out2))])

  for i in range(64):
    outy.extend(random.choice(outs))

  outy1 = []
  for o in outy:
    outy1.extend(o)

  return outy1

#==================================================


patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

patch_map = [[0, 1, 2, 3, 4, 5, 6, 7], # Piano 
              [24, 25, 26, 27, 28, 29, 30], # Guitar
              [32, 33, 34, 35, 36, 37, 38, 39], # Bass
              [40, 41], # Violin
              [42, 43], # Cello
              [46], # Harp
              [56, 57, 58, 59, 60], # Trumpet
              [71, 72], # Clarinet
              [73, 74, 75], # Flute
              [-1], # Fake Drums
              [52, 53], # Choir
              [16, 17, 18, 19, 20] # Organ
            ]

while itrack < len(score):
    for event in score[itrack]:         
        if event[0] == 'note' or event[0] == 'patch_change':
            events_matrix.append(event)
    itrack += 1

events_matrix.sort(key=lambda x: x[1])

events_matrix1 = []
for event in events_matrix:
        if event[0] == 'patch_change':
            patches[event[2]] = event[3]

        if event[0] == 'note':
            event.extend([patches[event[3]]])
            once = False
            
            for p in patch_map:
                if event[6] in p and event[3] != 9: # Except the drums
                    event[3] = patch_map.index(p)
                    once = True
                    
            if not once and event[3] != 9: # Except the drums
                event[3] = 0 # All other instruments/patches channel
                event[5] = max(80, event[5])
                
            if event[3] < 12: # We won't write chans 11-16 for now...
                events_matrix1.append(event)

# Sorting...
events_matrix1.sort(key=lambda x: (x[1], x[3]))

# recalculating timings
for e in events_matrix1:
    e[1] = int(e[1] / 16)
    e[2] = int(e[2] / 32)

# final processing...

inputs = []

melody = []

melody_chords = []

pe = events_matrix1[0]
for e in events_matrix1:

    time = max(0, min(127, e[1]-pe[1]))
    dur = max(1, min(127, e[2]))
    cha = max(0, min(11, e[3]))
    ptc = max(1, min(127, e[4]))
    vel = max(19, min(127, e[5]))

    div_vel = int(vel / 19)

    chan_vel = (cha * 11) + div_vel

    # Continuation / Inpainting
    inputs.extend([chan_vel, time+128, dur+256, ptc+384])

    # Melody Orchestration
    if time != 0:
      if ptc < 60:
        ptc = (ptc % 12) + 60  

      
      melody.extend([div_vel, time+128, dur+256, ptc+384])

    # For future development
    melody_chords.append([time, dur, cha, ptc, vel])

    pe = e

# =================================

out1 = inputs

if len(out1) != 0:
    
    song = out1
    song_f = []
    time = 0
    dur = 0
    vel = 0
    pitch = 0
    channel = 0
    son = []
    song1 = []

    for s in song:
      if s > 127:
        son.append(s)

      else:
        if len(son) == 4:
          song1.append(son)
        son = []
        son.append(s)
    
    for s in song1:
      if s[0] > 0 and s[1] >= 128:
        if s[2] > 256 and s[3] > 384:

          channel = s[0] // 11

          vel = (s[0] % 11) * 19

          time += (s[1]-128) * 16
      
          dur = (s[2] - 256) * 32
          
          pitch = (s[3] - 384)
                                    
          song_f.append(['note', time, dur, channel, pitch, vel ])

    detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Perceiver',  
                                                        output_file_name = '/content/Perceiver-Music-Composition', 
                                                        track_name='Project Los Angeles',
                                                        list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0],
                                                        number_of_ticks_per_quarter=500)

    print('Done!')

print('Displaying resulting composition...')
fname = '/content/Perceiver-Music-Composition'

x = []
y =[]
c = []

colors = ['red', 'yellow', 'green', 'cyan', 'blue', 'pink', 'orange', 'purple', 'gray', 'white', 'gold', 'silver']

for s in song_f:
  x.append(s[1] / 1000)
  y.append(s[4])
  c.append(colors[s[3]])

FluidSynth("/usr/share/sounds/sf2/FluidR3_GM.sf2", 16000).midi_to_audio(str(fname + '.mid'), str(fname + '.wav'))
display(Audio(str(fname + '.wav'), rate=16000))

plt.figure(figsize=(14,5))
ax=plt.axes(title=fname)
ax.set_facecolor('black')

plt.scatter(x,y, c=c)
plt.xlabel("Time")
plt.ylabel("Pitch")
plt.show()

In [None]:
# 保存处理后的临时音乐
!mv  /content/Perceiver-Music-Composition.mid   /content/Temp_output.mid
!mv  /content/Perceiver-Music-Composition.wav   /content/Temp_output.wav

# (五)扩展运用 

In [None]:
#@title Single Continuation Block Generator ( 模仿input进行生成 ) { form-width: "400px", display-mode: "code" }

##@markdown NOTE: Play with the settings to get different results
number_of_prime_tokens = 512 #@param {type:"slider", min:128, max:512, step:16}
number_of_tokens_to_generate = 512 #@param {type:"slider", min:64, max:512, step:32}
temperature = 0.8 #@param {type:"slider", min:0.1, max:1, step:0.1}

#===================================================================
print('=' * 70)
print('Perceiver Music Model Continuation Generator')
print('=' * 70)

print('Generation settings:')
print('=' * 70)
print('Number of prime tokens:', number_of_prime_tokens)
print('Number of tokens to generate:', number_of_tokens_to_generate)
print('Model temperature:', temperature)

print('=' * 70)
print('Generating...')

# inp = augment(inputs)

inp = inputs * math.ceil((8192 * 4) / len(inputs))

inp = inp[:(8192 * 4)]

inp = inp[512+len(inputs[:number_of_prime_tokens]):] + inputs[:number_of_prime_tokens]

inp1 = torch.LongTensor(inp).cuda()

out = model.generate(inp1[None, ...], 
                     number_of_tokens_to_generate, 
                     temperature=temperature)  

out1 = out.cpu().tolist()[0]

if len(out1) != 0:
    
    song = inputs[:number_of_prime_tokens] + out1
    song_f = []
    time = 0
    dur = 0
    vel = 0
    pitch = 0
    channel = 0
    son = []
    song1 = []

    for s in song:
      if s > 127:
        son.append(s)

      else:
        if len(son) == 4:
          song1.append(son)
        son = []
        son.append(s)
    
    for s in song1:
      if s[0] > 0 and s[1] >= 128:
        if s[2] > 256 and s[3] > 384:

          channel = s[0] // 11

          vel = (s[0] % 11) * 19

          time += (s[1]-128) * 16
      
          dur = (s[2] - 256) * 32
          
          pitch = (s[3] - 384)
                                    
          song_f.append(['note', time, dur, channel, pitch, vel ])

    detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Perceiver',  
                                                        output_file_name = '/content/Perceiver-Music-Composition', 
                                                        track_name='Project Los Angeles',
                                                        list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0],
                                                        number_of_ticks_per_quarter=500)

    print('Done!')

print('Displaying resulting composition...')
fname = '/content/Perceiver-Music-Composition'

x = []
y =[]
c = []

colors = ['red', 'yellow', 'green', 'cyan', 'blue', 'pink', 'orange', 'purple', 'gray', 'white', 'gold', 'silver']

for s in song_f:
  x.append(s[1] / 1000)
  y.append(s[4])
  c.append(colors[s[3]])

FluidSynth("/usr/share/sounds/sf2/FluidR3_GM.sf2", 16000).midi_to_audio(str(fname + '.mid'), str(fname + '.wav'))
display(Audio(str(fname + '.wav'), rate=16000))

plt.figure(figsize=(14,5))
ax=plt.axes(title=fname)
ax.set_facecolor('black')

plt.scatter(x,y, c=c)
plt.xlabel("Time")
plt.ylabel("Pitch")
plt.show()

In [None]:
# 保存生成的中间音乐
!mv  /content/Perceiver-Music-Composition.mid   /content/Middle_output.mid
!mv  /content/Perceiver-Music-Composition.wav   /content/Middle_output.wav

In [None]:
#@title Auto-Continue Custom MIDI ( AI 自动生成 ) { form-width: "400px" }

number_of_continuation_notes = 270 #@param {type:"slider", min:10, max:500, step:10}
number_of_prime_tokens = 512 #@param {type:"slider", min:128, max:512, step:16}
temperature = 0.8 #@param {type:"slider", min:0.1, max:1, step:0.1}

#===================================================================
print('=' * 70)
print('Perceiver Music Model Auto-Continuation Generator')
print('=' * 70)

print('Generation settings:')
print('=' * 70)
print('Number of continuation notes:', number_of_continuation_notes)
print('Number of prime tokens:', number_of_prime_tokens)
print('Model temperature:', temperature)

print('=' * 70)
print('Generating...')

out2 = copy.deepcopy(inputs[:number_of_prime_tokens])

# aug_inp = augment(inputs)

for i in tqdm(range(number_of_continuation_notes)):

  # inp = copy.deepcopy(aug_inp)

  inp = inputs * math.ceil((8160 * 6) / len(inputs))

  inp = inp[:(8192 * 4)]

  inp = inp[512+len(out2):] + out2

  inp = torch.LongTensor(inp).cuda()

  out = model.generate(inp[None, ...], 
                      4, 
                      temperature=temperature)  

  out1 = out.cpu().tolist()[0]
  out2.extend(out1)

if len(out2) != 0:
    
    song = out2
    song_f = []
    time = 0
    dur = 0
    vel = 0
    pitch = 0
    channel = 0
    son = []
    song1 = []

    for s in song:
      if s > 127:
        son.append(s)

      else:
        if len(son) == 4:
          song1.append(son)
        son = []
        son.append(s)
    
    for s in song1:
      if s[0] > 0 and s[1] >= 128:
        if s[2] > 256 and s[3] > 384:

          channel = s[0] // 11

          vel = (s[0] % 11) * 19

          time += (s[1]-128) * 16
      
          dur = (s[2] - 256) * 32
          
          pitch = (s[3] - 384)
                                    
          song_f.append(['note', time, dur, channel, pitch, vel ])

    detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Perceiver',  
                                                        output_file_name = '/content/Perceiver-Music-Composition', 
                                                        track_name='Project Los Angeles',
                                                        list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0],
                                                        number_of_ticks_per_quarter=500)

    print('Done!')

print('Displaying resulting composition...')
fname = '/content/Perceiver-Music-Composition'

x = []
y =[]
c = []

colors = ['red', 'yellow', 'green', 'cyan', 'blue', 'pink', 'orange', 'purple', 'gray', 'white', 'gold', 'silver']

for s in song_f:
  x.append(s[1] / 1000)
  y.append(s[4])
  c.append(colors[s[3]])

FluidSynth("/usr/share/sounds/sf2/FluidR3_GM.sf2", 16000).midi_to_audio(str(fname + '.mid'), str(fname + '.wav'))
display(Audio(str(fname + '.wav'), rate=16000))

plt.figure(figsize=(14,5))
ax=plt.axes(title=fname)
ax.set_facecolor('black')

plt.scatter(x,y, c=c)
plt.xlabel("Time")
plt.ylabel("Pitch")
plt.show()

In [None]:
# 保存生成的音乐
!mv  /content/Perceiver-Music-Composition.mid   /content/Final_output.mid
!mv  /content/Perceiver-Music-Composition.wav   /content/Final_output.wav

# congratulations! 您运行成功 !