# tensorflow

In [92]:
#
# SPDX-FileCopyrightText: Copyright (c) 2021-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Layer for applying channel responses to channel inputs in the time domain"""

import tensorflow as tf

import numpy as np

import scipy

from sionna.utils import insert_dims
from sionna.channel.awgn import AWGN

from sionna.constants import GLOBAL_SEED_NUMBER

class ApplyTimeChannel(tf.keras.layers.Layer):
    # pylint: disable=line-too-long
    r"""ApplyTimeChannel(num_time_samples, l_tot, add_awgn=True, dtype=tf.complex64, **kwargs)

    Apply time domain channel responses ``h_time`` to channel inputs ``x``,
    by filtering the channel inputs with time-variant channel responses.

    This class inherits from the Keras `Layer` class and can be used as layer
    in a Keras model.

    For each batch example, ``num_time_samples`` + ``l_tot`` - 1 time steps of a
    channel realization are required to filter the channel inputs.

    The channel output consists of ``num_time_samples`` + ``l_tot`` - 1
    time samples, as it is the result of filtering the channel input of length
    ``num_time_samples`` with the time-variant channel filter  of length
    ``l_tot``. In the case of a single-input single-output link and given a sequence of channel
    inputs :math:`x_0,\cdots,x_{N_B}`, where :math:`N_B` is ``num_time_samples``, this
    layer outputs

    .. math::
        y_b = \sum_{\ell = 0}^{L_{\text{tot}}} x_{b-\ell} \bar{h}_{b,\ell} + w_b

    where :math:`L_{\text{tot}}` corresponds ``l_tot``, :math:`w_b` to the additive noise, and
    :math:`\bar{h}_{b,\ell}` to the :math:`\ell^{th}` tap of the :math:`b^{th}` channel sample.
    This layer outputs :math:`y_b` for :math:`b` ranging from 0 to
    :math:`N_B + L_{\text{tot}} - 1`, and :math:`x_{b}` is set to 0 for :math:`b \geq N_B`.

    For multiple-input multiple-output (MIMO) links, the channel output is computed for each antenna
    of each receiver and by summing over all the antennas of all transmitters.

    Parameters
    ----------

    num_time_samples : int
        Number of time samples forming the channel input (:math:`N_B`)

    l_tot : int
        Length of the channel filter (:math:`L_{\text{tot}} = L_{\text{max}} - L_{\text{min}} + 1`)

    add_awgn : bool
        If set to `False`, no white Gaussian noise is added.
        Defaults to `True`.

    dtype : tf.DType
        Complex datatype to use for internal processing and output.
        Defaults to `tf.complex64`.

    Input
    -----

    (x, h_time, no) or (x, h_time):
        Tuple:

    x :  [batch size, num_tx, num_tx_ant, num_time_samples], tf.complex
        Channel inputs

    h_time : [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_tot - 1, l_tot], tf.complex
        Channel responses.
        For each batch example, ``num_time_samples`` + ``l_tot`` - 1 time steps of a
        channel realization are required to filter the channel inputs.

    no : Scalar or Tensor, tf.float
        Scalar or tensor whose shape can be broadcast to the shape of the channel outputs: [batch size, num_rx, num_rx_ant, num_time_samples + l_tot - 1].
        Only required if ``add_awgn`` is set to `True`.
        The noise power ``no`` is per complex dimension. If ``no`` is a
        scalar, noise of the same variance will be added to the outputs.
        If ``no`` is a tensor, it must have a shape that can be broadcast to
        the shape of the channel outputs. This allows, e.g., adding noise of
        different variance to each example in a batch. If ``no`` has a lower
        rank than the channel outputs, then ``no`` will be broadcast to the
        shape of the channel outputs by adding dummy dimensions after the
        last axis.

    Output
    -------
    y : [batch size, num_rx, num_rx_ant, num_time_samples + l_tot - 1], tf.complex
        Channel outputs.
        The channel output consists of ``num_time_samples`` + ``l_tot`` - 1
        time samples, as it is the result of filtering the channel input of length
        ``num_time_samples`` with the time-variant channel filter  of length
        ``l_tot``.
    """

    def __init__(self, num_time_samples, l_tot, add_awgn=True,
                 dtype=tf.complex64, **kwargs):

        super().__init__(trainable=False, dtype=dtype, **kwargs)

        self._add_awgn = add_awgn

        # The channel transfert function is implemented by first gathering from
        # the vector of transmitted baseband symbols
        # x = [x_0,...,x_{num_time_samples-1}]^T  the symbols that are then
        # multiplied by the channel tap coefficients.
        # We build here the matrix of indices G, with size
        # `num_time_samples + l_tot - 1` x `l_tot` that is used to perform this
        # gathering.
        # For example, if there are 4 channel taps
        # h = [h_0, h_1, h_2, h_3]^T
        # and `num_time_samples` = 10 time steps then G  would be
        #       [[0, 10, 10, 10]
        #        [1,  0, 10, 10]
        #        [2,  1,  0, 10]
        #        [3,  2,  1,  0]
        #        [4,  3,  2,  1]
        #        [5,  4,  3,  2]
        #        [6,  5,  4,  3]
        #        [7,  6,  5,  4]
        #        [8,  7,  6,  5]
        #        [9,  8,  7,  6]
        #        [10, 9,  8,  7]
        #        [10,10,  9,  8]
        #        [10,10, 10,  9]
        # Note that G is a Toeplitz matrix.
        # In this example, the index `num_time_samples`=10 corresponds to the
        # zero symbol. The vector of transmitted symbols is padded with one
        # zero at the end.
        first_colum = np.concatenate([  np.arange(0, num_time_samples),
                                        np.full([l_tot-1], num_time_samples)])
        first_row = np.concatenate([[0], np.full([l_tot-1], num_time_samples)])
        self._g = scipy.linalg.toeplitz(first_colum, first_row)

    def build(self, input_shape): #pylint: disable=unused-argument

        if self._add_awgn:
            self._awgn = AWGN(dtype=self.dtype)

    def call(self, inputs):

        if self._add_awgn:
            x, h_time, no = inputs
        else:
            x, h_time = inputs

        # Preparing the channel input for broadcasting and matrix multiplication
        x = tf.pad(x, [[0,0], [0,0], [0,0], [0,1]])
        x = insert_dims(x, 2, axis=1)

        x = tf.gather(x, self._g, axis=-1)

        # Apply the channel response
        y = tf.reduce_sum(h_time*x, axis=-1)
        y = tf.reduce_sum(tf.reduce_sum(y, axis=4), axis=3)

        # Add AWGN if requested
        if self._add_awgn:
            y = self._awgn((y, no))

        return y
# 测试
def test_apply_time_channel():
    tf.random.set_seed(GLOBAL_SEED_NUMBER)
    # 定义测试数据的形状
    batch_size = 2
    num_time_samples = 10
    l_tot = 4
    num_tx = 3
    num_tx_ant = 2
    num_rx = 2
    num_rx_ant = 1

    # 生成输入数据 x 和频率响应 h_time
    x_real = tf.random.normal(shape=(batch_size, num_tx, num_tx_ant, num_time_samples), dtype=tf.float32)
    x_imag = tf.random.normal(shape=(batch_size, num_tx, num_tx_ant, num_time_samples), dtype=tf.float32)
    x = tf.complex(x_real, x_imag)

    h_time_real = tf.random.normal(shape=(batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_tot - 1, l_tot), dtype=tf.float32)
    h_time_imag = tf.random.normal(shape=(batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_tot - 1, l_tot), dtype=tf.float32)
    h_time = tf.complex(h_time_real, h_time_imag)

    # 噪声功率 no
    no = tf.constant(0.1, dtype=tf.float32)

    # 创建 ApplyTimeChannel 实例
    apply_channel = ApplyTimeChannel(num_time_samples=num_time_samples, l_tot=l_tot, add_awgn=True, dtype=tf.complex64)

    # 执行前向传播
    y = apply_channel((x, h_time, no))

    # 打印输入和输出
    
    print("输入数据 x:")
    print(x)
    print("\n频率响应 h_time:")
    print(h_time)
    print("\n噪声功率 no:")
    print(no)
    print("\n输出数据 y:")
    print(y)
    

# 运行测试函数
test_apply_time_channel()  



输入数据 x:
tf.Tensor(
[[[[ 0.16052227+0.64973587j -1.6597689 +0.32791495j
    -1.2321332 -0.75198144j  0.5971658 -0.21430095j
     1.0609884 +0.52599317j -1.3277572 +1.1992904j
    -0.27911443-1.292074j   -0.02141875+0.1150163j
    -1.502249  -0.42695856j  0.3066489 +0.6561277j ]
   [ 0.5355358 +0.06107097j -1.3167298 +0.78159803j
     0.7335615 +1.5706491j  -1.1566194 +1.1413449j
     1.6611129 -0.5492473j   0.974669  +0.64840907j
    -0.10358131+1.053436j    2.0585718 +1.8392539j
    -2.1600595 +0.7877717j  -0.71011245+2.251656j  ]]

  [[-1.4161334 -1.1881466j  -0.61227006-0.33408013j
    -0.25455204+0.95721877j -0.18277623-2.432254j
    -0.93512297+0.70165807j -0.74408674+1.1474625j
     0.8501864 +0.2107991j  -0.8065856 +0.29559776j
    -0.9804924 -1.3067406j  -0.6469936 +0.626799j  ]
   [-1.0293618 -1.6479712j  -0.81056195+0.49235418j
    -0.6525119 +0.2504584j   0.5092518 -0.94027156j
     1.5420486 -1.068339j    1.4433792 -1.9018799j
    -0.7556428 -0.14345959j  0.16503793-1.095923

# pytorch

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.linalg
from my_code.mysionna.channel.torch_version.awgn import AWGN
from my_code.mysionna.utils import insert_dims

import tensorflow as tf

from sionna.constants import GLOBAL_SEED_NUMBER

def gather_pytorch(input_data,  axis=0, indices=None):
    if axis < 0:
        axis = len(input_data.shape) + axis
    data = torch.index_select(input_data, axis, indices.flatten())

    shape_input = list(input_data.shape)
    # shape_ = delete(shape_input, axis)
    
    # 连接列表
    shape_output = shape_input[:axis] + \
        list(indices.shape) + shape_input[axis + 1:]

    data_output = data.reshape(shape_output)

    return data_output


class ApplyTimeChannel(nn.Module):
    def __init__(self, num_time_samples, l_tot, add_awgn=True, dtype=torch.complex64):
        super(ApplyTimeChannel, self).__init__()
        self.add_awgn = add_awgn
        self.dtype = dtype

        # Generate Toeplitz matrix for gathering
        first_column = np.concatenate([np.arange(0, num_time_samples), np.full([l_tot - 1], num_time_samples)])
        first_row = np.concatenate([[0], np.full([l_tot - 1], num_time_samples)])
        self._g = torch.tensor(scipy.linalg.toeplitz(first_column, first_row), dtype=torch.long)

        if self.add_awgn:
            self.awgn = AWGN(dtype=dtype)

    def forward(self, inputs):
        if self.add_awgn:
            x, h_time, no = inputs
        else:
            x, h_time = inputs

        # Prepare the channel input for broadcasting and matrix multiplication
        x = F.pad(x, (0, 1))
        x = insert_dims(x, 2, axis=1)     # Add singleton dimension at the end

        # Gather operation
        x = gather_pytorch(x, -1,self._g)

        # Apply the channel response
        y = torch.sum(h_time * x, dim=-1)
        y = torch.sum(torch.sum(y, dim=4), dim=3)

        # Add AWGN if requested
        if self.add_awgn:
            y = self.awgn((y, no))

        return y


## 测试
def generate_torch_data(batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples, l_tot):
    # 生成随机数据，使用 TensorFlow 生成，然后转换为 NumPy 数组，最后转换为 PyTorch Tensor
    x_real = tf.random.normal(shape=(batch_size, num_tx, num_tx_ant, num_time_samples), dtype=tf.float32)
    x_imag = tf.random.normal(shape=(batch_size, num_tx, num_tx_ant, num_time_samples), dtype=tf.float32)
    x_real_np = x_real.numpy()
    x_imag_np = x_imag.numpy()
    x_np = np.complex64(x_real_np + 1j * x_imag_np)
    x_torch = torch.tensor(x_np)

    h_time_real = tf.random.normal(shape=(batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_tot - 1, l_tot), dtype=tf.float32)
    h_time_imag = tf.random.normal(shape=(batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_tot - 1, l_tot), dtype=tf.float32)
    h_time_real_np = h_time_real.numpy()
    h_time_imag_np = h_time_imag.numpy()
    h_time_np = np.complex64(h_time_real_np + 1j * h_time_imag_np)
    h_time_torch = torch.tensor(h_time_np)

    return x_torch, h_time_torch

def test_apply_time_channel():
    
    # 设置CPU随机种子
    tf.random.set_seed(GLOBAL_SEED_NUMBER)

    # 定义测试数据的形状
    batch_size = 2
    num_time_samples = 10
    l_tot = 4
    num_tx = 3
    num_tx_ant = 2
    num_rx = 2
    num_rx_ant = 1

    # 生成输入数据 x 和频率响应 h_time
    x, h_time = generate_torch_data(batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples, l_tot)

    # 噪声功率 no
    no = torch.tensor(0.1, dtype=torch.float32)

    # 创建 ApplyTimeChannel 实例
    apply_channel = ApplyTimeChannel(num_time_samples=num_time_samples, l_tot=l_tot, add_awgn=True, dtype=torch.complex64)

    # 执行前向传播
    y = apply_channel((x, h_time, no))

    # 打印输入和输出
    print("输入数据 x:")
    print(x)
    print("\n频率响应 h_time:")
    print(h_time)
    print("\n噪声功率 no:")
    print(no)
    print("\n输出数据 y:")
    print(y)

# 运行测试函数
test_apply_time_channel()

输入数据 x:
tensor([[[[ 0.1605+0.6497j, -1.6598+0.3279j, -1.2321-0.7520j,  0.5972-0.2143j,
            1.0610+0.5260j, -1.3278+1.1993j, -0.2791-1.2921j, -0.0214+0.1150j,
           -1.5022-0.4270j,  0.3066+0.6561j],
          [ 0.5355+0.0611j, -1.3167+0.7816j,  0.7336+1.5706j, -1.1566+1.1413j,
            1.6611-0.5492j,  0.9747+0.6484j, -0.1036+1.0534j,  2.0586+1.8393j,
           -2.1601+0.7878j, -0.7101+2.2517j]],

         [[-1.4161-1.1881j, -0.6123-0.3341j, -0.2546+0.9572j, -0.1828-2.4323j,
           -0.9351+0.7017j, -0.7441+1.1475j,  0.8502+0.2108j, -0.8066+0.2956j,
           -0.9805-1.3067j, -0.6470+0.6268j],
          [-1.0294-1.6480j, -0.8106+0.4924j, -0.6525+0.2505j,  0.5093-0.9403j,
            1.5420-1.0683j,  1.4434-1.9019j, -0.7556-0.1435j,  0.1650-1.0959j,
            1.7516+1.2201j,  0.7456-1.7196j]],

         [[-1.0260-0.8269j, -0.6620+2.4498j, -0.1749+1.0104j, -0.1698-1.7347j,
           -0.2755+0.5248j,  1.3250-0.4136j, -1.6598+0.5964j, -1.1208+0.2654j,
            1.

# gather_pytorch 验证

In [21]:
import numpy as np
import tensorflow as tf
import torch


def gather_pytorch(input_data,  axis=0, indices=None):
    if axis < 0:
        axis = len(input_data.shape) + axis
        
    data = torch.index_select(input_data, axis, indices.flatten())

    shape_input = list(input_data.shape)
    # shape_ = delete(shape_input, axis)
    
    # 连接列表
    shape_output = shape_input[:axis] + \
        list(indices.shape) + shape_input[axis + 1:]

    data_output = data.reshape(shape_output)

    return data_output


data = torch.randn(2, 3, 4,13,5)
indices = torch.tensor([[0, 0, 2], [1, 1, 2]])

data_torch = gather_pytorch(data, -1, indices).numpy()

data_tf = tf.gather(data, indices, axis=-1).numpy()

print(data_torch.shape)
print(data_tf.shape)
# print(data_torch)
# print(data_tf)
print(np.sum(data_torch - data_tf))


(2, 3, 4, 13, 2, 3)
(2, 3, 4, 13, 2, 3)
0.0


# gather_pytorch 逐行调试 实验

In [23]:
import torch

def gather_pytorch(input_data, axis=0, indices=None):
    if axis < 0:
        axis = len(input_data.shape) + axis

    print("Adjusted axis:", axis)
    
    # Flatten indices for index_select
    flattened_indices = indices.flatten()
    print("Flattened indices:", flattened_indices)
    
    data = torch.index_select(input_data, axis, flattened_indices)
    print("Data after index_select:", data)
    
    shape_input = list(input_data.shape)
    print("Input shape:", shape_input)
    
    # Combine input shape and indices shape
    shape_output = shape_input[:axis] + list(indices.shape) + shape_input[axis + 1:]
    print("Output shape after combining shapes:", shape_output)
    
    data_output = data.reshape(shape_output)
    print("Final output shape:", data_output.shape)
    print("Final output data:", data_output)

    return data_output

# Run the example
input_data = torch.tensor([[1, 2, 3],
                           [4, 5, 6],
                           [7, 8, 9]])
indices = torch.tensor([[2, 0],
                        [1, 1]])
axis = 1

gather_pytorch(input_data, axis, indices)


Adjusted axis: 1
Flattened indices: tensor([2, 0, 1, 1])
Data after index_select: tensor([[3, 1, 2, 2],
        [6, 4, 5, 5],
        [9, 7, 8, 8]])
Input shape: [3, 3]
Output shape after combining shapes: [3, 2, 2]
Final output shape: torch.Size([3, 2, 2])
Final output data: tensor([[[3, 1],
         [2, 2]],

        [[6, 4],
         [5, 5]],

        [[9, 7],
         [8, 8]]])


tensor([[[3, 1],
         [2, 2]],

        [[6, 4],
         [5, 5]],

        [[9, 7],
         [8, 8]]])