In [17]:
import os
import json
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from collections import OrderedDict
import numpy as np
import re
from scipy.io import wavfile
from matplotlib import pyplot as plt
from pypinyin import pinyin, Style
from string import punctuation
from mindspore import context
import hifigan
from text.symbols import symbols
print("MindSpore版本:", ms.__version__)
device = context.get_context("device_target")
if not device:
    has_gpu = ms.compatible.get_device_id() != -1
    device = "GPU" if has_gpu else "CPU"
    context.set_context(device_target=device)

print(f"当前使用设备: {device}")

MindSpore版本: 2.6.0.dev20250323
当前使用设备: CPU


In [None]:
def get_mask_from_lengths(lengths, max_len=None):
    if not isinstance(lengths, ms.Tensor):
        lengths = ms.Tensor(lengths, dtype=ms.int64)
    elif lengths.dtype != ms.int64:
        lengths = ops.cast(lengths, ms.int64)

    batch_size = int(lengths.shape[0])
    if max_len is None:
        max_len = int(ops.max(lengths)[0].asnumpy().item())
    else:
        if isinstance(max_len, ms.Tensor):
            max_len_np = max_len.asnumpy()
            max_len = int(max_len_np.item() if hasattr(max_len_np, "item") else np.array(max_len_np).reshape(-1)[0])
        else:
            max_len = int(max_len)
    max_len = max(max_len, 1)
    start = ms.Tensor(0.0, dtype=ms.float32)
    end = ms.Tensor(float(max_len - 1), dtype=ms.float32)
    ids = ops.linspace(start, end, max_len).reshape(1, -1)
    ids = ops.broadcast_to(ids, (batch_size, max_len))
    ids = ops.cast(ids, ms.int64)

    lengths_expand = ops.expand_dims(lengths, 1)
    lengths_expand = ops.broadcast_to(lengths_expand, (batch_size, max_len))
    mask = ids >= lengths_expand
    return mask


def pad(input_ele, mel_max_length=None):
    if mel_max_length:
        max_len = mel_max_length
    else:
        max_len = max([batch.shape[0] for batch in input_ele])
    
    out_list = []
    for batch in input_ele:
        seq_len = batch.shape[0]
        pad_right = max(max_len - seq_len, 0)
        
        if len(batch.shape) == 1:
            pad_width = (0, pad_right)
        elif len(batch.shape) == 2:
            pad_width = (0, pad_right, 0, 0)
        else:
            raise ValueError(f"仅支持1D/2D张量，当前是{len(batch.shape)}维")
        one_batch_padded = ops.pad(
            input_x=batch,
            padding=pad_width,
            mode='constant',
        )
        out_list.append(one_batch_padded)
    out_padded = ops.stack(out_list)
    return out_padded


def read_lexicon(lex_path):
    lexicon = {}
    with open(lex_path) as f:
        for line in f:
            temp = re.split(r"\s+", line.strip("\n"))
            word = temp[0]
            phones = temp[1:]
            if word.lower() not in lexicon:
                lexicon[word.lower()] = phones
    return lexicon


def preprocess_mandarin(text, preprocess_config):
    text = text.rstrip(punctuation)
    lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
    pinyins = [p[0] for p in pinyin(text, style=Style.TONE3, strict=False, neutral_tone_with_five=True)]
    phones = []
    for p in pinyins:
        phones += lexicon[p] if p in lexicon else ["sp"]
    phones = "{" + " ".join(phones) + "}"
    print(f"原始文本: {text}")
    print(f"音素序列: {phones}")
    from text import text_to_sequence
    sequence = np.array(text_to_sequence(phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]))
    if len(sequence) == 0:
        raise ValueError(f"文本预处理后序列长度为0，请检查输入文本或词典：{text}")
    return sequence


def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
    def cal_angle(position, hid_idx):
        return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
    def get_posi_angle_vec(position):
        return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
    sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 
    if padding_idx is not None:
        sinusoid_table[padding_idx] = 0.0
    return ms.Tensor(sinusoid_table, dtype=ms.float32)

def expand(values, durations):
    out = list()
    for value, d in zip(values, durations):
        out += [value] * max(0, int(d))
    return np.array(out)

In [19]:
class LengthRegulator(nn.Cell):
    def __init__(self):
        super(LengthRegulator, self).__init__()
    def LR(self, x, duration, max_len):
        output = []
        mel_len = []
        for batch, expand_target in zip(x, duration):
            expanded = self.expand(batch, expand_target)
            output.append(expanded)
            mel_len.append(expanded.shape[0])

        output = pad(output, max_len) if max_len else pad(output)
        return output, ms.Tensor(mel_len, dtype=ms.int64)
    def expand(self, batch, predicted):
        out = []
        for i, vec in enumerate(batch):
            expand_size = int(predicted[i].asnumpy()) 
            expanded_vec = ops.broadcast_to(vec, (max(expand_size, 0), vec.shape[0]))
            out.append(expanded_vec)
        return ops.concat(out, axis=0)
    def construct(self, x, duration, max_len):
        output, mel_len = self.LR(x, duration, max_len)
        return output, mel_len

In [None]:
class Conv(nn.Cell):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
        super(Conv, self).__init__()
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            has_bias=bias,
            pad_mode='pad'
        )

    def construct(self, x):
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = x.transpose(1, 2)
        return x


class VariancePredictor(nn.Cell):
    def __init__(self, model_config):
        super(VariancePredictor, self).__init__()
        self.input_size = model_config["transformer"]["encoder_hidden"]
        self.filter_size = model_config["variance_predictor"]["filter_size"]
        self.kernel = model_config["variance_predictor"]["kernel_size"]
        self.dropout = model_config["variance_predictor"]["dropout"]
        self.conv_layer = nn.SequentialCell(OrderedDict([
            ("conv1d_1", Conv(
                self.input_size, self.filter_size,
                kernel_size=self.kernel, padding=(self.kernel - 1) // 2,bias=True
            )),
            ("relu_1", nn.ReLU()),
            ("layer_norm_1", nn.LayerNorm((self.filter_size,))),
            ("dropout_1", nn.Dropout(p=self.dropout)),
            ("conv1d_2", Conv(
                self.filter_size, self.filter_size,
                kernel_size=self.kernel, padding=1,bias=True
            )),
            ("relu_2", nn.ReLU()),
            ("layer_norm_2", nn.LayerNorm((self.filter_size,))),
            ("dropout_2", nn.Dropout(p=self.dropout))
        ]))

        self.linear_layer = nn.Dense(self.filter_size, 1)

    def construct(self, encoder_output, mask):
        out = self.conv_layer(encoder_output)
        out = self.linear_layer(out)
        out = out.squeeze(-1)

        if mask is not None:
            out = ops.masked_fill(out, mask.astype(ms.bool_), 0.0)
        return out

In [None]:
class VarianceAdaptor(nn.Cell):
    def __init__(self, preprocess_config, model_config):
        super(VarianceAdaptor, self).__init__()
        self.duration_predictor = VariancePredictor(model_config)
        self.length_regulator = LengthRegulator()
        self.pitch_predictor = VariancePredictor(model_config)
        self.energy_predictor = VariancePredictor(model_config)
        self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"]["feature"]
        self.energy_feature_level = preprocess_config["preprocessing"]["energy"]["feature"]
        assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
        assert self.energy_feature_level in ["phoneme_level", "frame_level"]
        pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
        energy_quantization = model_config["variance_embedding"]["energy_quantization"]
        n_bins = model_config["variance_embedding"]["n_bins"]
        assert pitch_quantization in ["linear", "log"]
        assert energy_quantization in ["linear", "log"]
        with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")) as f:
            stats = json.load(f)
            pitch_min, pitch_max = stats["pitch"][:2]
            energy_min, energy_max = stats["energy"][:2]
        if pitch_quantization == "log":
            pitch_bins = np.exp(np.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1))
        else:
            pitch_bins = np.linspace(pitch_min, pitch_max, n_bins - 1)
        self.pitch_bins = ms.Parameter(ms.Tensor(pitch_bins, dtype=ms.float32), requires_grad=False)

        if energy_quantization == "log":
            energy_bins = np.exp(np.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1))
        else:
            energy_bins = np.linspace(energy_min, energy_max, n_bins - 1)
        self.energy_bins = ms.Parameter(ms.Tensor(energy_bins, dtype=ms.float32), requires_grad=False)
        self.pitch_embedding = nn.Embedding(n_bins, model_config["transformer"]["encoder_hidden"])
        self.energy_embedding = nn.Embedding(n_bins, model_config["transformer"]["encoder_hidden"])

    def get_pitch_embedding(self, x, target, mask, control):
        prediction = self.pitch_predictor(x, mask)
        pitch_bins_list = self.pitch_bins.asnumpy().tolist()
        pitch_bins_list = [float(bin_val) for bin_val in pitch_bins_list]
        if target is not None:
            bucket_idx = ops.bucketize(target, pitch_bins_list)
            embedding = self.pitch_embedding(bucket_idx)
        else:
            prediction = prediction * control

            bucket_idx = ops.bucketize(prediction, pitch_bins_list)
            embedding = self.pitch_embedding(bucket_idx)
        return prediction, embedding

    def get_energy_embedding(self, x, target, mask, control):
        prediction = self.energy_predictor(x, mask)
        energy_bins_list = self.energy_bins.asnumpy().tolist()
        energy_bins_list = [float(bin_val) for bin_val in energy_bins_list]
        if target is not None:
            bucket_idx = ops.bucketize(target, energy_bins_list)
            embedding = self.energy_embedding(bucket_idx)
        else:
            prediction = prediction * control
            bucket_idx = ops.bucketize(prediction, energy_bins_list)
            embedding = self.energy_embedding(bucket_idx)
        return prediction, embedding

    def construct(
        self, x, src_mask, mel_mask=None, max_len=None,
        pitch_target=None, energy_target=None, duration_target=None,
        p_control=1.0, e_control=1.0, d_control=1.0
    ):
        log_duration_prediction = self.duration_predictor(x, src_mask)
        if self.pitch_feature_level == "phoneme_level":
            pitch_prediction, pitch_embedding = self.get_pitch_embedding(x, pitch_target, src_mask, p_control)
            x = x + pitch_embedding
        if self.energy_feature_level == "phoneme_level":
            energy_prediction, energy_embedding = self.get_energy_embedding(x, energy_target, src_mask, e_control)
            x = x + energy_embedding
        if duration_target is not None:
            x, mel_len = self.length_regulator(x, duration_target, max_len)
            duration_rounded = duration_target
        else:
            duration_rounded = ops.clamp((ops.round(ops.exp(log_duration_prediction) - 1) * d_control), min=0)
            x, mel_len = self.length_regulator(x, duration_rounded, max_len)
            mel_mask = get_mask_from_lengths(mel_len)
        if self.pitch_feature_level == "frame_level":
            pitch_prediction, pitch_embedding = self.get_pitch_embedding(x, pitch_target, mel_mask, p_control)
            x = x + pitch_embedding
        if self.energy_feature_level == "frame_level":
            energy_prediction, energy_embedding = self.get_energy_embedding(x, energy_target, mel_mask, e_control)
            x = x + energy_embedding
        return x, pitch_prediction, energy_prediction, log_duration_prediction, duration_rounded, mel_len, mel_mask

In [22]:
class ScaledDotProductAttention(nn.Cell):
    def __init__(self, temperature):
        super().__init__()
        self.temperature = temperature
        self.softmax = nn.Softmax(axis=2)

    def construct(self, q, k, v, mask=None):
        batch_size, len_q, d_k = q.shape
        _, len_k, _ = k.shape
        _, len_v, d_v = v.shape
        k_trans = ops.transpose(k, (0, 2, 1))
        attn = ops.matmul(q, k_trans)
        attn = ops.clip_by_value(attn, clip_value_min=-8.0, clip_value_max=8.0)
        attn = attn / self.temperature
        if mask is not None:
            attn = ops.masked_fill(attn, mask, -np.inf)

        attn = self.softmax(attn)

        output = ops.matmul(attn, v)

        return output, attn

In [None]:
class MultiHeadAttention(nn.Cell):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.w_qs = nn.Dense(d_model, n_head * d_k)
        self.w_ks = nn.Dense(d_model, n_head * d_k)
        self.w_vs = nn.Dense(d_model, n_head * d_v)
        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm((d_model,))
        self.fc = nn.Dense(n_head * d_v, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def construct(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        batch_size, len_q, d_model_q = q.shape
        _, len_k, d_model_k = k.shape
        _, len_v, d_model_v = v.shape


        residual = q
        residual = ops.cast(residual, ms.float32)
        q = ops.cast(q, ms.float32)
        k = ops.cast(k, ms.float32)
        v = ops.cast(v, ms.float32)
        q = self.w_qs(q)
        k = self.w_ks(k)
        v = self.w_vs(v)


        q = q.reshape(batch_size, len_q, n_head, d_k)
        k = k.reshape(batch_size, len_k, n_head, d_k)
        v = v.reshape(batch_size, len_v, n_head, d_v)

        transpose_axis = (0, 2, 1, 3)  # 转置轴：B→0, H→2, L→1, d→3

        q = ops.transpose(q, transpose_axis)
        k = ops.transpose(k, transpose_axis)
        v = ops.transpose(v, transpose_axis)


        q = q.reshape(-1, len_q, d_k)
        k = k.reshape(-1, len_k, d_k)
        v = v.reshape(-1, len_v, d_v)

        if mask is not None:
            mask_int = ops.cast(mask, ms.int32)
            mask4_int = ops.repeat_interleave(mask_int, repeats=n_head, axis=0)
            mask4 = ops.cast(mask4_int, ms.bool_)
        else:
            mask4 = None

        output, attn = self.attention(q, k, v, mask=mask4)
        output = output.reshape(batch_size, n_head, len_q, d_v)
        output = ops.transpose(output, (0, 2, 1, 3))
        output = output.reshape(batch_size, len_q, -1)

        output = self.fc(output)
        output = self.dropout(output)
        output = self.layer_norm(output + residual)
        return output, attn

In [None]:
class PositionwiseFeedForward(nn.Cell):
    def __init__(self, d_in, d_hid, kernel_size, dropout=0.1):
        super().__init__()
        if isinstance(kernel_size, (list, tuple)):
            assert len(kernel_size) == 2, "kernel_size 应为长度为2的序列，如 [9, 1]"
            ks1, ks2 = int(kernel_size[0]), int(kernel_size[1])
        else:
            ks1 = int(kernel_size)
            ks2 = 1

        pad1 = (ks1 - 1) // 2
        pad2 = (ks2 - 1) // 2

        self.w_1 = nn.Conv1d(
            in_channels=d_in, out_channels=d_hid,
            kernel_size=ks1, padding=pad1, pad_mode='pad',has_bias=True
        )
        self.w_2 = nn.Conv1d(
            in_channels=d_hid, out_channels=d_in,
            kernel_size=ks2, padding=pad2, pad_mode='pad',has_bias=True
        )

        self.layer_norm = nn.LayerNorm((d_in,), epsilon=1e-4)
        self.dropout = nn.Dropout(p=dropout)

    def construct(self, x):
        residual = x
        x = x.transpose(1, 2)              # (B, T, C) -> (B, C, T)
        x = self.w_2(ops.relu(self.w_1(x)))
        x = x.transpose(1, 2)              # (B, C, T) -> (B, T, C)
        x = self.dropout(x)
        x = self.layer_norm(x + residual)
        return x

class FFTBlock(nn.Cell):
    def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
        super(FFTBlock, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, kernel_size, dropout=dropout)
        self.norm_before_attn = nn.LayerNorm((d_model,), epsilon=1e-4)
        self.norm_before_pos_ffn = nn.LayerNorm((d_model,), epsilon=1e-4)
        self.norm_before_attn.gamma.set_data(ms.Tensor(np.ones(d_model) * 0.9, dtype=ms.float32))
        self.norm_before_attn.beta.set_data(ms.Tensor(np.zeros(d_model), dtype=ms.float32))
        self.norm_before_pos_ffn.gamma.set_data(ms.Tensor(np.ones(d_model) * 0.9, dtype=ms.float32))
        self.norm_before_pos_ffn.beta.set_data(ms.Tensor(np.zeros(d_model), dtype=ms.float32))

    def construct(self, enc_input, mask=None, slf_attn_mask=None):

        enc_input_norm = self.norm_before_attn(enc_input)
        enc_output, enc_slf_attn = self.slf_attn(enc_input_norm, enc_input_norm, enc_input_norm, mask=slf_attn_mask)

        if mask is not None:
            mask_expanded = mask.expand_dims(2)
            fill_value = ms.Tensor(0.0, dtype=enc_output.dtype)
            enc_output = ops.masked_fill(enc_output, mask_expanded.astype(ms.bool_), fill_value)

        enc_output = self.pos_ffn(self.norm_before_pos_ffn(enc_output))

        if mask is not None:
            enc_output = ops.masked_fill(enc_output, mask.expand_dims(-1).astype(ms.bool_), 0.0)
        return enc_output, enc_slf_attn

In [25]:
class ConvNorm(nn.Cell):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=None, dilation=1, bias=True):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert kernel_size % 2 == 1
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = nn.Conv1d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=kernel_size, stride=stride, padding=padding,
            dilation=dilation, has_bias=bias,pad_mode='pad'
        )

    def construct(self, signal):
        return self.conv(signal)


class PostNet(nn.Cell):
    def __init__(self, n_mel_channels=80, postnet_embedding_dim=512, postnet_kernel_size=5, postnet_n_convolutions=5):
        super(PostNet, self).__init__()
        self.convolutions = nn.CellList()
        self.convolutions.append(nn.SequentialCell([
            ConvNorm(n_mel_channels, postnet_embedding_dim, kernel_size=postnet_kernel_size,bias=True),
            nn.BatchNorm1d(postnet_embedding_dim)
        ]))
        for _ in range(1, postnet_n_convolutions - 1):
            self.convolutions.append(nn.SequentialCell([
                ConvNorm(postnet_embedding_dim, postnet_embedding_dim, kernel_size=postnet_kernel_size,bias=True),
                nn.BatchNorm1d(postnet_embedding_dim)
            ]))
        self.convolutions.append(nn.SequentialCell([
            ConvNorm(postnet_embedding_dim, n_mel_channels, kernel_size=postnet_kernel_size,bias=True),
            nn.BatchNorm1d(n_mel_channels)
        ]))

    def construct(self, x):
        x = x.transpose(1, 2)
        for i in range(len(self.convolutions) - 1):
            x = ops.tanh(self.convolutions[i](x))
            x = nn.Dropout(p=0.5)(x) if self.training else x
        x = self.convolutions[-1](x)
        x = nn.Dropout(p=0.5)(x) if self.training else x
        x = x.transpose(1, 2)  
        return x

In [26]:
class Encoder(nn.Cell):
    def __init__(self, config):
        super(Encoder, self).__init__()
        n_position = config["max_seq_len"] + 1
        n_src_vocab = len(symbols) + 1
        d_word_vec = config["transformer"]["encoder_hidden"]
        n_layers = config["transformer"]["encoder_layer"]
        n_head = config["transformer"]["encoder_head"]
        d_k = d_v = config["transformer"]["encoder_hidden"] // config["transformer"]["encoder_head"]
        d_model = config["transformer"]["encoder_hidden"]
        d_inner = config["transformer"]["conv_filter_size"]
        kernel_size = config["transformer"]["conv_kernel_size"]
        dropout = config["transformer"]["encoder_dropout"]

        self.max_seq_len = config["max_seq_len"]
        self.d_model = d_model
        self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=0)  # PAD=0
        self.position_enc = ms.Parameter(
            get_sinusoid_encoding_table(n_position, d_word_vec).expand_dims(0),
            requires_grad=False
        )
        self.layer_stack = nn.CellList([
            FFTBlock(d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout)
            for _ in range(n_layers)
        ])

    def construct(self, src_seq, mask, return_attns=False):
        enc_slf_attn_list = []
        batch_size, max_len = src_seq.shape[0], src_seq.shape[1]
        slf_attn_mask = mask.expand_dims(1).broadcast_to((batch_size, max_len, max_len))
        if not self.training and max_len > self.max_seq_len:
            pos_enc = get_sinusoid_encoding_table(max_len, self.d_model).expand_dims(0)
            pos_enc = pos_enc.broadcast_to((batch_size, max_len, self.d_model))
        else:
            pos_enc = self.position_enc[:, :max_len, :].broadcast_to((batch_size, max_len, self.d_model))

        enc_output = self.src_word_emb(src_seq) + pos_enc
        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output, mask=mask, slf_attn_mask=slf_attn_mask)
            if return_attns:
                enc_slf_attn_list.append(enc_slf_attn)

        return enc_output if not return_attns else (enc_output, enc_slf_attn_list)


class Decoder(nn.Cell):
    def __init__(self, config):
        super(Decoder, self).__init__()
        n_position = config["max_seq_len"] + 1
        d_word_vec = config["transformer"]["decoder_hidden"]
        n_layers = config["transformer"]["decoder_layer"]
        n_head = config["transformer"]["decoder_head"]
        d_k = d_v = config["transformer"]["decoder_hidden"] // config["transformer"]["decoder_head"]
        d_model = config["transformer"]["decoder_hidden"]
        d_inner = config["transformer"]["conv_filter_size"]
        kernel_size = config["transformer"]["conv_kernel_size"]
        dropout = config["transformer"]["decoder_dropout"]

        self.max_seq_len = config["max_seq_len"]
        self.d_model = d_model

        self.position_enc = ms.Parameter(
            get_sinusoid_encoding_table(n_position, d_word_vec).expand_dims(0),
            requires_grad=False
        )
        self.layer_stack = nn.CellList([
            FFTBlock(d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout)
            for _ in range(n_layers)
        ])

    def construct(self, enc_seq, mask, return_attns=False):
        dec_slf_attn_list = []
        batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]

        isnan = ops.IsNan()
        isinf = ops.IsInf()
        any_op = ops.ReduceAny()


        if not self.training and max_len > self.max_seq_len:
            pos_enc = get_sinusoid_encoding_table(max_len, self.d_model).expand_dims(0)
            pos_enc = pos_enc.broadcast_to((batch_size, max_len, self.d_model))
            slf_attn_mask = mask.expand_dims(1).broadcast_to((batch_size, max_len, max_len))
            dec_output = enc_seq + pos_enc
        else:
            max_len = min(max_len, self.max_seq_len)
            enc_seq = enc_seq[:, :max_len, :]
            mask = mask[:, :max_len]
            pos_enc = self.position_enc[:, :max_len, :].broadcast_to((batch_size, max_len, self.d_model))
            slf_attn_mask = mask.expand_dims(1).broadcast_to((batch_size, max_len, max_len))
            dec_output = enc_seq + pos_enc


        for layer_idx, dec_layer in enumerate(self.layer_stack):
            dec_output, dec_slf_attn = dec_layer(dec_output, mask=mask, slf_attn_mask=slf_attn_mask)
            if return_attns:
                dec_slf_attn_list.append(dec_slf_attn)
        return (dec_output, mask) if not return_attns else (dec_output, mask, dec_slf_attn_list)


class FastSpeech2(nn.Cell):
    def __init__(self, preprocess_config, model_config):
        super(FastSpeech2, self).__init__()
        self.model_config = model_config
        self.encoder = Encoder(model_config)
        self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
        self.decoder = Decoder(model_config)
        self.mel_linear = nn.Dense(
            model_config["transformer"]["decoder_hidden"],
            preprocess_config["preprocessing"]["mel"]["n_mel_channels"]
        )
        self.postnet = PostNet()
        self.speaker_emb = None
        if model_config["multi_speaker"]:
            with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "speakers.json"), "r") as f:
                n_speaker = len(json.load(f))
            self.speaker_emb = nn.Embedding(n_speaker, model_config["transformer"]["encoder_hidden"])


    def construct(
                self, speakers, texts, src_lens, max_src_len,
                mels=None, mel_lens=None, max_mel_len=None,
                p_targets=None, e_targets=None, d_targets=None,
                p_control=1.0, e_control=1.0, d_control=1.0
            ):
        src_masks = get_mask_from_lengths(src_lens, max_src_len)
        mel_masks = get_mask_from_lengths(mel_lens, max_mel_len) if mel_lens is not None else None
        enc_output = self.encoder(texts, src_masks)

        if self.speaker_emb is not None:
            spk_emb = self.speaker_emb(speakers).expand_dims(1)
            B, T, D = enc_output.shape
            spk_emb = ops.broadcast_to(spk_emb, (B, T, spk_emb.shape[-1]))
            enc_output = enc_output + spk_emb

        (enc_output, p_predictions, e_predictions, log_d_predictions,
        d_rounded, mel_lens, mel_masks) = self.variance_adaptor(
            enc_output, src_masks, mel_masks, max_mel_len,
            p_targets, e_targets, d_targets,
            p_control, e_control, d_control
        )
        dec_output, mel_masks = self.decoder(enc_output, mel_masks)
        mel_outputs = self.mel_linear(dec_output)
        postnet_output = self.postnet(mel_outputs) + mel_outputs

        return (
            mel_outputs, postnet_output, p_predictions, e_predictions,
            log_d_predictions, d_rounded, src_masks, mel_masks, src_lens, mel_lens
        )

In [27]:
def get_vocoder(config, device):
    name = config["vocoder"]["model"]
    speaker = config["vocoder"]["speaker"]
    if name == "HiFi-GAN":
        with open("hifigan/config_SF.json", "r") as f:
            hifi_config = json.load(f)
        vocoder = hifigan.Generator_SF(hifi_config)
        ckpt_path = "hifigan/generator_universal.ckpt" if speaker == "universal" else "hifigan/generator_LJSpeech.ckpt"
        param_dict = ms.load_checkpoint(ckpt_path)
        param_not_load = ms.load_param_into_net(vocoder, param_dict,strict_load=True)
        if param_not_load:
            print(f"警告：未加载的参数{param_not_load}，可能影响推理效果")
        vocoder.set_train(False)
        vocoder.remove_weight_norm()
        return vocoder
    else:
        raise ValueError(f"不支持的声码器类型：{name}（当前仅支持MindSpore版HiFi-GAN）")

In [28]:
def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None):
    name = model_config["vocoder"]["model"]
    if not isinstance(mels, ms.Tensor):
        mels = ms.Tensor(mels, dtype=ms.float32)
    vocoder.set_train(False)
    if name == "HiFi-GAN":
        wavs_ms = vocoder(mels).squeeze(1)  # 移除多余维度
    max_wav_value = preprocess_config["preprocessing"]["audio"]["max_wav_value"]
    wavs = (wavs_ms.asnumpy() * max_wav_value).astype("int16")
    wavs = [wav for wav in wavs]
    if lengths is not None:
        for i in range(len(wavs)):
            length_scalar = int(lengths[i].asnumpy().item())
            wavs[i] = wavs[i][:length_scalar]
    return wavs

def plot_mel(data, stats, titles):
    fig, axes = plt.subplots(len(data), 1, squeeze=False)
    titles = titles if titles else [None]*len(data)
    pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
    pitch_min = pitch_min * pitch_std + pitch_mean
    pitch_max = pitch_max * pitch_std + pitch_mean
    def add_axis(fig, old_ax):
        ax = fig.add_axes(old_ax.get_position(), anchor="W")
        ax.set_facecolor("None")
        return ax

    for i in range(len(data)):
        mel, pitch, energy = data[i]
        pitch = pitch * pitch_std + pitch_mean
        axes[i][0].imshow(mel, origin="lower")
        axes[i][0].set_aspect(2.5, adjustable="box")
        axes[i][0].set_ylim(0, mel.shape[0])
        axes[i][0].set_title(titles[i], fontsize="medium")
        axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
        axes[i][0].set_anchor("W")
        ax1 = add_axis(fig, axes[i][0])
        ax1.plot(pitch, color="tomato")
        ax1.set_xlim(0, mel.shape[1])
        ax1.set_ylim(0, pitch_max)
        ax1.set_ylabel("F0", color="tomato")
        ax1.tick_params(labelsize="x-small", colors="tomato", bottom=False, labelbottom=False)
        ax2 = add_axis(fig, axes[i][0])
        ax2.plot(energy, color="darkviolet")
        ax2.set_xlim(0, mel.shape[1])
        ax2.set_ylim(energy_min, energy_max)
        ax2.set_ylabel("Energy", color="darkviolet")
        ax2.yaxis.set_label_position("right")
        ax2.tick_params(labelsize="x-small", colors="darkviolet", bottom=False, labelbottom=False, left=False, labelleft=False, right=True, labelright=True)
    
    return fig


def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
    basenames = [f"sample_{i}" for i in range(len(predictions[0]))]  # 样本名称
    os.makedirs(path, exist_ok=True)  # 创建保存目录

    with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")) as f:
        stats = json.load(f)
        stats = stats["pitch"] + stats["energy"][:2]  # [p_min, p_max, p_mean, p_std, e_min, e_max]

    for i in range(len(predictions[0])):
        mel_len = int(predictions[9][i].asnumpy())  # Mel谱实际长度
        mel_pred = predictions[1][i, :mel_len].transpose(0, 1).asnumpy()  # (mel_dim, seq_len)
        src_len = int(predictions[8][i].asnumpy())  # 文本序列长度
        duration = predictions[5][i, :src_len].asnumpy()  # 时长预测结果

        if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
            pitch = predictions[2][i, :src_len].asnumpy()
            pitch = expand(pitch, duration)
        else:
            pitch = predictions[2][i, :mel_len].asnumpy()

        if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
            energy = predictions[3][i, :src_len].asnumpy()
            energy = expand(energy, duration)
        else:
            energy = predictions[3][i, :mel_len].asnumpy()

        fig = plot_mel([(mel_pred, pitch, energy)], stats, ["Synthetized Spectrogram"])
        plt.savefig(os.path.join(path, f"{basenames[i]}.png"))
        plt.close()

    mel_predictions = predictions[1].transpose(1, 2)  # (batch, mel_dim, seq_len)
    hop_length = preprocess_config["preprocessing"]["stft"]["hop_length"]
    lengths = predictions[9] * hop_length  # 音频实际长度（采样点）
    wav_predictions = vocoder_infer(mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths)

    sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
    for wav, basename in zip(wav_predictions, basenames):
        wavfile.write(os.path.join(path, f"{basename}.wav"), sampling_rate, wav)

In [29]:
def load_fastspeech2_model(preprocess_config, model_config, ckpt_path):
    model = FastSpeech2(preprocess_config, model_config)
    param_dict = ms.load_checkpoint(ckpt_path)
    encoder_layer_num = model_config["transformer"]["encoder_layer"]
    for layer_idx in range(encoder_layer_num):
        w1_name = f"encoder.layer_stack.{layer_idx}.pos_ffn.w_1.weight"
        if w1_name in param_dict:
            old_w1 = param_dict[w1_name].asnumpy()
            new_w1 = old_w1.reshape(1024, 256, 1, 9)
            param_dict[w1_name] = ms.Parameter(ms.Tensor(new_w1, dtype=ms.float32), requires_grad=False)
        w2_name = f"encoder.layer_stack.{layer_idx}.pos_ffn.w_2.weight"
        if w2_name in param_dict:
            old_w2 = param_dict[w2_name].asnumpy()
            new_w2 = old_w2.reshape(256, 1024, 1, 1)
            param_dict[w2_name] = ms.Parameter(ms.Tensor(new_w2, dtype=ms.float32), requires_grad=False)
    
    decoder_layer_num = model_config["transformer"]["decoder_layer"]
    for layer_idx in range(decoder_layer_num):
        w1_name = f"decoder.layer_stack.{layer_idx}.pos_ffn.w_1.weight"
        if w1_name in param_dict:
            old_w1 = param_dict[w1_name].asnumpy()
            new_w1 = old_w1.reshape(1024, 256, 1, 9)
            param_dict[w1_name] = ms.Parameter(ms.Tensor(new_w1, dtype=ms.float32), requires_grad=False)
        w2_name = f"decoder.layer_stack.{layer_idx}.pos_ffn.w_2.weight"
        if w2_name in param_dict:
            old_w2 = param_dict[w2_name].asnumpy()
            new_w2 = old_w2.reshape(256, 1024, 1, 1)
            param_dict[w2_name] = ms.Parameter(ms.Tensor(new_w2, dtype=ms.float32), requires_grad=False)
    
    predictors = ["duration_predictor", "pitch_predictor", "energy_predictor"]
    conv_layers = ["conv1d_1", "conv1d_2"]
    for predictor in predictors:
        for conv_layer in conv_layers:
            param_name = f"variance_adaptor.{predictor}.conv_layer.{conv_layer}.conv.weight"
            if param_name in param_dict:
                old_weight = param_dict[param_name].asnumpy()
                new_weight = old_weight.reshape(256, 256, 1, 3)
                param_dict[param_name] = ms.Parameter(ms.Tensor(new_weight, dtype=ms.float32), requires_grad=False)
    
    postnet_conv_num = 5 
    for conv_idx in range(postnet_conv_num):
        param_name = f"postnet.convolutions.{conv_idx}.0.conv.weight"
        if param_name in param_dict:
            old_weight = param_dict[param_name].asnumpy()
            new_weight = old_weight.reshape(*old_weight.shape[:2], 1, old_weight.shape[2])
            param_dict[param_name] = ms.Parameter(ms.Tensor(new_weight, dtype=ms.float32), requires_grad=False)
    
    # -------------------------- 5. 加载参数 --------------------------
    param_not_load = ms.load_param_into_net(model, param_dict)
    model.set_train(False)
    return model


def load_vocoder(model_config):
    return get_vocoder(model_config, device)


def infer_single_sentence():
    preprocess_config = preprocess_config_infer
    model_config = model_config_infer
    control_params = infer_control_params
    input_data = infer_input
    ckpt_path = "./output/ckpt/AISHELL3/600000.ckpt"
    model = load_fastspeech2_model(preprocess_config, model_config, ckpt_path)
    vocoder = load_vocoder(model_config)
    text_sequence = preprocess_mandarin(input_data["text"], preprocess_config)
    text_sequence = ms.Tensor(text_sequence, dtype=ms.int64) 
    speakers = ms.Tensor([input_data["speaker_id"]], dtype=ms.int64)  # 说话人ID
    texts = text_sequence.expand_dims(0)  # (1, seq_len)：批量大小=1
    src_lens = ms.Tensor([text_sequence.shape[0]], dtype=ms.int64)  # 文本长度
    max_src_len = int(text_sequence.shape[0])  # ✅ 纯 Python int
    context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')  # 或 GPU，根据需要选择
    output = model(
        speakers=speakers,
        texts=texts,
        src_lens=src_lens,
        max_src_len=max_src_len,
        p_control=control_params["p_control"],  # 音高控制
        e_control=control_params["e_control"],  # 能量控制
        d_control=control_params["d_control"]   # 语速控制
    )
    save_path = "./output/result/AISHELL3"
    batch = (
        [input_data["text"][:100]],  # 文本摘要
        [input_data["text"]],        # 原始文本
        speakers.asnumpy(),          # 说话人ID（numpy）
        texts.asnumpy(),             # 文本序列（numpy）
        src_lens.asnumpy(),          # 文本长度（numpy）
        max_src_len              # 最大文本长度（numpy）
    )
    synth_samples(batch, output, vocoder, model_config, preprocess_config, save_path)
    print(f"语音合成完成！结果保存至：{save_path}")

In [30]:
preprocess_config_infer = {
    "path": {
        "lexicon_path": "lexicon/pinyin-lexicon-r.txt",  # 音素词典路径
        "preprocessed_path": "./preprocessed_data/AISHELL3",  # 预处理数据路径（含stats.json）
        "speakers_json_path": "./preprocessed_data/AISHELL3/speakers.json"  # 说话人列表
    },
    "preprocessing": {
        "text": {
            "language": "zh",  # 中文
            "text_cleaners": []  # 文本清理规则（与训练一致）
        },
        "audio": {
            "sampling_rate": 22050,  # 采样率（与声码器一致）
            "max_wav_value": 32768.0  # 音频最大值（int16范围）
        },
        "stft": {
            "filter_length": 1024,
            "hop_length": 256,
            "win_length": 1024
        },
        "mel": {
            "n_mel_channels": 80,  # Mel谱通道数（与模型输出一致）
            "mel_fmin": 0,
            "mel_fmax": 8000  # HiFi-GAN常用配置
        },
        "pitch": {
            "feature": "phoneme_level",  # 音素级音高（与训练一致）
            "normalization": True
        },
        "energy": {
            "feature": "phoneme_level",  # 音素级能量（与训练一致）
            "normalization": True
        }
    }
}

model_config_infer = {
    "transformer": {
        "encoder_layer": 4,
        "encoder_head": 2,
        "encoder_hidden": 256,
        "decoder_layer": 6,
        "decoder_head": 2,
        "decoder_hidden": 256,
        "conv_filter_size": 1024,
        "conv_kernel_size": [9, 1],
        "encoder_dropout": 0.2,
        "decoder_dropout": 0.3
    },
    "variance_predictor": {
        "filter_size": 256,
        "kernel_size": 3,
        "dropout": 0.5
    },
    "variance_embedding": {
        "pitch_quantization": "linear",
        "energy_quantization": "linear",
        "n_bins": 256
    },
    "multi_speaker": True,  # 多说话人模型
    "max_seq_len": 1000,    # 最大文本长度
    "vocoder": {
        "model": "HiFi-GAN",  # 声码器类型
        "speaker": "universal"  # 通用声码器
    }
}

infer_control_params = {
    "p_control": 1.0,
    "e_control": 1.0, 
    "d_control": 1.0  
}

infer_input = {
    "text": "李白，字太白，号青莲居士，唐代伟大的浪漫主义诗人，被后人誉为“诗仙”。他的诗歌以豪放飘逸、想象丰富著称，代表作有《将进酒》《静夜思》《早发白帝城》等，深受人们喜爱。",
    "speaker_id": 0,
    "max_src_len": 1000
}



In [31]:
if __name__ == "__main__":
    infer_single_sentence()



警告：未加载的参数([], [])，可能影响推理效果
原始文本: 李白，字太白，号青莲居士，唐代伟大的浪漫主义诗人，被后人誉为“诗仙”。他的诗歌以豪放飘逸、想象丰富著称，代表作有《将进酒》《静夜思》《早发白帝城》等，深受人们喜爱。
音素序列: {l i3 b ai2 sp z ii4 t ai4 b ai2 sp h ao4 q ing1 l ian2 j v1 sh iii4 sp t ang2 d ai4 w uei3 d a4 d e5 l ang4 m an4 zh u3 y i4 sh iii1 r en2 sp b ei4 h ou4 r en2 y v4 w uei2 sp sh iii1 x ian1 sp t a1 d e5 sh iii1 g e1 y i3 h ao2 f ang4 p iao1 y i4 sp x iang3 x iang4 f eng1 f u4 zh u4 ch eng1 sp d ai4 b iao3 z uo4 y iou3 sp q iang1 j in4 j iou3 sp j ing4 y ie4 s ii1 sp z ao3 f a1 b ai2 d i4 ch eng2 sp d eng3 sp sh en1 sh ou4 r en2 m en5 x i3 ai4 sp}
语音合成完成！结果保存至：./output/result/AISHELL3
