In [None]:
import torch
import torch.nn.functional as F
import torchaudio.transforms as T
import numpy as np
from tqdm import tqdm
import torchaudio

class PNP_Flow_Audio_Denoiser:
    """
    一个专门为你的F5-TTS Flow Matching模型设计的PnP-Flow音频去噪器。
    """
    def __init__(self, model, vocoder, device):
        """
        初始化去噪器。
        
        参数:
        - model: 你已经加载好的 ema_model。
        - vocoder: 你已经加载好的声码器，用于最终的音频合成。
        - device: torch.device，例如 'cuda'。
        """
        self.model = model.to(device)
        self.vocoder = vocoder
        self.device = device
        self.model.eval()

    @torch.no_grad()
    def get_unconditional_velocity(self, m_t, t):
        """
        获取模型在时间t对输入m_t的无条件速度场。
        这是PnP去噪的核心先验。

        参数:
        - m_t: 当前带噪的Mel谱张量, shape (B, T, D)。
        - t: 当前时间步, 一个标量或与B匹配的张量。
        """
        # 确保t是正确的shape和device
        time_tensor = torch.tensor([t], device=self.device, dtype=m_t.dtype).expand(m_t.shape[0])

        # F5-TTS模型的transformer期望输入 (B, T, D)，和我们的m_t格式一致
        # 我们通过提供一个空的text和一个空的cond来强制无条件生成
        batch_size, seq_len = m_t.shape[0], m_t.shape[1]
        
        null_text = torch.zeros(batch_size, 0, dtype=torch.long, device=self.device) # 空文本
        null_cond = torch.zeros(batch_size, seq_len, self.model.num_channels, device=self.device, dtype=m_t.dtype) # 空条件
        
        # 模型在无条件模式下会输出v_t
        # 注意：这里的 cfg_infer=False，或者我们需要直接调用transformer的无条件部分
        # 根据你的fn函数，我们可以通过cfg_infer=True然后只取null_pred来获得
        # 我们模拟一下`fn`函数的行为来获取`null_pred`
        # pred_cfg = self.model.transformer(
        #     x=m_t.repeat(2, 1, 1), # 输入两次以触发CFG路径
        #     cond=null_cond.repeat(2, 1, 1),
        #     text=null_text.repeat(2, 1), # 假设text也需要适配batch size
        #     time=time_tensor.repeat(2),
        #     cfg_infer=True
        # )
        pred_cfg = self.model.transformer(
            x=m_t,
            cond=null_cond,
            text=null_text,
            time=time_tensor,
            cfg_infer=False
        )
        # _, null_pred_v = torch.chunk(pred_cfg, 2, dim=0)
        
        return pred_cfg

    def denoise(self, noisy_mel, steps=100, lr=0.1, gamma_style='alpha_1_minus_t', alpha=0.5, num_samples=1):
        """
        对一个带噪的Mel谱执行完整的PnP-Flow去噪过程。

        参数:
        - noisy_mel: 带噪的Mel谱张量, shape (1, T, D) 或 (T, D)。
        - steps: 总的迭代步数 (N)。
        - lr: 学习率基础值。
        - gamma_style: 学习率衰减策略。
        - alpha: gamma_style='alpha_1_minus_t'时的指数。
        - num_samples: 每次去噪时采样的随机噪声ε的数量，用于平均结果。
        """
        if noisy_mel.ndim == 2:
            noisy_mel = noisy_mel.unsqueeze(0)
        
        noisy_mel = noisy_mel.to(self.device, dtype=self.model.transformer.time_embed.time_mlp[0].weight.dtype)
        batch_size, seq_len, mel_dim = noisy_mel.shape

        # 1. 初始化
        # 从一个纯噪声开始
        m = torch.randn_like(noisy_mel, device=self.device, dtype=noisy_mel.dtype)
        
        print(f"开始PnP-Flow去噪，共 {steps} 步...")
        for n in tqdm(range(steps)):
            t_n = n / steps
            time_tensor = torch.tensor([t_n], device=self.device).expand(batch_size)

            # --- 第一步：梯度步骤 (数据保真) ---
            # H 和 H_adj 都是恒等变换，我们假设高斯噪声
            # grad = 2 * (m - noisy_mel)  <- 这是梯度，但我们通常把系数吸收到lr里
            grad = m - noisy_mel

            # 计算随时间变化的learning rate
            current_lr = lr * ((1 - t_n)**alpha) if gamma_style == 'alpha_1_minus_t' else lr
            
            z = m - current_lr * grad
            
            # --- 第二步 & 第三步 循环 ---
            m_next_total = torch.zeros_like(m)
            for _ in range(num_samples):
                # --- 第二步：重投影/插值步骤 ---
                epsilon = torch.randn_like(z, device=self.device, dtype=z.dtype)
                m_tilde = (1 - t_n) * epsilon + t_n * z

                # print(m_tilde.dtype)
                # print(t_n.dtype)
            
                # --- 第三步：PnP去噪步骤 ---
                # 获取无条件速度场 v_t
                v_t = self.get_unconditional_velocity(m_tilde, t_n)
                # print(v_t)
            
                # 应用去噪器 D_t
                # D_t(x) = x + (1-t)*v_t(x)
                m_denoised = m_tilde + (1 - t_n) * v_t
                m_next_total += m_denoised
            
            # 平均多次采样的结果
            m = m_next_total / num_samples

        # 循环结束，m 就是去噪后的Mel谱
        final_mel = m.detach()
        print("PnP-Flow去噪完成！")

        # 使用声码器合成音频
        final_mel_for_vocoder = final_mel.permute(0, 2, 1)
        final_wave = self.vocoder.decode(final_mel_for_vocoder.to(torch.float32))
        return final_wave.squeeze().cpu().numpy(), final_mel.squeeze().cpu().numpy()


# --- 如何使用 ---

import os
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import soundfile as sf
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from importlib.resources import files

sys.path.append("/mnt/workspace/zhangjunan/F5-TTS/src/")

from f5_tts.infer.utils_infer import (
    device,
    infer_process,
    load_model,
    load_vocoder,
)

model_name = "F5TTS_v1_Base"
vocoder_name = "vocos"

vocoder = load_vocoder(vocoder_name=vocoder_name, device=device)
model_config_path = str(files("f5_tts").joinpath(f"configs/{model_name}.yaml"))
model_cfg = OmegaConf.load(model_config_path)

repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
checkpoint_path = str(cached_path(f"hf://SWivid/{repo_name}/{model_name}/model_{ckpt_step}.{ckpt_type}"))

model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch

ema_model = load_model(
    model_cls, 
    model_arc, 
    checkpoint_path, 
    mel_spec_type=vocoder_name, 
    device=device
)

# 1. 创建去噪器实例
pnp_denoiser = PNP_Flow_Audio_Denoiser(model=ema_model, vocoder=vocoder, device=device)

audio_path = "/mnt/workspace/zhangjunan/F5-TTS/src/f5_tts/infer/examples/basic/basic_ref_en.wav"
audio_path = "/mnt/workspace/zhangjunan/F5-TTS/src/f5_tts/infer/tests/p232_005.wav"
target_sample_rate = 24000 # 确保采样率与声码器/模型匹配

print(f"正在加载音频文件: {audio_path}")
waveform, original_sr = torchaudio.load(audio_path)

# 确保音频是单声道
if waveform.shape[0] > 1:
    waveform = torch.mean(waveform, dim=0, keepdim=True)

# 如果采样率不匹配，则进行重采样
if original_sr != target_sample_rate:
    print(f"音频采样率为 {original_sr}, 正在重采样至 {target_sample_rate}...")
    resampler = T.Resample(original_sr, target_sample_rate)
    waveform = resampler(waveform)

# 确保有batch维度 (B, T_wave) 并移动到设备
waveform = waveform.unsqueeze(0) if waveform.ndim == 1 else waveform
waveform = waveform.to(device)

print("音频加载并预处理完成。")

# --- 步骤 2: 将音频转换为干净的Mel频谱图 ---

print("正在使用模型的 mel_spec 将音频转换为Mel频谱图...")
with torch.no_grad():
    # 使用你模型内部的mel_spec函数，这是最准确的方式
    # 输入波形 shape: (B, T_wave)
    # 输出Mel谱 shape: (B, D, T_mel)
    clean_mel_from_audio = ema_model.mel_spec(waveform.to(torch.float32))

    # 你的模型内部处理流程期望 (B, T, D) 格式
    # 因此我们需要进行维度转换
    clean_mel_from_audio = clean_mel_from_audio.permute(0, 2, 1)

print(f"成功生成干净的Mel频谱图, shape: {clean_mel_from_audio.shape}")


# --- 步骤 3: 创建带噪的Mel谱和对应的带噪音频（用于对比） ---

noise_level = 0.5  # 噪声强度，这是个很关键的参数，你可以调整它来测试
print(f"正在添加高斯噪声, 强度: {noise_level}")
# noisy_mel = clean_mel_from_audio + torch.randn_like(clean_mel_from_audio) * noise_level
noisy_mel = clean_mel_from_audio

# 为了方便对比，我们也将带噪的Mel谱合成为音频听一下
print("正在合成带噪的音频用于对比...")
with torch.no_grad():
    # 声码器期望输入 (B, D, T), 所以需要再次转换维度
    noisy_mel_for_vocoder = noisy_mel.permute(0, 2, 1)
    noisy_audio = vocoder.decode(noisy_mel_for_vocoder.to(torch.float32)).squeeze().cpu().numpy()


# --- 步骤 4: 运行PnP-Flow去噪器 ---

print("="*30)
denoised_audio, denoised_mel = pnp_denoiser.denoise(
    noisy_mel=noisy_mel,   # 使用我们刚刚创建的带噪Mel谱
    steps=100,             # 迭代步数
    lr=0.5,                # 学习率
    alpha=0.5,             # 学习率衰减指数
    num_samples=1          # 为了速度设为1，设为3-5效果更稳定
)
print("="*30)

# --- 步骤 5: 保存所有音频用于对比 ---

# 保存原始干净音频
clean_audio_to_save = waveform.squeeze().cpu().numpy()
sf.write("output_0_clean_original.wav", clean_audio_to_save, target_sample_rate)
print("原始干净音频已保存到: output_0_clean_original.wav")

# 保存带噪音频
sf.write("output_1_noisy_input.wav", noisy_audio, target_sample_rate)
print("手动添加噪声后的音频已保存到: output_1_noisy_input.wav")

# 保存去噪后的音频
sf.write("output_2_denoised_pnpflow.wav", denoised_audio, target_sample_rate)
print("PnP-Flow去噪后的音频已保存到: output_2_denoised_pnpflow.wav")

Download Vocos from huggingface charactr/vocos-mel-24khz

vocab :  /mnt/workspace/zhangjunan/F5-TTS/src/f5_tts/infer/examples/vocab.txt
token :  custom
model :  /mnt/workspace/zhangjunan/.cache/huggingface/hub/models--SWivid--F5-TTS/snapshots/84e5a410d9cead4de2f847e7c9369a6440bdfaca/F5TTS_v1_Base/model_1250000.safetensors 

正在加载音频文件: /mnt/workspace/zhangjunan/F5-TTS/src/f5_tts/infer/tests/p232_005.wav
音频采样率为 16000, 正在重采样至 24000...
音频加载并预处理完成。
正在使用模型的 mel_spec 将音频转换为Mel频谱图...
成功生成干净的Mel频谱图, shape: torch.Size([1, 586, 100])
正在添加高斯噪声, 强度: 0.5
正在合成带噪的音频用于对比...
开始PnP-Flow去噪，共 100 步...


  6%|▌         | 6/100 [00:00<00:01, 55.06it/s]

tensor([[[  0.4487,  -3.8320,  -3.2129,  ...,  -4.7930,  -8.8906, -10.7734],
         [ -2.4824,  -2.6387,  -1.4277,  ...,  -5.8711,  -8.4375,  -8.3281],
         [ -3.4980,  -4.2539,  -3.2363,  ...,  -6.7930,  -7.7188, -11.4609],
         ...,
         [ -3.3633,  -3.9004,  -1.6670,  ...,  -6.9961, -10.0703,  -9.3750],
         [ -2.6562,  -5.4922,  -2.6855,  ...,  -5.4609,  -7.5391,  -7.5234],
         [ -2.0117,  -2.1445,  -2.2070,  ...,  -3.8828,  -4.1211,  -4.0938]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[  0.6270,  -2.9551,  -1.3848,  ...,  -9.7188, -10.4141,  -8.5547],
         [ -2.6016,  -2.5156,  -2.6055,  ...,  -7.0234,  -7.8320, -11.1406],
         [ -1.8818,  -3.1973,  -3.5703,  ...,  -7.6992,  -8.0312,  -7.6133],
         ...,
         [ -2.0645,  -2.9688,  -3.3125,  ...,  -6.5430, -10.4766,  -9.8359],
         [ -3.7207,  -3.4883,  -4.7148,  ...,  -5.2734,  -6.7773,  -7.4219],
         [ -1.9619,  -2.1289,  -1.9463,  ...,  -3.8652,  -4.1094,  -4.0938]]],

 13%|█▎        | 13/100 [00:00<00:01, 58.12it/s]

tensor([[[  0.0792,   1.1523,   1.4434,  ...,  -4.8125,  -5.5000,  -8.6172],
         [  2.5977,   1.5127,   1.0322,  ...,  -3.9395,  -6.1289, -10.0938],
         [  3.3984,   0.9692,  -0.2311,  ...,  -4.8984,  -6.3516,  -6.6133],
         ...,
         [  2.2012,   0.6680,  -1.1729,  ...,  -5.9531,  -5.5000, -10.6172],
         [ -0.1482,   1.3350,  -0.7314,  ...,  -4.0469,  -5.3125,  -4.2344],
         [ -1.5068,  -1.6592,  -1.5098,  ...,  -3.1523,  -3.4219,  -3.6504]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[  2.4199,   3.7305,   0.6162,  ...,  -6.1445,  -6.7227,  -9.2656],
         [  3.0820,   0.6133,   1.9268,  ...,  -7.2930,  -5.3789, -10.2812],
         [  2.7344,   1.0420,  -0.9980,  ...,  -6.4766,  -8.5312,  -9.8672],
         ...,
         [  0.4724,   1.3545,   0.1873,  ...,  -3.7266,  -4.6406,  -8.3281],
         [  0.7944,   0.0922,   3.3926,  ...,  -3.3496,  -3.8125,  -4.1211],
         [ -1.4961,  -1.6240,  -1.5068,  ...,  -3.1055,  -3.3887,  -3.6270]]],

 20%|██        | 20/100 [00:00<00:01, 59.16it/s]

tensor([[[ 0.4819, -2.0000,  0.7617,  ..., -6.4727, -3.4844, -5.6836],
         [-0.4211,  1.5410,  2.5000,  ..., -4.3984, -4.6289, -2.8320],
         [ 2.6113,  2.0801,  3.1543,  ..., -5.0000, -3.3906, -3.7500],
         ...,
         [ 1.1309,  0.0468,  0.1429,  ..., -4.1328, -5.5898, -8.7812],
         [ 1.9727,  1.8955, -0.1150,  ..., -3.6270, -3.6035, -5.5469],
         [-1.5010, -1.6240, -1.4951,  ..., -3.0801, -3.3555, -3.6172]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 1.2686,  2.1758,  3.2109,  ..., -4.4258, -2.5879, -1.3926],
         [ 0.5479,  2.8691,  4.3594,  ..., -3.5879, -2.0586, -4.4180],
         [ 0.9497,  0.7690,  2.5938,  ..., -5.4258, -1.5703, -2.0723],
         ...,
         [ 1.2061,  0.8335,  1.5723,  ..., -4.0742, -5.2227, -8.2656],
         [-0.2744,  0.9512,  0.0858,  ..., -3.1152, -5.0312, -5.1055],
         [-1.4961, -1.6113, -1.4844,  ..., -3.0664, -3.3359, -3.5508]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 0.1490, -0.1982

 34%|███▍      | 34/100 [00:00<00:01, 59.88it/s]

tensor([[[ 2.5801e+00,  3.7891e+00,  1.6240e+00,  ..., -5.7656e+00,
          -5.1797e+00, -4.1055e+00],
         [ 3.6699e+00,  1.8242e+00,  2.3672e+00,  ..., -7.8633e+00,
          -3.3438e+00, -4.2578e+00],
         [ 4.9258e+00,  1.4932e+00,  2.9668e+00,  ..., -5.9062e+00,
          -4.5117e+00, -3.9375e+00],
         ...,
         [ 2.6504e+00,  3.2031e+00,  1.8418e+00,  ..., -3.9062e+00,
          -4.4648e+00, -8.8594e+00],
         [ 2.2383e+00,  1.6914e+00,  8.0185e-03,  ..., -2.2363e+00,
          -3.5781e+00, -3.3750e+00],
         [-1.4404e+00, -1.5400e+00, -1.4082e+00,  ..., -2.8730e+00,
          -3.2012e+00, -3.4648e+00]]], device='cuda:0', dtype=torch.float16)
tensor([[[  2.0820,   5.6641,   0.6421,  ...,  -4.5547,  -2.8438,  -3.6855],
         [  3.5645,  -0.1388,   2.0430,  ...,  -5.6445,  -5.5898,  -3.3262],
         [  2.4219,   5.4297,   1.4004,  ...,  -6.2891,  -2.5742,  -3.3887],
         ...,
         [  2.3750,   3.0293,   0.9878,  ...,  -5.4414,  -5.4102, -10.8

 47%|████▋     | 47/100 [00:00<00:00, 60.08it/s]

tensor([[[ 1.3311,  3.5898,  1.4336,  ..., -4.6289, -4.1094, -4.6250],
         [ 0.9399,  2.2559, -2.0508,  ..., -6.4766, -4.7188, -4.4180],
         [ 4.3633,  1.5576,  3.0039,  ..., -5.9141, -4.5000, -5.5977],
         ...,
         [ 2.0820,  2.1992,  0.4431,  ..., -2.9609, -3.5039, -6.1211],
         [ 2.7871,  0.2214, -0.5054,  ..., -3.1211, -3.6738, -3.5957],
         [-1.4521, -1.4775, -1.3398,  ..., -2.7793, -3.1270, -3.4043]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 1.9844,  3.3066,  0.1587,  ..., -3.8359, -4.0898, -3.5371],
         [ 3.1699,  1.1455,  3.0430,  ..., -5.0508, -4.1133, -5.9141],
         [ 3.3262,  2.6016, -0.2352,  ..., -3.8496, -3.6250, -5.0000],
         ...,
         [ 2.5781,  0.7192,  1.4346,  ..., -4.4531, -5.2930, -8.0938],
         [ 1.4355,  0.5825,  2.2871,  ..., -2.1777, -3.6211, -4.5859],
         [-1.4521, -1.4756, -1.3389,  ..., -2.7637, -3.1328, -3.3887]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 0.9033,  3.3887

 61%|██████    | 61/100 [00:01<00:00, 60.24it/s]

tensor([[[ 3.1914,  0.3621,  2.1641,  ..., -3.8184, -4.6094, -3.0332],
         [ 2.7090,  2.2715,  1.8330,  ..., -5.8789, -2.3066, -3.4219],
         [ 2.0039,  3.5137,  1.5693,  ..., -5.7891, -5.0352, -6.7891],
         ...,
         [ 1.2666,  2.7695,  1.5488,  ..., -3.8086, -3.9180, -7.3398],
         [ 2.7168,  0.4346,  2.4258,  ..., -4.6953, -2.7969, -2.1836],
         [-1.4316, -1.4307, -1.2764,  ..., -2.6621, -3.0703, -3.3926]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 2.2109, -0.9585,  3.1348,  ..., -4.8867, -2.0938, -4.5664],
         [ 2.3867,  3.0684,  2.5195,  ..., -3.6465, -3.7422, -3.9258],
         [ 3.0078,  1.9521,  1.9570,  ..., -5.9062, -5.2461, -7.8906],
         ...,
         [ 2.0430,  2.6445, -0.2622,  ..., -3.0605, -4.4375, -7.6133],
         [ 1.7314,  1.0840,  1.4131,  ..., -2.4102, -2.6426, -2.3926],
         [-1.4219, -1.4248, -1.2539,  ..., -2.6719, -3.0781, -3.3887]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 1.0713,  2.8184

 75%|███████▌  | 75/100 [00:01<00:00, 60.32it/s]

tensor([[[ 3.1113,  1.5156,  2.9609,  ..., -2.5977, -2.3867, -2.7930],
         [ 2.7832,  2.1523,  1.4619,  ..., -4.8164, -2.7461, -5.1797],
         [ 3.3047,  2.9785,  1.6211,  ..., -6.6836, -8.1562, -9.5859],
         ...,
         [ 1.9189,  2.5352,  1.5908,  ..., -4.4727, -4.2266, -8.4844],
         [ 1.4746,  0.1488,  2.2969,  ..., -3.1719, -2.8691, -3.0039],
         [-1.3779, -1.3779, -1.1963,  ..., -2.5801, -3.0488, -3.3789]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 3.0332,  1.7764,  2.0391,  ..., -3.3438, -2.4629, -2.5293],
         [ 2.5840,  1.9316,  1.7246,  ..., -3.8906, -3.1328, -4.1094],
         [ 4.0977,  1.8965,  2.2988,  ..., -7.6172, -8.7656, -9.2344],
         ...,
         [ 1.3887,  1.9990,  1.8701,  ..., -4.5625, -4.9141, -7.2461],
         [ 2.3105,  2.0781,  0.4783,  ..., -3.6270, -2.9453, -2.7988],
         [-1.3799, -1.3809, -1.2021,  ..., -2.5664, -3.0332, -3.3828]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 3.3906,  0.7275

 89%|████████▉ | 89/100 [00:01<00:00, 60.34it/s]

tensor([[[ 3.0605,  2.8848,  2.2910,  ..., -1.6738, -0.9980, -1.3027],
         [ 2.1777,  2.4980,  2.6875,  ..., -3.4570, -3.1191, -3.4512],
         [ 3.2949,  1.8525,  3.1504,  ..., -7.4180, -8.0156, -9.4297],
         ...,
         [ 2.2090,  2.1680,  1.4248,  ..., -4.3398, -5.3086, -7.2852],
         [ 0.3755,  1.2197,  1.6338,  ..., -4.1289, -3.3145, -3.4707],
         [-1.3281, -1.3379, -1.1270,  ..., -2.4473, -2.9512, -3.3594]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[  0.5420,   0.5244,  -1.2051,  ...,  -2.6367,  -2.5293,  -2.1602],
         [  2.9199,   3.9863,   1.1465,  ...,  -3.1328,  -3.8125,  -3.4922],
         [  2.0059,   3.4023,   1.7266,  ...,  -7.2305,  -8.0625, -10.2266],
         ...,
         [  2.4492,   2.4648,   0.5747,  ...,  -4.6602,  -4.8945,  -6.8281],
         [  0.2725,   1.3076,   2.0508,  ...,  -4.0391,  -3.6484,  -3.5254],
         [ -1.3242,  -1.3311,  -1.1328,  ...,  -2.4395,  -2.9375,  -3.3555]]],
       device='cuda:0', dtype=torch

100%|██████████| 100/100 [00:01<00:00, 59.89it/s]


tensor([[[ 1.7559,  1.7070,  1.9297,  ..., -2.3281, -1.5439, -1.5264],
         [ 2.4961,  2.0234,  2.1250,  ..., -2.6914, -2.7520, -2.7617],
         [ 2.5645,  1.9199,  1.3750,  ..., -6.8555, -8.0312, -9.1641],
         ...,
         [ 1.4570,  2.3613,  0.9648,  ..., -4.4844, -5.3750, -6.6523],
         [ 1.6738,  0.9224,  1.8809,  ..., -3.4961, -3.6484, -3.5273],
         [-1.2930, -1.3184, -1.0889,  ..., -2.2227, -2.7285, -3.2344]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 2.0098,  1.5342,  2.0996,  ..., -1.7041, -1.5527, -1.4580],
         [ 2.6328,  1.8848,  1.7051,  ..., -3.1406, -2.4922, -2.8594],
         [ 2.5742,  2.0957,  1.9854,  ..., -6.7773, -8.1250, -8.8750],
         ...,
         [ 1.3633,  2.6289,  1.0078,  ..., -4.6641, -5.3477, -6.5352],
         [ 1.5605,  0.8120,  2.1348,  ..., -3.3516, -3.7090, -3.6328],
         [-1.2881, -1.3164, -1.0859,  ..., -2.1934, -2.7031, -3.2227]]],
       device='cuda:0', dtype=torch.float16)
tensor([[[ 2.0645,  1.5537