In [28]:
#load utt
import json
from yacs.config import CfgNode
from deepspeech.models.u2 import U2Model
import paddle
from collections import OrderedDict
import numpy as np


def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
    """Make mask tensor containing indices of padded part.
    See description of make_non_pad_mask.
    Args:
        lengths (paddle.Tensor): Batch of lengths (B,).
    Returns:
        paddle.Tensor: Mask tensor containing indices of padded part.
    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    # (TODO: Hui Zhang): jit not support Tenosr.dim() and Tensor.ndim
    # assert lengths.dim() == 1
    batch_size = int(lengths.shape[0])
    max_len = int(lengths.max())
    seq_range = paddle.arange(0, max_len, dtype=paddle.int64)
    seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len])
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask


def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
    """Make mask tensor containing indices of non-padded part.
    The sequences in a batch may have different lengths. To enable
    batch computing, padding is need to make all sequence in same
    size. To avoid the padding part pass value to context dependent
    block such as attention or convolution , this padding part is
    masked.
    This pad_mask is used in both encoder and decoder.
    1 for non-padded part and 0 for padded part.
    Args:
        lengths (paddle.Tensor): Batch of lengths (B,).
    Returns:
        paddle.Tensor: mask tensor containing indices of padded part.
    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_non_pad_mask(lengths)
        masks = [[1, 1, 1, 1 ,1],
                 [1, 1, 1, 0, 0],
                 [1, 1, 0, 0, 0]]
    """
    return ~make_pad_mask(lengths)

espnet_dir = '../asr/espnet/'
ds2_dir = '../DeepSpeech-2.x/'

In [129]:
!ls ../asr/espnet//egs/librispeech/asr1/asr1/exp/train_960_pytorch_train_specaug/results/
!ls $espnet_dir/egs/librispeech/asr1/dump/dev_clean/deltafalse/data_unigram5000.json

acc.png		model.loss.best  snapshot.ep.16  snapshot.ep.23  snapshot.ep.6
att_ws		snapshot.ep.1	 snapshot.ep.17  snapshot.ep.24  snapshot.ep.7
ctc_prob	snapshot.ep.10	 snapshot.ep.18  snapshot.ep.25  snapshot.ep.8
log		snapshot.ep.11	 snapshot.ep.19  snapshot.ep.26  snapshot.ep.9
loss.png	snapshot.ep.12	 snapshot.ep.2	 snapshot.ep.27
model.acc.best	snapshot.ep.13	 snapshot.ep.20  snapshot.ep.3
model.json	snapshot.ep.14	 snapshot.ep.21  snapshot.ep.4
model.json.bak	snapshot.ep.15	 snapshot.ep.22  snapshot.ep.5
../asr/espnet//egs/librispeech/asr1/dump/dev_clean/deltafalse/data_unigram5000.json


In [3]:
!ls $espnet_dir/egs/librispeech/asr1/asr1/exp/train_960_pytorch_train_specaug/results/snapshot.ep.27

../asr/espnet//egs/librispeech/asr1/asr1/exp/train_960_pytorch_train_specaug/results/snapshot.ep.27


In [4]:
recog_json="../asr/espnet//egs/librispeech/asr1/dump/dev_clean/deltafalse/data_unigram5000.json"
# read json data
with open(recog_json, "rb") as f:
    js = json.load(f)["utts"]

In [5]:
for item in js.items():
    print(item)
    break;

('1272-128104-0000', {'input': [{'feat': '/workspace/zhanghui/asr/espnet/egs/librispeech/asr1/dump/dev_clean/deltafalse/feats.1.ark:17', 'name': 'input1', 'shape': [584, 83]}], 'output': [{'name': 'target1', 'shape': [22, 5002], 'text': 'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL', 'token': '▁MISTER ▁QUI L TER ▁IS ▁THE ▁A PO ST LE ▁OF ▁THE ▁MIDDLE ▁CLASSES ▁AND ▁WE ▁ARE ▁GLAD ▁TO ▁WELCOME ▁HIS ▁GOSPEL', 'tokenid': '3008 3630 245 410 2598 4502 482 352 399 252 3204 4502 2974 1165 627 4845 689 2209 4577 4859 2391 2237'}], 'utt2spk': '1272-128104'})


In [6]:
state_dict = paddle.load('espnet.model', return_numpy=True)

  if isinstance(obj, collections.Iterable) and not isinstance(obj, (


In [7]:
print(state_dict.keys())
torch_forward_dict = state_dict['forward']
torch_backward_dict = state_dict['backward']
model_dict = state_dict['model']
xs = state_dict['xs']
ilens = state_dict['ilens']
ys = state_dict['ys']
olens = state_dict['olens']

dict_keys(['forward', 'backward', 'model', 'xs', 'ys', 'ilens', 'olens'])


In [8]:
print(xs)

[[[-0.74641937 -0.58889526 -0.79081845 ...  1.2798967   0.22447938
   -0.11255249]
  [-0.74641937 -0.574557   -0.9586555  ...  1.3187659   0.24862868
    0.06486106]
  [-0.51503205 -0.41683638 -0.6949115  ...  1.3654089   0.28082776
    0.04357147]
  ...
  [-0.65000796 -0.54588056 -0.9586555  ...  0.7260699  -0.24240735
   -0.19061443]
  [-0.74641937 -0.6175717  -0.8387718  ...  0.7131443  -0.2343576
    0.11453688]
  [-0.47646755 -0.37382168 -0.79081845 ...  1.0933245  -0.2343576
    0.44807434]]]


In [9]:
print(model_dict.values())

dict_values([array([[[[ 7.1029529e-02,  1.9559536e-04, -1.5515685e-01],
         [-9.9848308e-02, -3.6037773e-01,  3.4337088e-01],
         [-3.2645259e-02,  4.3610886e-01,  2.0822534e-02]]],


       [[[-1.5891600e-01,  1.3020800e-01,  2.7415212e-02],
         [ 2.0975447e-01, -1.4134079e-01, -2.6494140e-02],
         [-3.7469115e-02,  1.1016304e-02,  3.3086690e-03]]],


       [[[ 5.8521651e-02, -6.4714767e-02,  1.4879146e-02],
         [-1.0177678e-01,  5.2017633e-02, -1.0318909e-02],
         [-6.6046283e-02,  2.2357166e-02,  7.1352371e-03]]],


       ...,


       [[[ 2.5821188e-01,  2.4085315e-02,  7.7300176e-02],
         [-2.2517703e-01, -1.5048036e-02, -1.9404091e-01],
         [-1.3471361e-01,  1.0731172e-02,  1.1089972e-01]]],


       [[[ 6.6544212e-02,  2.1687316e-02, -3.0199403e-02],
         [-1.3713089e-01,  2.3822840e-03,  3.2115195e-02],
         [ 3.0754156e-02, -2.2211686e-02, -1.6062601e-02]]],


       [[[-6.8446390e-02, -1.1640425e-02, -5.8525059e-02],
         

In [10]:
print(torch_forward_dict.values())

dict_values([{'inputs': [array([[[[-0.74641937, -0.58889526, -0.79081845, ...,  1.2798967 ,
           0.22447938, -0.11255249],
         [-0.74641937, -0.574557  , -0.9586555 , ...,  1.3187659 ,
           0.24862868,  0.06486106],
         [-0.51503205, -0.41683638, -0.6949115 , ...,  1.3654089 ,
           0.28082776,  0.04357147],
         ...,
         [-0.65000796, -0.54588056, -0.9586555 , ...,  0.7260699 ,
          -0.24240735, -0.19061443],
         [-0.74641937, -0.6175717 , -0.8387718 , ...,  0.7131443 ,
          -0.2343576 ,  0.11453688],
         [-0.47646755, -0.37382168, -0.79081845, ...,  1.0933245 ,
          -0.2343576 ,  0.44807434]]]], dtype=float32)], 'outputs': array([[[[-2.3296108e+00, -2.3804862e+00, -2.2985189e+00, ...,
          -2.2790453e+00, -1.7673831e+00, -2.1839311e+00],
         [-2.4855013e+00, -2.5069976e+00, -2.5415361e+00, ...,
          -2.0678055e+00, -1.9863954e+00, -2.2065032e+00],
         [-2.1821668e+00, -2.6111586e+00, -2.4770577e+00, ...,

In [11]:
print(torch_backward_dict.values())

dict_values([{'grad_outs': [array(1., dtype=float32)], 'grad_ins': [array(1., dtype=float32)]}, {'grad_outs': [array(1., dtype=float32)], 'grad_ins': [array([1.], dtype=float32)]}, {'grad_outs': [array([[[-4.06068875e-06,  1.56248958e-17,  1.52769069e-08, ...,
          1.09671638e-09,  2.51141209e-11,  1.59137523e-17],
        [-4.95299128e-06,  2.18775094e-17,  1.37869876e-08, ...,
          1.04155373e-09,  2.89430840e-11,  2.20381661e-17],
        [-4.11035398e-06,  2.86817466e-17,  1.27417881e-08, ...,
          1.01215381e-09,  2.52582399e-11,  2.86526563e-17],
        ...,
        [-3.15360594e-05,  1.33016598e-16,  3.34257919e-07, ...,
          1.06512243e-09,  3.57643415e-10,  1.48437720e-16],
        [-6.90042070e-05,  2.00907771e-16,  1.28025863e-06, ...,
          2.03509898e-09,  4.20984914e-10,  2.27859910e-16],
        [-3.08763883e-05,  6.87805378e-17,  8.43607211e-07, ...,
          2.22061369e-09,  9.67530500e-11,  7.67507274e-17]]],
      dtype=float32)], 'grad_ins'

In [14]:
# paddle transformer

In [15]:
!ls ../DeepSpeech-2.x/examples/librispeech/s2/conf/transformer.yaml

../DeepSpeech-2.x/examples/librispeech/s2/conf/transformer.yaml


In [16]:
config_path = '../DeepSpeech-2.x/examples/librispeech/s2/conf/transformer.yaml'

In [220]:
# config = CfgNode()
# config.set_new_allowed(True)
# config.merge_from_file(config_path)

In [130]:
config = CfgNode.load_cfg(open(config_path, 'rt'))
print(config)

collator:
  augmentation_config: conf/augmentation.json
  batch_bins: 0
  batch_count: auto
  batch_frames_in: 0
  batch_frames_inout: 0
  batch_frames_out: 0
  batch_size: 32
  feat_dim: 83
  maxlen_in: 512
  maxlen_out: 150
  minibatches: 0
  num_encs: 1
  num_workers: 2
  sortagrad: 0
  spm_model_prefix: data/train_960_unigram5000
  stride_ms: 10.0
  subsampling_factor: 1
  unit_type: spm
  vocab_filepath: data/train_960_unigram5000_units.txt
  window_ms: 25.0
data:
  dev_manifest: data/manifest.dev
  test_manifest: data/manifest.test-clean
  train_manifest: data/manifest.train
decoding:
  alpha: 2.5
  batch_size: 64
  beam_size: 10
  beta: 0.3
  ctc_weight: 0.5
  cutoff_prob: 1.0
  cutoff_top_n: 0
  decoding_chunk_size: -1
  decoding_method: attention
  error_rate_type: wer
  lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
  num_decoding_left_chunks: -1
  num_proc_bsearch: 8
  simulate_streaming: False
model:
  cmvn_file: None
  cmvn_file_type: json
  decoder: transfor

In [131]:
config.model.input_dim = 83
config.model.output_dim = 5002
config.model.encoder_conf.output_size = 512
config.model.encoder_conf.attention_heads = 8
pmodel = U2Model.from_config(config.model)

[INFO 2021/08/31 04:38:26 u2.py:854] U2 Encoder type: transformer


In [132]:
print(pmodel)

U2Model(
  (encoder): TransformerEncoder(
    (embed): Conv2dSubsampling4(
      (pos_enc): PositionalEncoding(
        (dropout): Dropout(p=0.1, axis=None, mode=upscale_in_train)
      )
      (conv): Sequential(
        (0): Conv2D(1, 512, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)
        (1): ReLU()
        (2): Conv2D(512, 512, kernel_size=[3, 3], stride=[2, 2], data_format=NCHW)
        (3): ReLU()
      )
      (out): Sequential(
        (0): Linear(in_features=10240, out_features=512, dtype=float32)
      )
    )
    (after_norm): LayerNorm(normalized_shape=[512], epsilon=1e-12)
    (encoders): LayerList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiHeadedAttention(
          (linear_q): Linear(in_features=512, out_features=512, dtype=float32)
          (linear_k): Linear(in_features=512, out_features=512, dtype=float32)
          (linear_v): Linear(in_features=512, out_features=512, dtype=float32)
          (linear_out): Linear(in_features=512, out_f

In [133]:
for key, val in pmodel.named_parameters():
    print(key, '\t', val.shape)

encoder.embed.conv.0.weight 	 [512, 1, 3, 3]
encoder.embed.conv.0.bias 	 [512]
encoder.embed.conv.2.weight 	 [512, 512, 3, 3]
encoder.embed.conv.2.bias 	 [512]
encoder.embed.out.0.weight 	 [10240, 512]
encoder.embed.out.0.bias 	 [512]
encoder.after_norm.weight 	 [512]
encoder.after_norm.bias 	 [512]
encoder.encoders.0.self_attn.linear_q.weight 	 [512, 512]
encoder.encoders.0.self_attn.linear_q.bias 	 [512]
encoder.encoders.0.self_attn.linear_k.weight 	 [512, 512]
encoder.encoders.0.self_attn.linear_k.bias 	 [512]
encoder.encoders.0.self_attn.linear_v.weight 	 [512, 512]
encoder.encoders.0.self_attn.linear_v.bias 	 [512]
encoder.encoders.0.self_attn.linear_out.weight 	 [512, 512]
encoder.encoders.0.self_attn.linear_out.bias 	 [512]
encoder.encoders.0.feed_forward.w_1.weight 	 [512, 2048]
encoder.encoders.0.feed_forward.w_1.bias 	 [2048]
encoder.encoders.0.feed_forward.w_2.weight 	 [2048, 512]
encoder.encoders.0.feed_forward.w_2.bias 	 [512]
encoder.encoders.0.norm1.weight 	 [512]
encode

In [134]:
for key, val in model_dict.items():
    print(key, '\t', val.shape)

encoder.embed.conv.0.weight 	 (512, 1, 3, 3)
encoder.embed.conv.0.bias 	 (512,)
encoder.embed.conv.2.weight 	 (512, 512, 3, 3)
encoder.embed.conv.2.bias 	 (512,)
encoder.embed.out.0.weight 	 (10240, 512)
encoder.embed.out.0.bias 	 (512,)
encoder.encoders.0.self_attn.linear_q.weight 	 (512, 512)
encoder.encoders.0.self_attn.linear_q.bias 	 (512,)
encoder.encoders.0.self_attn.linear_k.weight 	 (512, 512)
encoder.encoders.0.self_attn.linear_k.bias 	 (512,)
encoder.encoders.0.self_attn.linear_v.weight 	 (512, 512)
encoder.encoders.0.self_attn.linear_v.bias 	 (512,)
encoder.encoders.0.self_attn.linear_out.weight 	 (512, 512)
encoder.encoders.0.self_attn.linear_out.bias 	 (512,)
encoder.encoders.0.feed_forward.w_1.weight 	 (512, 2048)
encoder.encoders.0.feed_forward.w_1.bias 	 (2048,)
encoder.encoders.0.feed_forward.w_2.weight 	 (2048, 512)
encoder.encoders.0.feed_forward.w_2.bias 	 (512,)
encoder.encoders.0.norm1.weight 	 (512,)
encoder.encoders.0.norm1.bias 	 (512,)
encoder.encoders.0.norm

In [135]:
pmodel.set_state_dict(model_dict)

In [136]:
for key, val in pmodel.state_dict().items():
    if 'concat_linear' in key:
        continue
    print(key, '\t', np.allclose(val, model_dict[key]))

encoder.embed.conv.0.weight 	 True
encoder.embed.conv.0.bias 	 True
encoder.embed.conv.2.weight 	 True
encoder.embed.conv.2.bias 	 True
encoder.embed.out.0.weight 	 True
encoder.embed.out.0.bias 	 True
encoder.after_norm.weight 	 True
encoder.after_norm.bias 	 True
encoder.encoders.0.self_attn.linear_q.weight 	 True
encoder.encoders.0.self_attn.linear_q.bias 	 True
encoder.encoders.0.self_attn.linear_k.weight 	 True
encoder.encoders.0.self_attn.linear_k.bias 	 True
encoder.encoders.0.self_attn.linear_v.weight 	 True
encoder.encoders.0.self_attn.linear_v.bias 	 True
encoder.encoders.0.self_attn.linear_out.weight 	 True
encoder.encoders.0.self_attn.linear_out.bias 	 True
encoder.encoders.0.feed_forward.w_1.weight 	 True
encoder.encoders.0.feed_forward.w_1.bias 	 True
encoder.encoders.0.feed_forward.w_2.weight 	 True
encoder.encoders.0.feed_forward.w_2.bias 	 True
encoder.encoders.0.norm1.weight 	 True
encoder.encoders.0.norm1.bias 	 True
encoder.encoders.0.norm2.weight 	 True
encoder.enc

encoder.encoders.11.feed_forward.w_2.weight 	 True
encoder.encoders.11.feed_forward.w_2.bias 	 True
encoder.encoders.11.norm1.weight 	 True
encoder.encoders.11.norm1.bias 	 True
encoder.encoders.11.norm2.weight 	 True
encoder.encoders.11.norm2.bias 	 True
decoder.embed.0.weight 	 True
decoder.after_norm.weight 	 True
decoder.after_norm.bias 	 True
decoder.output_layer.weight 	 True
decoder.output_layer.bias 	 True
decoder.decoders.0.self_attn.linear_q.weight 	 True
decoder.decoders.0.self_attn.linear_q.bias 	 True
decoder.decoders.0.self_attn.linear_k.weight 	 True
decoder.decoders.0.self_attn.linear_k.bias 	 True
decoder.decoders.0.self_attn.linear_v.weight 	 True
decoder.decoders.0.self_attn.linear_v.bias 	 True
decoder.decoders.0.self_attn.linear_out.weight 	 True
decoder.decoders.0.self_attn.linear_out.bias 	 True
decoder.decoders.0.src_attn.linear_q.weight 	 True
decoder.decoders.0.src_attn.linear_q.bias 	 True
decoder.decoders.0.src_attn.linear_k.weight 	 True
decoder.decoders.0.

In [137]:
paddle.set_device('cpu')
pmodel.eval()

model = pmodel
model.eval()

In [138]:


def numpy_from_tensor(xx):
    if isinstance(xx, (list, tuple)):
        return [numpy_from_tensor(x) for x in xx]
    if isinstance(xx, dict):
        return {key: numpy_from_tensor(val) for key, val in xx}
    if xx is None:
        return None
    if isinstance(xx, paddle.Tensor):
        return xx.numpy()
    return xx


forward_dict=OrderedDict()
backward_dict=OrderedDict()
def forward_hook(m, ins, outs):
    for n, mod in model.named_sublayers():
        if m is mod:
            forward_dict[n] = {
                'inputs': numpy_from_tensor(ins),
                'outputs':  numpy_from_tensor(outs),
            }
    
def backward_hook(m, grad_ins, grad_outs):
     if grad_outs is None or grad_outs[0] is None:
            return
     for n, mod in model.named_sublayers():
        if m is mod:
            backward_dict[n] = {
                'grad_outs': numpy_from_tensor(grad_outs),
                'grad_ins': numpy_from_tensor(grad_ins),
            }

            
for n, m in model.named_sublayers():
    if isinstance(m, paddle.nn.Layer):
        try:
            m.register_forward_post_hook(forward_hook)
            #m.register_backward_hook(backward_hook)
        except Exception as e:
            print(n, e)

            
def backwardhook(grad):
    if grad is None:
            return
    val_name = grad.name[:-5]
    for n, val in model.named_parameters():
        if val.name == val_name:
            backward_dict[n] = {
                'grad': numpy_from_tensor(grad),
            }

for n, val in model.named_parameters():
    #print(val.name)
    val.register_hook(backwardhook)

In [139]:
def allclose(a, b, atol=1e-5, rtol=0.0):
    if isinstance(a, (list, tuple)):
        return all([allclose(i, j) for i, j in zip(a, b)])
    #return np.allclose(a, b, atol, rtol)
    return np.all( np.abs(a - b) < atol )

In [140]:
model.eval()
pmodel.clear_gradients()

# print(xs)
p_xs_pad = paddle.to_tensor(xs)
p_ilens = paddle.to_tensor(ilens)
p_text = paddle.to_tensor(ys)
p_text_lengths = paddle.to_tensor(olens)

from inspect import signature

masks = make_non_pad_mask(p_ilens).unsqueeze(1)  # (B, 1, L)
x = p_xs_pad.unsqueeze(1)  # (b, c=1, t, f)
# print(x)
# x = pmodel.encoder.embed.conv(x)
# print(x)

t0 = torch_forward_dict['encoder.embed.conv.0']
t1 = torch_forward_dict['encoder.embed.conv.1']
t2 = torch_forward_dict['encoder.embed.conv.2']
t3 = torch_forward_dict['encoder.embed.conv.3']
assert np.allclose(t0['inputs'], xs)
l0_out = t0['outputs']
l1_out = t1['outputs']
l2_out = t2['outputs']
l3_out = t3['outputs']

convs = list(pmodel.encoder.embed.conv._sub_layers.values())
l0 = x = convs[0](x)
l1 = x = convs[1](x)
l2 = x = convs[2](x)
l3 = x = convs[3](x)

assert np.allclose(t0['inputs'], xs)
assert np.allclose(l0_out, l0)

assert np.allclose(t1['inputs'], l0)
assert np.allclose(l1_out, l1)

assert np.allclose(t2['inputs'], l1)
assert( allclose(l2_out, l2.numpy()))

# print(np.allclose(convs[2].weight.numpy(), model_dict['encoder.embed.conv.2.weight']))
# print(np.allclose(convs[2].bias.numpy(), model_dict['encoder.embed.conv.2.bias']))

assert allclose(t3['inputs'], l2.numpy())
assert allclose(l3_out, l3.numpy())

# for l in pmodel.encoder.embed.conv._sub_layers.values():
#     if hasattr(l, 'weight'):
#         print(l.weight)

In [179]:
class CTCLoss(paddle.nn.Layer):
    def __init__(self, blank=0, reduction='sum', batch_average=False):
        super().__init__()
        # last token id as blank id
        self.loss = paddle.nn.CTCLoss(blank=blank, reduction=reduction)
        self.batch_average = batch_average

    def forward(self, logits, ys_pad, hlens, ys_lens):
        """Compute CTC loss.

        Args:
            logits ([paddle.Tensor]): [B, Tmax, D]
            ys_pad ([paddle.Tensor]): [B, Tmax]
            hlens ([paddle.Tensor]): [B]
            ys_lens ([paddle.Tensor]): [B]

        Returns:
            [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
        """
        B = paddle.shape(logits)[0]
        # warp-ctc need logits, and do softmax on logits by itself
        # warp-ctc need activation with shape [T, B, V + 1]
        # logits: (B, L, D) -> (L, B, D)
        logits = logits.transpose([1, 0, 2])
        # (TODO:Hui Zhang) ctc loss does not support int64 labels
        ys_pad = ys_pad.astype(paddle.int32)
        loss = self.loss(
            logits, ys_pad, hlens, ys_lens, norm_by_batchsize=self.batch_average)
        if self.batch_average:
            # Batch-size average
            loss = loss / B
        return loss

In [180]:
model.eval()
pmodel.clear_gradients()
pmodel.ctc.criterion = CTCLoss()

p_xs_pad = paddle.to_tensor(xs)
p_ilens = paddle.to_tensor(ilens)
p_text = paddle.to_tensor(ys)
p_text_lengths = paddle.to_tensor(olens)

encoder_out, encoder_mask = pmodel.encoder(p_xs_pad, p_ilens)
# print(encoder_out.shape, encoder_mask.shape) #[1, 145, 512] [1, 1, 145]
print(encoder_out.sum())
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(1)  #[B, 1, T] -> [B]
# print(encoder_out_lens)

loss_ctc = pmodel.ctc(encoder_out, encoder_out_lens, p_text,
                                p_text_lengths)
print('loss', loss_ctc.numpy())
loss_ctc.backward(retain_graph=True)
print(loss_ctc.grad.numpy())
print(pmodel.ctc)

lo = pmodel.ctc.ctc_lo(encoder_out)
print(lo.numpy().sum())
print(lo.numpy())
# 122.53613
# loss tensor(6.2267, grad_fn=<SumBackward0>)
# tensor(1.)
# -8020950.5
# [[[ 22.989422   -14.030043     6.0457883  ...   3.0186079   -0.23548293
#    -14.020504  ]
#   [ 23.250462   -13.562406     6.260898   ...   3.5122383    0.09232712
#    -13.554235  ]
#   [ 23.280766   -13.543007     6.225544   ...   3.5360448    0.14748597
#    -13.537972  ]
#   ...
#   [ 13.485205   -23.31453     -0.3501923  ...  -7.0977297   -8.506543
#    -23.212677  ]
#   [ 14.065702   -22.272661     0.30908418 ...  -5.896199    -7.6193686
#    -22.165726  ]
#   [ 13.672498   -22.783966     0.1425507  ...  -6.2963276   -8.160887
#    -22.68066   ]]]

Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
       [122.53607178])
loss [5.980134]
[1.]
CTCDecoder(
  (ctc_lo): Linear(in_features=512, out_features=5002, dtype=float32)
  (criterion): CTCLoss(
    (loss): CTCLoss()
  )
)
-8020951.0
[[[ 22.989422   -14.030044     6.045786   ...   3.0186083   -0.2354827
   -14.020504  ]
  [ 23.250462   -13.562402     6.260898   ...   3.5122392    0.09232831
   -13.554233  ]
  [ 23.280764   -13.543003     6.225544   ...   3.5360477    0.14748693
   -13.537969  ]
  ...
  [ 13.485203   -23.314537    -0.35019422 ...  -7.0977325   -8.506544
   -23.21268   ]
  [ 14.065701   -22.272663     0.30908418 ...  -5.896202    -7.6193695
   -22.16573   ]
  [ 13.672493   -22.78397      0.1425457  ...  -6.2963324   -8.160891
   -22.680668  ]]]


In [174]:
ctc_dict = paddle.load('ctc_lo')
lot = ctc_dict['lo']
hs_len = ctc_dict['hs_len']
ys_pad = ctc_dict['ys_pad']
ys_len = ctc_dict['ys_len']
ys_len = paddle.to_tensor(ys_len).reshape((1,))
print(hs_len)
print(ys_pad)
print(ys_len)
print(lot.numpy().sum())
print(allclose(lo.numpy(), lot.numpy()))
loss_a = pmodel.ctc.criterion(lot,  ys_pad, hs_len, ys_len)
print(loss_a)

# loss 5.9801354

Tensor(shape=[1], dtype=int64, place=CPUPlace, stop_gradient=True,
       [145])
Tensor(shape=[1, 22], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[3008, 3630, 245 , 410 , 2598, 4502, 482 , 352 , 399 , 252 , 3204, 4502, 2974, 1165, 627 , 4845, 689 , 2209, 4577, 4859, 2391, 2237]])
Tensor(shape=[1], dtype=int64, place=CPUPlace, stop_gradient=True,
       [22])
-8020950.0
False
Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True,
       [5.98013544])


In [143]:
from pprint import pprint
pprint(list(forward_dict.keys()))

['encoder.embed.conv.0',
 'encoder.embed.conv.1',
 'encoder.embed.conv.2',
 'encoder.embed.conv.3',
 'encoder.embed.conv',
 'encoder.embed.out.0',
 'encoder.embed.out',
 'encoder.embed.pos_enc.dropout',
 'encoder.embed.pos_enc',
 'encoder.embed',
 'encoder.encoders.0.norm1',
 'encoder.encoders.0.self_attn.linear_q',
 'encoder.encoders.0.self_attn.linear_k',
 'encoder.encoders.0.self_attn.linear_v',
 'encoder.encoders.0.self_attn.dropout',
 'encoder.encoders.0.self_attn.linear_out',
 'encoder.encoders.0.self_attn',
 'encoder.encoders.0.dropout',
 'encoder.encoders.0.norm2',
 'encoder.encoders.0.feed_forward.w_1',
 'encoder.encoders.0.feed_forward.activation',
 'encoder.encoders.0.feed_forward.dropout',
 'encoder.encoders.0.feed_forward.w_2',
 'encoder.encoders.0.feed_forward',
 'encoder.encoders.0',
 'encoder.encoders.1.norm1',
 'encoder.encoders.1.self_attn.linear_q',
 'encoder.encoders.1.self_attn.linear_k',
 'encoder.encoders.1.self_attn.linear_v',
 'encoder.encoders.1.self_attn.drop

In [144]:
pprint(list(backward_dict.keys()))

['ctc.ctc_lo.bias',
 'ctc.ctc_lo.weight',
 'encoder.after_norm.bias',
 'encoder.after_norm.weight',
 'encoder.encoders.11.feed_forward.w_2.bias',
 'encoder.encoders.11.feed_forward.w_2.weight',
 'encoder.encoders.11.feed_forward.w_1.bias',
 'encoder.encoders.11.feed_forward.w_1.weight',
 'encoder.encoders.11.norm2.bias',
 'encoder.encoders.11.norm2.weight',
 'encoder.encoders.11.self_attn.linear_out.bias',
 'encoder.encoders.11.self_attn.linear_out.weight',
 'encoder.encoders.11.self_attn.linear_v.bias',
 'encoder.encoders.11.self_attn.linear_v.weight',
 'encoder.encoders.11.self_attn.linear_q.bias',
 'encoder.encoders.11.self_attn.linear_q.weight',
 'encoder.encoders.11.self_attn.linear_k.bias',
 'encoder.encoders.11.self_attn.linear_k.weight',
 'encoder.encoders.11.norm1.bias',
 'encoder.encoders.11.norm1.weight',
 'encoder.encoders.10.feed_forward.w_2.bias',
 'encoder.encoders.10.feed_forward.w_2.weight',
 'encoder.encoders.10.feed_forward.w_1.bias',
 'encoder.encoders.10.feed_forwa

In [200]:
def compare(key, atol=1e-5, flag=False):
    print(f'compare --B-- {key}')
    
    try:
        p = forward_dict[key]
        t = torch_forward_dict[key]
  
        ins = allclose(p['inputs'], t['inputs'], atol)
        outs = allclose(p['outputs'], t['outputs'], atol)
        print("\t", ins)
        print("\t", outs)
        if ins is False and flag:
            print('ins diff')
            print(p['inputs'])
            print(t['inputs'])
        if outs is False and flag:
            print('outs diff')
            print(p['outputs'])
            print(t['outputs'])
        
    except Exception as e:
        print(e)
        
    print(f'compare --E-- {key}')

In [201]:
compare('encoder.embed.conv.0')
compare('encoder.embed.conv.1')
compare('encoder.embed.conv.2')
compare('encoder.embed.conv.3')

compare --B-- encoder.embed.conv.0
	 True
	 True
compare --E-- encoder.embed.conv.0
compare --B-- encoder.embed.conv.1
	 True
	 True
compare --E-- encoder.embed.conv.1
compare --B-- encoder.embed.conv.2
	 True
	 True
compare --E-- encoder.embed.conv.2
compare --B-- encoder.embed.conv.3
	 True
	 True
compare --E-- encoder.embed.conv.3


In [202]:
compare('encoder.encoders.0.norm1', flag=True)

compare --B-- encoder.encoders.0.norm1
	 False
	 True
ins diff
[array([[[  4.589593 ,  32.161453 , -26.933317 , ...,  -6.813931 ,
           4.8903193, -14.901575 ],
        [ 10.80398  ,  19.90617  , -31.725603 , ...,  -1.7484844,
          -1.6071154,   6.7884765],
        [ 13.564527 ,  13.202793 , -30.384214 , ...,   7.0861425,
          12.307296 ,  -6.329128 ],
        ...,
        [ -9.602592 ,  18.36042  , -19.767456 , ..., -24.487469 ,
           7.9358573, -10.122811 ],
        [-20.52908  ,  17.745    , -29.497602 , ...,   1.8749719,
           3.2002244,  -9.745176 ],
        [ -7.890733 ,  18.230509 , -23.437292 , ...,  12.516883 ,
          27.098549 , -11.436547 ]]], dtype=float32)]
[array([[[  4.589587 ,  32.16146  , -26.93332  , ...,  -6.8139286,
           4.890319 , -14.901578 ],
        [ 10.803984 ,  19.906166 , -31.725603 , ...,  -1.7484815,
          -1.6071137,   6.7884765],
        [ 13.564524 ,  13.202791 , -30.384218 , ...,   7.0861416,
          12.307297 , 

In [203]:
for key in list(forward_dict.keys()):
    if 'encoder.encoders.0' not in key:
        continue
    compare(key, flag=True)

compare --B-- encoder.encoders.0.norm1
	 False
	 True
ins diff
[array([[[  4.589593 ,  32.161453 , -26.933317 , ...,  -6.813931 ,
           4.8903193, -14.901575 ],
        [ 10.80398  ,  19.90617  , -31.725603 , ...,  -1.7484844,
          -1.6071154,   6.7884765],
        [ 13.564527 ,  13.202793 , -30.384214 , ...,   7.0861425,
          12.307296 ,  -6.329128 ],
        ...,
        [ -9.602592 ,  18.36042  , -19.767456 , ..., -24.487469 ,
           7.9358573, -10.122811 ],
        [-20.52908  ,  17.745    , -29.497602 , ...,   1.8749719,
           3.2002244,  -9.745176 ],
        [ -7.890733 ,  18.230509 , -23.437292 , ...,  12.516883 ,
          27.098549 , -11.436547 ]]], dtype=float32)]
[array([[[  4.589587 ,  32.16146  , -26.93332  , ...,  -6.8139286,
           4.890319 , -14.901578 ],
        [ 10.803984 ,  19.906166 , -31.725603 , ...,  -1.7484815,
          -1.6071137,   6.7884765],
        [ 13.564524 ,  13.202791 , -30.384218 , ...,   7.0861416,
          12.307297 , 

In [181]:
pprint(list(backward_dict.items()))

[('ctc.ctc_lo.bias',
  {'grad': array([-7.5980431e-01,  2.4060975e-12,  1.3913538e-02, ...,
        9.5984916e-08,  5.8366727e-08,  2.5355662e-12], dtype=float32)}),
 ('ctc.ctc_lo.weight',
  {'grad': array([[-4.87796992e-01,  6.01302699e-13, -3.17799946e-04, ...,
         2.03148485e-08, -5.06627074e-10,  6.62482439e-13],
       [ 1.26905739e-01,  4.69981476e-13, -6.22264668e-03, ...,
        -1.42200274e-09,  1.27226674e-09,  5.07522084e-13],
       [-1.67604446e-01,  4.74478258e-13, -1.63182733e-03, ...,
         1.06399058e-08, -4.29758851e-09,  5.16066736e-13],
       ...,
       [ 1.19449295e-01, -7.67460585e-13,  1.40323769e-03, ...,
        -4.96660579e-09, -8.68938754e-10, -8.40603574e-13],
       [-5.17768860e-02, -2.64615447e-14,  5.42499730e-03, ...,
        -1.24599198e-09, -6.50550558e-10, -3.34047975e-14],
       [ 2.36392662e-01, -5.24148867e-13, -5.61758736e-03, ...,
        -2.11227871e-08, -9.90126559e-09, -5.71188496e-13]], dtype=float32)}),
 ('encoder.after_norm.bia

  {'grad': array([[ 7.0358469e-04, -7.5150613e-04, -7.5274403e-04, ...,
        -1.9717781e-05, -3.0760004e-04, -5.4739672e-04],
       [-1.3037412e-03,  5.8270805e-04, -3.9569705e-04, ...,
         7.4460125e-04,  4.5698002e-04,  1.0642219e-04],
       [-6.0577458e-03,  2.4920350e-03, -4.5638438e-03, ...,
         2.8206829e-03,  1.3020423e-03, -9.4622560e-04],
       ...,
       [ 7.2496217e-03, -4.3493696e-03,  2.6932175e-03, ...,
        -3.4491455e-03, -2.1193225e-03, -1.2221956e-03],
       [ 3.7118322e-03, -1.9060067e-03,  2.1010083e-03, ...,
        -2.6282023e-03, -1.2150276e-03, -2.2761514e-04],
       [ 2.3310289e-03, -1.6424268e-03,  8.7533740e-04, ...,
         6.3053344e-04, -9.9757116e-04, -1.8605493e-03]], dtype=float32)}),
 ('encoder.encoders.10.self_attn.linear_q.bias',
  {'grad': array([-1.49549532e-03, -8.03963002e-03, -5.14440937e-03,  1.06479286e-03,
       -2.95204204e-03, -5.18526742e-03,  1.83835346e-03, -1.68688584e-03,
       -8.79413169e-03, -5.72132599e-03,

      dtype=float32)}),
 ('encoder.encoders.8.self_attn.linear_out.bias',
  {'grad': array([ 1.4382483e-04,  5.2213552e-04, -3.2233080e-04, -2.0222631e-03,
        2.8613309e-04, -3.1220894e-03,  1.2110268e-03,  2.0420107e-03,
       -5.7587388e-04,  2.7400427e-04,  1.2515046e-04, -7.9931348e-04,
       -4.4606026e-04, -1.7287227e-03,  9.3072653e-04,  1.5345557e-03,
        4.0796609e-03,  7.5611338e-04, -7.3755370e-04, -5.3040305e-04,
        1.6955297e-03,  2.6596484e-03,  1.3498989e-03,  3.0886324e-04,
       -1.9389295e-04,  3.0574424e-03, -2.4690310e-04,  7.0793205e-04,
       -2.6117754e-03, -1.6199027e-03, -2.2693716e-03, -2.1793155e-04,
        1.7342917e-03, -3.3008223e-03,  1.1302719e-03, -1.1930097e-03,
       -2.1051384e-04,  9.0533809e-04, -4.8357169e-03, -2.5139793e-03,
       -1.1547963e-04, -2.2786697e-03, -2.7160475e-04,  2.1218802e-04,
       -1.4141812e-03,  1.2314267e-03,  1.1988929e-03, -5.7567225e-04,
       -1.0373858e-03,  1.7886624e-03,  1.7689895e-03,  1.89204

  {'grad': array([-2.64866203e-01,  4.17262465e-01,  4.20839310e-01,  5.41534647e-02,
       -3.80996794e-01,  7.94707164e-02,  1.07565331e+00, -3.04445714e-01,
       -2.83773214e-01,  2.00691834e-01,  3.31396192e-01, -4.64294069e-02,
        7.97953159e-02, -3.96058291e-01,  1.91019937e-01, -1.57462075e-01,
       -6.14420176e-01, -3.06210741e-02,  2.68711038e-02,  1.30840400e-02,
       -1.14754476e-01, -1.70583293e-01,  8.32576975e-02, -6.81443140e-02,
       -3.11598986e-01, -7.20727146e-01,  1.24914035e-01,  4.50322121e-01,
        3.66669685e-01,  4.17923257e-02,  1.37750626e-01, -1.92657001e-02,
       -6.78844154e-02,  3.63209128e-01,  7.95010850e-02,  7.78455734e-02,
        1.10551186e-01, -2.08903566e-01,  2.91645437e-01, -1.49570704e-01,
        1.82645634e-01, -1.50872961e-01, -1.38239861e-01, -6.10721171e-01,
        5.67548573e-01,  2.41643414e-01, -6.94878817e-01, -1.00836024e-01,
       -1.37836874e-01,  3.61277968e-01, -9.29311216e-02,  9.12603140e-02,
        2.2054

 ('encoder.encoders.5.norm1.weight',
  {'grad': array([ 7.27066323e-02, -7.26326182e-02,  1.64067015e-01, -3.83047871e-02,
       -3.82522903e-02, -1.42215220e-02,  1.86826482e-01, -2.33801901e-02,
       -9.25074890e-03, -7.73940841e-03,  2.67077722e-02, -2.86743720e-03,
        3.71948145e-02,  2.85786781e-02,  3.48364487e-02,  5.13052307e-02,
        1.01086035e-01,  1.40815675e+00,  2.08091624e-02,  4.56687137e-02,
        1.22292779e-01, -5.31555153e-02,  8.56871437e-03, -4.90930714e-02,
        1.47386966e-02,  1.05465772e-02,  7.73067772e-02,  6.12093024e-02,
        3.81300750e-04,  5.58782340e-05,  2.59582531e-02, -1.53923845e-02,
        2.61087599e-03,  1.77464914e-02,  1.08807790e+00,  1.89964324e-02,
       -2.47139502e-02,  4.16952036e-02, -3.64689417e-02, -4.23430465e-02,
       -1.83628295e-02, -1.32073477e-01, -5.74012585e-02,  1.84487235e-02,
       -3.59035991e-02, -6.24219179e-02,  1.46104023e-02, -5.21366950e-03,
        3.28290090e-02, -3.83214317e-02, -2.27456335

  {'grad': array([ 4.25033085e-02, -2.14494318e-02, -4.64881165e-03,  8.59497041e-02,
       -1.47966137e-02, -3.09458952e-02,  3.82897824e-01, -1.49034858e-02,
       -3.50584351e-02, -4.42300625e-02,  3.77171487e-02, -8.11239406e-02,
        7.60458130e-03,  6.57098219e-02, -2.42629144e-02,  3.20944525e-02,
        4.52051759e-02,  9.64045584e-01, -9.84071046e-02,  2.46175118e-02,
       -9.88012925e-03, -4.07016017e-02,  3.68434452e-02,  4.08901051e-02,
       -1.39611736e-02,  2.98673734e-02,  1.45202577e-02, -2.49042790e-02,
       -5.46606295e-02,  1.01912394e-02,  1.63677074e-02,  4.44443114e-02,
       -7.62273790e-03, -1.40018947e-02,  7.02040672e-01,  2.80412496e-03,
       -8.42561275e-02,  4.55998741e-02,  3.44946831e-02,  5.20158149e-02,
        2.20656209e-02, -8.16464275e-02, -2.11080033e-02,  5.84599823e-02,
       -3.46273388e-04, -3.77144925e-02, -2.88715716e-02,  3.78679447e-02,
        2.00957339e-02, -6.21237850e-04,  1.95414238e-02,  3.32795829e-02,
        1.3552

  {'grad': array([[-1.48544117e-04, -1.70481435e-04,  4.60803625e-04, ...,
        -6.54856558e-04,  3.10310890e-04,  8.78230858e-05],
       [ 1.20932105e-04,  4.29786451e-05,  2.58068118e-04, ...,
        -6.54926407e-04, -4.56220805e-05, -9.02897955e-05],
       [ 1.40504271e-04,  1.83801429e-04, -6.06206304e-04, ...,
        -3.94000090e-04,  6.81056757e-04, -2.26811724e-04],
       ...,
       [ 1.67459439e-05,  1.93060361e-04, -7.70102430e-04, ...,
        -5.87491435e-04,  1.62875702e-04, -5.35585685e-04],
       [ 1.11376385e-05, -2.76254355e-06, -2.83918896e-04, ...,
        -4.85459714e-05,  5.25005395e-04, -2.00063922e-04],
       [ 1.01942605e-04, -3.72820243e-04,  2.53951206e-04, ...,
         1.42609730e-04, -5.83348155e-04,  3.70269525e-04]], dtype=float32)}),
 ('encoder.encoders.1.norm1.bias',
  {'grad': array([ 2.54925877e-01,  5.87818846e-02,  5.53244986e-02,  2.67834663e-01,
       -2.43211254e-01,  4.43823159e-01, -1.35102645e-01,  6.61320984e-02,
       -1.89733177

         [  0.      ,   0.      ,   0.      ]]]], dtype=float32)})]


In [186]:
print(backward_dict['ctc.ctc_lo.bias']['grad'])
print(backward_dict['ctc.ctc_lo.weight']['grad'])
# ctc.ctc_lo.bias [-7.5972462e-01  2.4060894e-12  1.3913637e-02 ...  9.5985001e-08
#   5.8366982e-08  2.5355482e-12]
# ctc.ctc_lo.weight [[-4.8778719e-01  6.0130427e-13 -3.1780367e-04 ...  2.0314937e-08
#   -5.0664734e-10  6.6247995e-13]
#  [ 1.2689343e-01  4.6998050e-13 -6.2227007e-03 ... -1.4220235e-09
#    1.2722685e-09  5.0751840e-13]
#  [-1.6759640e-01  4.7447750e-13 -1.6318382e-03 ...  1.0639946e-08
#   -4.2976129e-09  5.1606310e-13]
#  ...
#  [ 1.1945984e-01 -7.6746086e-13  1.4032462e-03 ... -4.9666298e-09
#   -8.6895852e-10 -8.4059870e-13]
#  [-5.1776640e-02 -2.6462575e-14  5.4250457e-03 ... -1.2460214e-09
#   -6.5056005e-10 -3.3405353e-14]
#  [ 2.3638242e-01 -5.2414919e-13 -5.6176288e-03 ... -2.1122869e-08
#   -9.9013251e-09 -5.7118551e-13]]

[-7.5980431e-01  2.4060975e-12  1.3913538e-02 ...  9.5984916e-08
  5.8366727e-08  2.5355662e-12]
[[-4.87796992e-01  6.01302699e-13 -3.17799946e-04 ...  2.03148485e-08
  -5.06627074e-10  6.62482439e-13]
 [ 1.26905739e-01  4.69981476e-13 -6.22264668e-03 ... -1.42200274e-09
   1.27226674e-09  5.07522084e-13]
 [-1.67604446e-01  4.74478258e-13 -1.63182733e-03 ...  1.06399058e-08
  -4.29758851e-09  5.16066736e-13]
 ...
 [ 1.19449295e-01 -7.67460585e-13  1.40323769e-03 ... -4.96660579e-09
  -8.68938754e-10 -8.40603574e-13]
 [-5.17768860e-02 -2.64615447e-14  5.42499730e-03 ... -1.24599198e-09
  -6.50550558e-10 -3.34047975e-14]
 [ 2.36392662e-01 -5.24148867e-13 -5.61758736e-03 ... -2.11227871e-08
  -9.90126559e-09 -5.71188496e-13]]


In [207]:
espnet_grad_dict = paddle.load('espnet.grad', return_numpy=True)

In [219]:
def allclose(a, b, atol=1e-5, rtol=0.0):
    if isinstance(a, (list, tuple)):
        return sum([allclose(i, j) for i, j in zip(a, b)])
    #return np.allclose(a, b, atol, rtol)
    print(~(np.abs(a - b) < atol))
    return np.sum(~(np.abs(a - b) < atol))

print(allclose(espnet_grad_dict['ctc.ctc_lo.weight'], backward_dict['ctc.ctc_lo.weight']['grad'], 1e-6))
print(allclose(espnet_grad_dict['ctc.ctc_lo.bias'], backward_dict['ctc.ctc_lo.bias']['grad'], 1e-6))

[[ True False False ... False False False]
 [ True False False ... False False False]
 [ True False False ... False False False]
 ...
 [ True False False ... False False False]
 [False False False ... False False False]
 [ True False False ... False False False]]
1010
[ True False False ... False False False]
12
