In [None]:
import os
import pickle
import difflib
from collections import defaultdict

import numpy as np
import pandas as pd
import pretty_midi
from music21 import chord, stream, note
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence

In [None]:
# -------------------------------
# Functions CSV (R_Pretrain) â†’ progressions
# -------------------------------
def binvec_to_pcset(binvec):
    return [i for i, v in enumerate(binvec) if v == 1]

def parse_binvec(cell):
    if isinstance(cell, str):
        cell = cell.replace('[', '').replace(']', '')
        return [int(x) for x in cell.split()]
    return list(cell)

def csv_to_chord_progressions(csv_path, chord_cols):
    df = pd.read_csv(csv_path)
    all_progressions = []

    for _, row in df.iterrows():
        prog = []
        for col in chord_cols:
            binvec = parse_binvec(row[col])
            pcset = binvec_to_pcset(binvec)
            if len(pcset) >= 1:
                prog.append(pcset)
        if len(prog) > 0:
            all_progressions.append(prog)
    return all_progressions

def transpose_progression(prog, semitones):
    return [[(n + semitones) % 12 for n in chord] for chord in prog]

def chord_to_token(chord):
    return '_'.join(map(str, sorted(chord)))

def build_vocab(progs):
    token_to_id = {'<pad>':0, '<sos>':1, '<eos>':2}
    id_to_token = {0:'<pad>', 1:'<sos>', 2:'<eos>'}
    current_id = 3
    for prog in progs:
        for chord in prog:
            tok = chord_to_token(chord)
            if tok not in token_to_id:
                token_to_id[tok] = current_id
                id_to_token[current_id] = tok
                current_id += 1
    return token_to_id, id_to_token

def progression_to_token_seq(prog, token_to_id):
    return [token_to_id[chord_to_token(ch)] for ch in prog]

# -------------------------------
# PRETRAIN from CSV (R_Pretrain)
# -------------------------------
CSV_PATH = "R_Pretrain.csv"
CHORD_COLS = ["Chord_1","Chord_2","Chord_3","Chord_4","Chord_5"]

X_csv_raw = csv_to_chord_progressions(CSV_PATH, CHORD_COLS)
print(f"Total progresiones CSV (pretrain): {len(X_csv_raw)}")

# Transpose augmentations
X_pretrain_raw = []
for prog in X_csv_raw:
    for shift in range(12):
        X_pretrain_raw.append(transpose_progression(prog, shift))

# vocab
token_to_id_pretrain, id_to_token_pretrain = build_vocab(X_pretrain_raw)


X_pretrain = []
Y_pretrain = []
for prog in X_pretrain_raw:
    token_seq = progression_to_token_seq(prog, token_to_id_pretrain)
    X_pretrain.append(torch.tensor(token_seq, dtype=torch.long))
    Y_pretrain.append(torch.tensor([1] + token_seq + [2], dtype=torch.long))

# -------------------------------
# TRAIN from MIDIs
# -------------------------------
# - TRAIN_DIR -> MIDIs X,Y
# -------------------------------
# Datasets ans loaders
# -------------------------------
class PretrainDataset(Dataset):
    def __init__(self, X_data, Y_data):
        self.X = X_data
        self.Y = Y_data
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]
    @staticmethod
    def collate_fn(batch):
        X, Y = zip(*batch)
        return pad_sequence(X, batch_first=True, padding_value=0), pad_sequence(Y, batch_first=True, padding_value=0)

pretrain_dataset = PretrainDataset(X_pretrain, Y_pretrain)
pretrain_loader = DataLoader(pretrain_dataset, batch_size=8, shuffle=True, collate_fn=PretrainDataset.collate_fn)

print(f"Pretrain dataset size: {len(pretrain_dataset)}")
for xb, yb in pretrain_loader:
    print("Pretrain batch shapes - X:", xb.shape, "Y:", yb.shape)
    break


In [None]:
# ---------------- Config ----------------
USE_PRETRAIN_MIDI = False  # True for pretrain from CSV(R_Pretrain)
#  USE_PRETRAIN_MIDI = True para generar pretrain desde MIDIs
USE_AUGMENTATION = False
PRETRAIN_DIR = "PRE_TRAIN"
TRAIN_DIR = "MIDI_TRAIN"
TOLERANCE = 0.5
MIN_NOTES = 2
BATCH_SIZE = 8

# ---------------- MIDIs ----------------
def midi_to_chords_train(midi_path, tolerance=TOLERANCE, min_notes=MIN_NOTES):
    pm = pretty_midi.PrettyMIDI(midi_path)
    all_notes = []
    for instrument in pm.instruments:
        if instrument.is_drum:
            continue
        all_notes.extend(instrument.notes)

    notes_by_time = defaultdict(list)
    for note in all_notes:
        time_key = round(note.start / tolerance) * tolerance
        notes_by_time[time_key].append(note.pitch)

    chords_out = []
    for time in sorted(notes_by_time.keys()):
        note_group = notes_by_time[time]
        if len(note_group) < min_notes:
            continue
        try:
            m21_chord = chord.Chord(note_group)
            pc_set = sorted(set(n % 12 for n in note_group))
        except Exception:
            pc_set = sorted(set(n % 12 for n in note_group))
        chords_out.append(pc_set)
    return chords_out

# ---------------- Other ----------------
def transpose_progression(prog, semitones):
    return [[(n + semitones) % 12 for n in chord] for chord in prog]

def chord_to_token(chord):
    return '_'.join(map(str, sorted(chord)))

def build_vocab(progs):
    token_to_id = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
    id_to_token = {0: '<pad>', 1: '<sos>', 2: '<eos>'}
    current_id = 3
    for prog in progs:
        for chord in prog:
            tok = chord_to_token(chord)
            if tok not in token_to_id:
                token_to_id[tok] = current_id
                id_to_token[current_id] = tok
                current_id += 1
    return token_to_id, id_to_token

def progression_to_token_seq(prog, token_to_id):
    return [token_to_id[chord_to_token(ch)] for ch in prog]

# ---------------- Pretrain from CSV ----------------
X_csv_raw = csv_to_chord_progressions(CSV_PATH, CHORD_COLS)

# ---------------- Pretrain ----------------
X_pretrain_raw = []

if USE_PRETRAIN_MIDI:
    pretrain_files = [f for f in os.listdir(PRETRAIN_DIR) if f.endswith(".mid")]
    for file_name in pretrain_files:
        path = os.path.join(PRETRAIN_DIR, file_name)
        chords = midi_to_chords_train(path)
        if USE_AUGMENTATION:
            for shift in range(12):
                X_pretrain_raw.append(transpose_progression(chords, shift))
        else:
            X_pretrain_raw.append(chords)
else:
    X_csv_raw = csv_to_chord_progressions(CSV_PATH, CHORD_COLS)
    for prog in X_csv_raw:
        if USE_AUGMENTATION:
            for shift in range(12):
                X_pretrain_raw.append(transpose_progression(prog, shift))
        else:
            X_pretrain_raw.append(prog)

token_to_id_pretrain, id_to_token_pretrain = build_vocab(X_pretrain_raw)

X_pretrain = []
Y_pretrain = []
for prog in X_pretrain_raw:
    token_seq = progression_to_token_seq(prog, token_to_id_pretrain)
    X_pretrain.append(torch.tensor(token_seq, dtype=torch.long))
    Y_pretrain.append(torch.tensor([1] + token_seq + [2], dtype=torch.long))

with open("token_to_id_pretrain.pkl", "wb") as f:
    pickle.dump(token_to_id_pretrain, f)
with open("id_to_token_pretrain.pkl", "wb") as f:
    pickle.dump(id_to_token_pretrain, f)


# ---------------- Superv-Train from MIDIs ----------------
X_train_files = []
Y_train_files = []

for file_name in os.listdir(TRAIN_DIR):
    index = file_name.split("_")[0]
    if file_name.startswith(index + "_1_pf_org"):
        X_train_files.append((int(index), os.path.join(TRAIN_DIR, file_name)))
    elif file_name.startswith(index + "_2_pf_rhm"):
        Y_train_files.append((int(index), os.path.join(TRAIN_DIR, file_name)))

X_train_files.sort(key=lambda x: x[0])
Y_train_files.sort(key=lambda x: x[0])

X_train_paths = [x[1] for x in X_train_files]
Y_train_paths = [y[1] for y in Y_train_files]

X_train_raw = [midi_to_chords_train(path) for path in X_train_paths]
Y_train_raw = [midi_to_chords_train(path) for path in Y_train_paths]

X_train_aug = []
Y_train_aug = []

for x_prog, y_prog in zip(X_train_raw, Y_train_raw):
    if USE_AUGMENTATION:
        for shift in range(12):
            X_train_aug.append(transpose_progression(x_prog, shift))
            Y_train_aug.append(transpose_progression(y_prog, shift))
    else:
        X_train_aug.append(x_prog)
        Y_train_aug.append(y_prog)

token_to_id_train, id_to_token_train = build_vocab(X_train_aug + Y_train_aug)

X_train = []
Y_train = []

for x_prog, y_prog in zip(X_train_aug, Y_train_aug):
    x_tokens = progression_to_token_seq(x_prog, token_to_id_train)
    y_tokens = progression_to_token_seq(y_prog, token_to_id_train)
    X_train.append(torch.tensor(x_tokens, dtype=torch.long))
    Y_train.append(torch.tensor([1] + y_tokens + [2], dtype=torch.long))

# ---------------- Datasets ----------------
class PretrainDataset(Dataset):
    def __init__(self, X_data, Y_data):
        self.X = X_data
        self.Y = Y_data
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]
    @staticmethod
    def collate_fn(batch):
        X, Y = zip(*batch)
        return pad_sequence(X, batch_first=True, padding_value=0), pad_sequence(Y, batch_first=True, padding_value=0)

class TrainDataset(Dataset):
    def __init__(self, X_data, Y_data):
        self.X = X_data
        self.Y = Y_data
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]
    @staticmethod
    def collate_fn(batch):
        X, Y = zip(*batch)
        return pad_sequence(X, batch_first=True, padding_value=0), pad_sequence(Y, batch_first=True, padding_value=0)

pretrain_dataset = PretrainDataset(X_pretrain, Y_pretrain)
train_dataset = TrainDataset(X_train, Y_train)

pretrain_loader = DataLoader(pretrain_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=PretrainDataset.collate_fn)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TrainDataset.collate_fn)

print(f"Pretrain dataset size: {len(pretrain_dataset)}")
print(f"Train dataset size: {len(train_dataset)}")

# ---------------- Examples ----------------
ej_idx = 0
x_tokens, y_tokens = pretrain_dataset[ej_idx]
x_chords = [id_to_token_pretrain.get(tok, "?") for tok in x_tokens.tolist()]
y_chords = [id_to_token_pretrain.get(tok, "?") for tok in y_tokens.tolist()]

print("ðŸ”¹  X (input):", x_chords)
print("ðŸ”¸  Y (target):", y_chords)


In [None]:

def chord_to_token(chord):
    return '_'.join(map(str, sorted(chord)))

def fallback_token_str(chord_str, token_to_id):
    matches = difflib.get_close_matches(chord_str, token_to_id.keys(), n=1)
    return token_to_id[matches[0]] if matches else token_to_id.get('<pad>', 0)

def progression_to_token_seq(prog, token_to_id):
    tokens = []
    for chord in prog:
        token = chord_to_token(chord)
        if token in token_to_id:
            tokens.append(token_to_id[token])
        else:
            tokens.append(fallback_token_str(token, token_to_id))
    return tokens

class ChordPairDataset(Dataset):
    def __init__(self, X_data, Y_data):
        self.X_data = X_data
        self.Y_data = Y_data

    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, idx):
        return self.X_data[idx], self.Y_data[idx]

    @staticmethod
    def collate_fn(batch):
        X_batch, Y_batch = zip(*batch)
        X_padded = pad_sequence(X_batch, batch_first=True, padding_value=0)
        Y_padded = pad_sequence(Y_batch, batch_first=True, padding_value=0)
        return X_padded, Y_padded

X_data = []
Y_data = []

for x_prog, y_prog in zip(X_train_aug, Y_train_aug): 
    x_tokens = progression_to_token_seq(x_prog, token_to_id_train)
    y_tokens = [token_to_id_train['<sos>']] + progression_to_token_seq(y_prog, token_to_id_train) + [token_to_id_train['<eos>']]
    X_data.append(torch.tensor(x_tokens, dtype=torch.long))
    Y_data.append(torch.tensor(y_tokens, dtype=torch.long))

X_train_split, X_val_split, Y_train_split, Y_val_split = train_test_split(
    X_data, Y_data, test_size=0.2, random_state=42)

train_dataset = ChordPairDataset(X_train_split, Y_train_split)
val_dataset = ChordPairDataset(X_val_split, Y_val_split)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=ChordPairDataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=ChordPairDataset.collate_fn)

for X_batch, Y_batch in train_loader:
    print("X batch shape:", X_batch.shape)
    print("Y batch shape:", Y_batch.shape)
    break

print("Total batches pretrain:", len(pretrain_loader))
print("Total batches train:", len(train_loader))

def token_seq_to_chords(token_seq, id_to_token):
    chords = []
    for idx in token_seq:
        token = id_to_token.get(idx, '<unk>')
        if token in ['<pad>', '<sos>', '<eos>']:
            continue
        chords.append(list(map(int, token.split('_'))))
    return chords


In [None]:
token_to_id = {'<pad>': 0, '<sos>': 1, '<eos>': 2}
id_to_token = {0: '<pad>', 1: '<sos>', 2: '<eos>'}

def chord_to_token(chord):
    if isinstance(chord, torch.Tensor):
        if chord.ndim == 0:
            chord = [chord.item()]
        else:
            chord = chord.tolist()
    return '_'.join(str(n) for n in sorted(chord))

def construir_vocabulario(datasets):
    for data in datasets:
        for prog in data:
            for chord in prog:
                token = chord_to_token(chord)
                if token not in token_to_id:
                    idx = len(token_to_id)
                    token_to_id[token] = idx
                    id_to_token[idx] = token

construir_vocabulario([X_train, Y_train, X_pretrain, Y_pretrain])
vocab_size = len(token_to_id)

print(f"âœ… Vocab maked: {vocab_size} tokens")

In [None]:
class CVAE_LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, vocab_size, embedding_dim, num_layers=2, dropout=0.0):
        super(CVAE_LSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.vocab_size = vocab_size

        # Encoder
        self.encoder_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder_lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.decoder_lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim * num_layers)
        self.output_layer = nn.Linear(hidden_dim, vocab_size)

    def encode(self, x):
        x = self.encoder_embedding(x)
        _, (h_n, _) = self.encoder_lstm(x)
        h_last = h_n[-1]
        mu = self.fc_mu(h_last)
        logvar = self.fc_logvar(h_last)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, y_seq, teacher_forcing_ratio=0.4):
        batch_size, seq_len = y_seq.size()
        embedding = self.embedding

        hidden_state = torch.tanh(self.latent_to_hidden(z))
        hidden = hidden_state.view(self.decoder_lstm.num_layers, batch_size, self.hidden_dim)
        cell = torch.zeros_like(hidden).to(hidden.device)

        inputs = y_seq[:, 0]
        outputs = []

        for t in range(1, seq_len):
            input_embed = embedding(inputs).unsqueeze(1)
            output, (hidden, cell) = self.decoder_lstm(input_embed, (hidden, cell))
            output_logits = self.output_layer(output.squeeze(1))
            outputs.append(output_logits)

            teacher_force = (torch.rand(1).item() < teacher_forcing_ratio)
            top1 = output_logits.argmax(1)
            inputs = y_seq[:, t] if teacher_force else top1

        outputs = torch.stack(outputs, dim=1)
        return outputs


    def forward(self, x, y_seq=None, teacher_forcing_ratio=0.8):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        if y_seq is not None:
            y_hat = self.decode(z, y_seq, teacher_forcing_ratio)
            return y_hat, mu, logvar
        else:
            y_hat = self.generate(z, max_len=10)
            return y_hat, mu, logvar
    def generate(self, z, max_len=10, start_token_id=None, eos_token_id=None):
        batch_size = z.size(0)
        hidden_state = torch.tanh(self.latent_to_hidden(z))
        hidden = hidden_state.view(self.decoder_lstm.num_layers, batch_size, self.hidden_dim)
        cell = torch.zeros_like(hidden).to(hidden.device)

        inputs = torch.full((batch_size,), start_token_id, dtype=torch.long).to(z.device)
        generated_tokens = []

        for _ in range(max_len):
            input_embed = self.embedding(inputs).unsqueeze(1)
            output, (hidden, cell) = self.decoder_lstm(input_embed, (hidden, cell))
            output_logits = self.output_layer(output.squeeze(1))

            probs = torch.softmax(output_logits, dim=-1)
            inputs = torch.multinomial(probs, num_samples=1).squeeze(1)

            generated_tokens.append(inputs)

            if eos_token_id is not None:
                if (inputs == eos_token_id).all():
                    break

        generated_tokens = torch.stack(generated_tokens, dim=1)
        return generated_tokens

def detectar_escala(notas_midi):
    s = stream.Stream()
    for n in notas_midi:
        s.append(note.Note(n))
    k = s.analyze('key')
    escala = k.getScale().getPitches()
    escala_mod12 = sorted({p.pitchClass for p in escala})
    return escala_mod12

def tokens_to_root_notes(tokens):
    root_notes = []
    for token in tokens:
        chord = id_to_token[token]
        if chord in ['<sos>', '<eos>', '<pad>']:
            continue
        try:
            notas = [int(x) for x in chord.split('_')]
            root_note = min(notas)
            root_notes.append(root_note)
        except ValueError:
            continue
    return root_notes


def loss_coherencia_tonal_soft(y_hat_step, root_note, escala_mod12):
    probs = F.softmax(y_hat_step, dim=-1)
    fuera_escala = torch.zeros(probs.size(-1), device=probs.device)
    
    for idx in range(probs.size(-1)):
        chord = id_to_token[idx]
        if chord in ['<sos>', '<eos>', '<pad>']:
            fuera_escala[idx] = 0
        else:
            try:
                notas = [root_note + int(x) for x in chord.split('_')]
                count_fuera = sum(1 for n in notas if n % 12 not in escala_mod12)
                fuera_escala[idx] = count_fuera / len(notas)
            except:
                fuera_escala[idx] = 0
    loss = torch.sum(probs * fuera_escala)
    return loss

def loss_movimiento_armonico_soft(root_notes, device='cpu'):
    loss = torch.tensor(0., device=device)
    for i in range(1, len(root_notes)):
        salto = abs(root_notes[i] - root_notes[i-1])
        salto_tensor = torch.tensor(salto - 7.0, device=device)
        loss += torch.relu(salto_tensor)
    return loss / max(1, len(root_notes)-1)


def loss_tensiones_soft(y_hat_step):
    probs = F.softmax(y_hat_step, dim=-1)
    penalizacion = torch.zeros(probs.size(-1), device=probs.device)

    for idx in range(probs.size(-1)):
        chord = id_to_token[idx]
        if chord in ['<sos>', '<eos>', '<pad>']:
            penalizacion[idx] = 0
        else:
            try:
                num_notas = len(chord.split('_'))
                penalizacion[idx] = 1 if num_notas < 4 else 0
            except:
                penalizacion[idx] = 0

    loss = torch.sum(probs * penalizacion)
    return loss

def cvae_loss(y_hat, y, mu, logvar, beta=0.1, free_bits=1.0, pad_idx=0):
    y_target = y[:, 1:]
    recon_loss = F.cross_entropy(
        y_hat.reshape(-1, y_hat.size(-1)),
        y_target.reshape(-1),
        reduction='sum',
        ignore_index=pad_idx 
    )
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_loss = torch.clamp(kl_loss, min=free_bits)
    return recon_loss + beta * kl_loss, recon_loss, kl_loss


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CVAE_LSTM(
    input_dim=32, 
    hidden_dim=64,
    embedding_dim=64,
    latent_dim=64,
    num_layers=1, 
    dropout=0.0,
    vocab_size=vocab_size,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def entrenar_cvae(model, dataloader, optimizer, device, beta=0.1, teacher_forcing_ratio=0.8, usar_music_loss=False, w_coherencia=2.0, w_tensiones=0.1, w_movimiento=0.3):
    model.train()
    total_loss, total_recon, total_kl, total_music = 0, 0, 0, 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        y_hat, mu, logvar = model(x, y, teacher_forcing_ratio)
        loss, recon, kl = cvae_loss(y_hat, y, mu, logvar, beta)

        if usar_music_loss:
            batch_music_loss = 0
            for b in range(y.size(0)):
                y_hat_seq = y_hat[b]
                y_seq = y[b]
                entrada = x[b].tolist()
                root_note = None
                for token_id in entrada:
                    token = id_to_token[token_id]
                    if token not in ['<sos>', '<pad>', '<eos>']:
                        try:
                            notas = [int(n) for n in token.split('_')]
                            root_note = min(notas)
                            break
                        except:
                            continue
                if root_note is None:
                    root_note = 0

                escala = detectar_escala([root_note + int(n) for n in id_to_token[entrada[1]].split('_')])
                for t in range(y_hat_seq.size(0)):
                    y_hat_step = y_hat_seq[t]
                    music1 = loss_coherencia_tonal_soft(y_hat_step, root_note, escala)
                    music2 = loss_tensiones_soft(y_hat_step)
                    batch_music_loss += w_coherencia * music1 + w_tensiones * music2

                root_pred = tokens_to_root_notes(y_hat[b].argmax(dim=1).tolist())
                batch_music_loss += w_movimiento * loss_movimiento_armonico_soft(root_pred, device=device)
            batch_music_loss = batch_music_loss / (y.size(0) * y_hat.size(1))
            loss += batch_music_loss
            total_music += batch_music_loss.item()

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon.item()
        total_kl += kl.item()

    return total_loss, total_recon, total_kl, total_music


In [None]:
def preparar_dataset(progresiones_x, progresiones_y=None, pad_value=0):
    x_tensors = [torch.tensor(p, dtype=torch.long) for p in progresiones_x]
    x_padded = pad_sequence(x_tensors, batch_first=True, padding_value=pad_value)

    if progresiones_y is not None:
        y_tensors = [torch.tensor(p, dtype=torch.long) for p in progresiones_y]
        y_padded = pad_sequence(y_tensors, batch_first=True, padding_value=pad_value)
        return TensorDataset(x_padded, y_padded)
    else:
        return TensorDataset(x_padded, x_padded)


dataset_pre = preparar_dataset(X_pretrain)
loader_pre = DataLoader(dataset_pre, batch_size=8, shuffle=True)

print("== Fase 1: Preentrenamiento ==")
# 600
for epoch in range(600):
    loss, recon, kl, _ = entrenar_cvae(model, loader_pre, optimizer, device,
                                       teacher_forcing_ratio=1, usar_music_loss=False)
    print(f"[{epoch+1}] Loss: {loss:.2f} | Recon: {recon:.2f} | KL: {kl:.2f}")



In [None]:

def mostrar_reconstrucciones(model, dataset, id_to_token, device, token_to_id, num_muestras=5, max_len=5):
    model.eval()
    x_batch, _ = next(iter(DataLoader(dataset, batch_size=num_muestras, shuffle=True)))
    x_batch = x_batch.to(device)

    with torch.no_grad():
        mu, logvar = model.encode(x_batch)
        z = model.reparameterize(mu, logvar)
        generated = model.generate(
            z,
            max_len=max_len,
            start_token_id=token_to_id['<sos>'],
            eos_token_id=token_to_id['<eos>']
        )

    for i in range(num_muestras):
        entrada_tokens = x_batch[i].cpu().tolist()
        salida_tokens = generated[i].cpu().tolist()

        entrada_chords = [
            id_to_token[tok]
            for tok in entrada_tokens
            if tok not in [token_to_id['<pad>'], token_to_id['<eos>'], token_to_id['<sos>']]
        ]
        salida_chords = [
            id_to_token[tok]
            for tok in salida_tokens
            if tok not in [token_to_id['<pad>'], token_to_id['<eos>'], token_to_id['<sos>']]
        ]

        print(f"\nðŸ”¹ Input  ({i+1}): {' | '.join(entrada_chords)}")
        print(f"ðŸ”¸ Reconstr ({i+1}): {' | '.join(salida_chords)}")


mostrar_reconstrucciones(model, dataset_pre, id_to_token_pretrain, device, token_to_id_pretrain, num_muestras=5, max_len=5)



In [None]:
timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
model_path = f"scl_pretrained_{timestamp}.pt"
torch.save(model.state_dict(), model_path)

In [None]:
dataset_train = preparar_dataset(X_train, Y_train)
loader_train = DataLoader(dataset_train, batch_size=8, shuffle=True)

print("\n== Fase 2: Train ==")
for epoch in range(100):
    loss, recon, kl, music = entrenar_cvae(model, loader_train, optimizer, device,
                                           teacher_forcing_ratio=0.9, usar_music_loss=True, w_coherencia=0.1, w_tensiones=2.0, w_movimiento=0.3)
    print(f"[{epoch+1}] Loss: {loss:.2f} | Recon: {recon:.2f} | KL: {kl:.2f} | Music: {music:.2f}")


In [None]:
timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
model_path = f"scl_fulltrained_{timestamp}.pt"
torch.save(model.state_dict(), model_path)