In [10]:
import json

import tqdm
!pip install torch torchvision torchaudio librosa numpy matplotlib tqdm transformers mido

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [38]:
import torch.utils.data
import os
from mido import MidiFile


def find_all_mid_files(directory):
    mid_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith('.mid'):  # 检查是否为 .mid 文件
                mid_files.append(os.path.join(root, file))
    return mid_files


def midi_to_sequence(midi_path):
    midi = MidiFile(midi_path)
    sequence = []
    for track in midi.tracks:
        for msg in track:
            if msg.is_meta:
                continue
            if msg.type == 'note_on' and msg.velocity > 0:
                sequence.append(f"Note_On_{msg.note}")
            elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                sequence.append(f"Note_Off_{msg.note}")
            elif msg.type == 'control_change':
                sequence.append(f"Control_{msg.control}_{msg.value}")
            elif msg.type == 'program_change':
                sequence.append(f"Program_{msg.program}")
            elif msg.type == 'time_signature':
                sequence.append(f"Time_Signature_{msg.numerator}/{msg.denominator}")
            elif msg.type == 'set_tempo':
                sequence.append(f"Tempo_{msg.tempo}")
            else:
                raise Exception(f"Unknown message type: {msg.type}")
    return sequence


midi_files = find_all_mid_files("./dataset/nesmdb/nesmdb_midi")
_raw_sequence = dict()


def seq_of(file):
    if len(_raw_sequence) == 0:
        for _file in tqdm.tqdm(midi_files, desc="Building sequences"):
            _raw_sequence[_file] = midi_to_sequence(_file)
    return _raw_sequence[file]


seq_of(midi_files[0])

Building sequences: 100%|██████████| 5278/5278 [01:17<00:00, 68.45it/s] 


['Program_80',
 'Control_12_2',
 'Note_On_69',
 'Control_11_10',
 'Note_Off_69',
 'Note_On_71',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_71',
 'Note_On_72',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_72',
 'Note_On_79',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_79',
 'Note_On_78',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_78',
 'Note_On_74',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_74',
 'Note_On_77',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_77',
 'Note_On_71',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_71',
 'Note_On_76',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_76',
 'Note_On_69',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_69',
 'Note_On_71',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_71',
 'Note_On_72',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_72',
 'Note_On_79',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_79',
 'Note_On_78',
 'Control_11_12',
 'Control_11_10',
 'Note_Off_78',
 'Note_On_74',
 'Control_11_12',
 'Control_11_10

In [41]:
def build_vocab():
    vocab = set()
    vocab.add("<PAD>")
    vocab.add("<BOS>")
    vocab.add("<EOS>")
    for file in tqdm.tqdm(midi_files, desc="Building vocabulary"):
        for note in seq_of(file):
            vocab.add(note)
    import json
    V_dict = {token: idx for idx, token in enumerate(sorted(vocab))}
    with open("vocab.json", "w") as f:
        json.dump(V_dict, f)


def load_vocab():
    with open("vocab.json", "r") as f:
        return json.load(f)


build_vocab()
V_dict = load_vocab()
V_dict["<PAD>"]

Building vocabulary: 100%|██████████| 5278/5278 [00:01<00:00, 3183.14it/s]


1

In [71]:
def cal_avg_len():
    sum = 0
    for file in midi_files:
        sum += len(seq_of(file))
    return int(sum / len(midi_files))


# S_len = cal_avg_len()
S_len = 2048
print(f"每个词条的平均长度是：{cal_avg_len()}，取S_len为{S_len}", )

每个词条的平均长度是：2947，取S_len为2048


In [73]:
def to_std_seq_vector(seq, s_len=S_len):
    arr = [V_dict[a] for a in seq]
    arr.insert(0, V_dict["<BOS>"])
    arr = arr[:s_len - 1]
    if len(arr) < s_len:
        arr.append(V_dict["<EOS>"])
    while len(arr) < s_len:
        arr.append(V_dict["<PAD>"])
    return arr


class NesMusicDataset(torch.utils.data.Dataset):
    def __init__(self):
        super(NesMusicDataset, self).__init__()

    def __len__(self):
        return len(midi_files)

    def __getitem__(self, idx):
        seq = seq_of(midi_files[idx])
        seq = to_std_seq_vector(seq)
        input_seq = seq[:-1]
        target_seq = seq[1:]
        return torch.tensor(input_seq), torch.tensor(target_seq)


dataset = NesMusicDataset()
len(to_std_seq_vector(["<PAD>"]))
# len(dataset[0][])
# std_seq_of()
# dataset[0][:10]
# V_dict["<BOS>"]

2048

In [84]:
import torch.nn as nn


class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_len, pad_token):
        super(MusicTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token)
        self.positional_encoding = nn.Parameter(torch.zeros(max_len, embed_dim))
        self.transformer = nn.Transformer(
            d_model=embed_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=2048,
        )
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        self.pad_token = pad_token

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        # Embedding + Positional Encoding
        src_emb = self.embedding(src) + self.positional_encoding[:src.size(1), :]
        tgt_emb = self.embedding(tgt) + self.positional_encoding[:tgt.size(1), :]

        # Transformer
        output = self.transformer(
            src_emb.permute(1, 0, 2),  # 转换为 [seq_len, batch_size, embed_dim]
            tgt_emb.permute(1, 0, 2),  # 转换为 [seq_len, batch_size, embed_dim]
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
        )
        return self.fc_out(output.permute(1, 0, 2))  # 转回 [batch_size, seq_len, vocab_size]

In [90]:
# 模型参数
embed_dim = 64
num_heads = 8
num_layers = 4
max_len = 2048  # 填充后的最大序列长度
pad_token = V_dict["<PAD>"]
vocab_size = len(V_dict)
device = "cpu"
if torch.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"

# 初始化模型
model = MusicTransformer(vocab_size, embed_dim, num_heads, num_layers, max_len, pad_token).to(device)

# 忽略 PAD token 的损失
criterion = nn.CrossEntropyLoss(ignore_index=pad_token).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [92]:
from torch.utils.data import DataLoader

# 训练循环
num_epochs = 10
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_inputs, batch_targets in tqdm.tqdm(dataloader, desc=f"[{epoch}]Training"):
        optimizer.zero_grad()

        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)

        # 准备目标序列
        tgt_input = batch_inputs
        tgt_output = batch_targets

        # 注意力掩码
        src_padding_mask = (batch_inputs == pad_token)  # 忽略输入的 PAD
        tgt_padding_mask = (tgt_input == pad_token)    # 忽略目标的 PAD
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(1)).to(batch_inputs.device)

        # 前向传播
        outputs = model(batch_inputs, tgt_input, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask)

        # 计算损失
        loss = criterion(outputs.view(-1, vocab_size), tgt_output.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}")

[0]Training:   0%|          | 1/2639 [01:12<53:22:08, 72.83s/it]


KeyboardInterrupt: 