# tensorflow

## GenerateFlatFadingChannel

In [93]:
import tensorflow as tf
from sionna.channel import AWGN
from sionna.utils import complex_normal

class GenerateFlatFadingChannel():
    def __init__(self, num_tx_ant, num_rx_ant, spatial_corr=None, dtype=tf.complex64, **kwargs):
        super().__init__(**kwargs)
        self._num_tx_ant = num_tx_ant
        self._num_rx_ant = num_rx_ant
        self._dtype = dtype
        self.spatial_corr = spatial_corr

    @property
    def spatial_corr(self):
        """The :class:`~sionna.channel.SpatialCorrelation` to be used."""
        return self._spatial_corr

    @spatial_corr.setter
    def spatial_corr(self, value):
        self._spatial_corr = value

    def __call__(self, batch_size):
        # Generate standard complex Gaussian matrices
        shape = [batch_size, self._num_rx_ant, self._num_tx_ant]
        h = complex_normal(shape, dtype=self._dtype)

        # Apply spatial correlation
        if self.spatial_corr is not None:
            h = self.spatial_corr(h)

        return h

In [94]:
from sionna.channel.spatial_correlation import SpatialCorrelation
class SimpleSpatialCorrelation(SpatialCorrelation):
    def __init__(self, correlation_matrix):
        """
        Parameters
        ----------
        correlation_matrix : tf.Tensor
            A square matrix used to introduce spatial correlation.
        """
        self.correlation_matrix = tf.cast(correlation_matrix, tf.complex64)

    def __call__(self, h, *args, **kwargs):
        # Apply the correlation matrix to the input tensor
        h_shape = tf.shape(h)
        h_reshaped = tf.reshape(h, [-1, h_shape[-1]])
        h_corr_reshaped = tf.matmul(h_reshaped, self.correlation_matrix)
        h_corr = tf.reshape(h_corr_reshaped, h_shape)
        return h_corr 
def test_generate_flat_fading_channel_with_spatial_corr():
    # Parameters
    num_tx_ant = 4
    num_rx_ant = 2
    batch_size = 2

    # Define a simple correlation matrix
    correlation_matrix = tf.constant([[1.0, 0.5, 0.5, 0.2],
                                      [0.5, 1.0, 0.3, 0.4],
                                      [0.5, 0.3, 1.0, 0.6],
                                      [0.2, 0.4, 0.6, 1.0]], dtype=tf.float32)
    
    # Initialize the SimpleSpatialCorrelation instance
    spatial_corr = SimpleSpatialCorrelation(correlation_matrix)
    
    # Initialize the GenerateFlatFadingChannel instance
    channel_generator = GenerateFlatFadingChannel(
        num_tx_ant=num_tx_ant,
        num_rx_ant=num_rx_ant,
        spatial_corr=spatial_corr,  # Use the SimpleSpatialCorrelation instance
        dtype=tf.complex64
    )

    # Generate a batch of channel matrices
    h = channel_generator(batch_size)

    # Print the output
    print("Generated channel matrices with spatial correlation:")
    print(h)
test_generate_flat_fading_channel_with_spatial_corr()

Generated channel matrices with spatial correlation:
tf.Tensor(
[[[ 0.28819647-0.18808052j  0.75670254-0.49159977j
    0.5917811 -0.12480513j  0.63494235-0.37096283j]
  [ 0.8017786 -0.11795726j -0.02173579-0.07250741j
   -0.61142975+0.3283163j  -1.3197908 +0.37896475j]]

 [[-0.81450444-0.5996711j  -0.7622298 -0.3086362j
   -0.67319655-1.4973358j  -0.3295982 -1.1230161j ]
  [-0.83335257-0.68523586j -0.72149324-1.1087103j
   -1.085548  -1.2565811j  -1.0349623 -1.3363783j ]]], shape=(2, 2, 4), dtype=complex64)


## ApplyFlatFadingChannel

In [95]:
class ApplyFlatFadingChannel(tf.keras.layers.Layer):
    def __init__(self, add_awgn=True, dtype=tf.complex64, **kwargs):
        super().__init__(trainable=False, dtype=dtype, **kwargs)
        self._add_awgn = add_awgn

    def build(self, input_shape): #pylint: disable=unused-argument
        if self._add_awgn:
            self._awgn = AWGN(dtype=self.dtype)

    def call(self, inputs):
        if self._add_awgn:
            x, h, no = inputs
        else:
            x, h = inputs

        x = tf.expand_dims(x, axis=-1)
        y = tf.matmul(h, x)
        y = tf.squeeze(y, axis=-1)

        if self._add_awgn:
            y = self._awgn((y, no))

        return y
def test_apply_flat_fading_channel():
    # 定义参数
    num_tx_ant = 4
    num_rx_ant = 2
    batch_size = 3

    # 生成输入数据
    x = tf.complex(tf.random.normal([batch_size, num_tx_ant]),
                tf.random.normal([batch_size, num_tx_ant]))

    # 生成信道矩阵
    h = tf.complex(tf.random.normal([batch_size, num_rx_ant, num_tx_ant]),
                tf.random.normal([batch_size, num_rx_ant, num_tx_ant]))

    # 定义噪声功率
    no = tf.constant(0.1, dtype=tf.float32)

    # 初始化 ApplyFlatFadingChannel 类
    apply_channel = ApplyFlatFadingChannel(add_awgn=True, dtype=tf.complex64)

    # 构建计算图
    y = apply_channel((x, h, no))

    # 打印输出结果
    print("传输向量:")
    print(x.numpy())
    print("\n信道矩阵:")
    print(h.numpy())
    print("\n加了AWGN后的输出向量:")
    print(y.numpy())
test_apply_flat_fading_channel()

传输向量:
[[ 1.1132426 +0.11130638j -0.5319816 +0.5356469j   1.396718  -1.6788299j
  -2.1668339 +0.6576924j ]
 [ 0.21215717-1.3061938j   0.10999919-0.31190628j -0.7991015 -0.01790807j
   0.07831591+1.3733525j ]
 [ 0.4104208 -2.1528537j  -0.66589326-1.8302702j   1.5569323 -1.5798086j
  -0.18739842+1.2966088j ]]

信道矩阵:
[[[ 0.16454397+0.6013255j  -1.2538235 +1.7570105j
   -0.5845626 -0.6918303j  -1.4415432 +0.11925775j]
  [-0.33987543+0.4709389j  -2.607881  -0.3853946j
   -0.7426765 -1.6712811j  -0.7299179 -0.16972645j]]

 [[-0.2922164 -0.13243763j  1.1041995 -0.87678397j
    2.12758   +0.82741284j  0.00658638+0.45426363j]
  [-0.5833805 -1.1152693j  -1.8485973 -0.74552035j
   -0.03089342-0.23460677j  0.694645  -0.21249084j]]

 [[ 2.0911696 +0.4285976j  -0.54920936-0.19643413j
    0.6413635 +0.38148257j  1.2151476 -0.05950265j]
  [ 0.41378117+0.46881416j  0.49113128+0.577954j
   -0.27631778+1.3978186j   0.4066226 +0.5923779j ]]]

加了AWGN后的输出向量:
[[ 1.5141959-1.8787501j  -0.9935119-1.8785846j ]
 

## FlatFadingChannel

In [96]:
from sionna.channel import GenerateFlatFadingChannel, ApplyFlatFadingChannel
class FlatFadingChannel(tf.keras.layers.Layer):
    def __init__(self,
                 num_tx_ant,
                 num_rx_ant,
                 spatial_corr=None,
                 add_awgn=True,
                 return_channel=False,
                 dtype=tf.complex64,
                 **kwargs):
        super().__init__(trainable=False, dtype=dtype, **kwargs)
        self._num_tx_ant = num_tx_ant
        self._num_rx_ant = num_rx_ant
        self._add_awgn = add_awgn
        self._return_channel = return_channel
        self._gen_chn = GenerateFlatFadingChannel(self._num_tx_ant,
                                                  self._num_rx_ant,
                                                  spatial_corr,
                                                  dtype=dtype)
        self._app_chn = ApplyFlatFadingChannel(add_awgn=add_awgn, dtype=dtype)

    @property
    def spatial_corr(self):
        """The :class:`~sionna.channel.SpatialCorrelation` to be used."""
        return self._gen_chn.spatial_corr

    @spatial_corr.setter
    def spatial_corr(self, value):
        self._gen_chn.spatial_corr = value

    @property
    def generate(self):
        """Calls the internal :class:`GenerateFlatFadingChannel`."""
        return self._gen_chn

    @property
    def apply(self):
        """Calls the internal :class:`ApplyFlatFadingChannel`."""
        return self._app_chn

    def call(self, inputs):
        if self._add_awgn:
            x, no = inputs
        else:
            x = inputs

        # Generate a batch of channel realizations
        batch_size = tf.shape(x)[0]
        h = self._gen_chn(batch_size)

        # Apply the channel to the input
        if self._add_awgn:
            y = self._app_chn([x, h, no])
        else:
            y = self._app_chn([x, h])

        if self._return_channel:
            return y, h
        else:
            return y

In [98]:
def test_flat_fading_channel():
    num_tx_ant = 4
    num_rx_ant = 2
    batch_size = 3

    x_real = tf.random.normal([batch_size, num_tx_ant], dtype=tf.float32)
    x_imag = tf.random.normal([batch_size, num_tx_ant], dtype=tf.float32)
    x = tf.complex(x_real, x_imag)
    
    no = tf.constant(0.1, dtype=tf.float32)

    flat_fading_channel = FlatFadingChannel(num_tx_ant, num_rx_ant, add_awgn=True, return_channel=True)

    y, h = flat_fading_channel((x, no))

    print("传输向量:")
    print(x)
    print("\n信道矩阵:")
    print(h)
    print("\n加了AWGN后的输出向量:")
    print(y)
test_flat_fading_channel()

传输向量:
tf.Tensor(
[[ 0.6843453 +0.22757216j -0.26122835+1.0375805j   0.2260131 +2.4184716j
  -1.0160859 +0.50326806j]
 [ 0.48190844+0.27828395j  0.92021644+0.57757956j -0.22269917+0.5953742j
  -0.18767054+0.5871959j ]
 [ 0.7211384 -0.13487098j  0.1657252 -0.8465592j  -1.0538555 -0.17256583j
  -0.7745284 +0.21054038j]], shape=(3, 4), dtype=complex64)

信道矩阵:
tf.Tensor(
[[[-0.5358709 +0.611553j   -0.40919942-1.7566687j
   -0.7247089 -0.41668743j  0.42145267-0.30306837j]
  [ 0.5341157 +2.09381j    -0.47984308+0.23437928j
   -0.62534195+0.66592836j  0.87496334+0.34314734j]]

 [[ 0.24807462-0.3962823j  -0.50909466-0.12608887j
    0.42017648-0.5143275j  -0.0468226 +0.0418671j ]
  [-0.56326246-0.55790615j -0.21070991+0.13739674j
    0.37893227+0.35578674j -0.11870175-0.10376992j]]

 [[-0.21596894-0.06639432j -0.6416581 -0.4427852j
   -0.17835604+0.8335447j  -0.3046067 +1.4488416j ]
  [ 0.5513172 -0.8785j      0.33442724-0.8733942j
   -1.1926903 +0.00908366j  0.1839229 -0.1777566j ]]], shape=(3,

# torch

## GenerateFlatFadingChannel

In [77]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod

from my_code.mysionna.channel.torch_version import AWGN

from my_code.mysionna.utils import complex_normal

class SimpleSpatialCorrelation(SpatialCorrelation):
    def __init__(self, correlation_matrix):
        """
        Parameters
        ----------
        correlation_matrix : torch.Tensor
            A square matrix used to introduce spatial correlation.
        """
        self.correlation_matrix = correlation_matrix.type(torch.complex64)
    def __call__(self, h, *args, **kwargs):
        # Apply the correlation matrix to the input tensor
        h_shape = h.shape
        h_reshaped = h.view(-1, h_shape[-1])
        h_corr_reshaped = torch.matmul(h_reshaped, self.correlation_matrix)
        h_corr = h_corr_reshaped.view(h_shape)
        return h_corr
class GenerateFlatFadingChannel(nn.Module):

    def __init__(self, num_tx_ant, num_rx_ant, spatial_corr=None, dtype=torch.complex64, **kwargs):
        super().__init__(**kwargs)
        self._num_tx_ant = num_tx_ant
        self._num_rx_ant = num_rx_ant
        self._dtype = dtype
        self.spatial_corr = spatial_corr

    @property
    def spatial_corr(self):
        """The :class:`~sionna.channel.SpatialCorrelation` to be used."""
        return self._spatial_corr

    @spatial_corr.setter
    def spatial_corr(self, value):
        self._spatial_corr = value

    def __call__(self, batch_size):
        # Generate standard complex Gaussian matrices
        shape = [batch_size, self._num_rx_ant, self._num_tx_ant]
        h = complex_normal(shape, dtype=self._dtype)

        # Apply spatial correlation
        if self.spatial_corr is not None:
            h = self.spatial_corr(h)

        return h

In [78]:
def test_generate_flat_fading_channel_with_spatial_corr():
    # Parameters
    num_tx_ant = 4
    num_rx_ant = 2
    batch_size = 2

    # Define a simple correlation matrix
    correlation_matrix = torch.tensor([[1.0, 0.5, 0.5, 0.2],
                                       [0.5, 1.0, 0.3, 0.4],
                                       [0.5, 0.3, 1.0, 0.6],
                                       [0.2, 0.4, 0.6, 1.0]], dtype=torch.float32)
    
    # Initialize the SimpleSpatialCorrelation instance
    spatial_corr = SimpleSpatialCorrelation(correlation_matrix)
    
    # Initialize the GenerateFlatFadingChannel instance
    channel_generator = GenerateFlatFadingChannel(
        num_tx_ant=num_tx_ant,
        num_rx_ant=num_rx_ant,
        spatial_corr=spatial_corr,  # Use the SimpleSpatialCorrelation instance
        dtype=torch.complex64
    )

    # Generate a batch of channel matrices
    h = channel_generator(batch_size)

    # Print the output
    print("Generated channel matrices with spatial correlation:")
    print(h)
test_generate_flat_fading_channel_with_spatial_corr()

Generated channel matrices with spatial correlation:
tensor([[[-0.7127+0.8141j, -0.8704+0.5202j, -0.3839+0.0325j, -0.2602+0.1362j],
         [-0.7857+0.5532j, -0.2742+1.0451j,  0.4278+1.1164j,  0.1717+1.3110j]],

        [[-0.6055+0.6815j, -0.8914+1.1921j, -0.7499+1.1598j, -1.3174+1.5669j],
         [-1.7984+0.6477j, -1.4103+0.8763j, -0.2251+0.1602j,  0.1251+0.2173j]]])


## ApplyFlatFadingChannel

In [79]:
class ApplyFlatFadingChannel(nn.Module):
    def __init__(self, add_awgn=True, dtype=torch.complex64, **kwargs):
        # super().__init__(requires_grad=False, dtype=dtype, **kwargs)
        super().__init__()
        self._add_awgn = add_awgn
        if self._add_awgn:
            self._awgn = AWGN(dtype=dtype)
    

    def forward(self,inputs):
        if self._add_awgn:
            x, h, no = inputs
        else:
            x, h = inputs
        
        x = x.unsqueeze(-1)
        y = torch.matmul(h, x)
        y = y.squeeze(-1)

        if self._add_awgn:
            y = self._awgn((y,no))
        
        return y

In [80]:
def test_apply_flat_fading_channel():
    # 定义参数
    num_tx_ant = 4
    num_rx_ant = 2
    batch_size = 3

    # 生成输入数据
    x = torch.randn(batch_size, num_tx_ant, dtype=torch.cfloat)

    # 生成信道矩阵
    h = torch.randn(batch_size, num_rx_ant, num_tx_ant, dtype=torch.cfloat)

    # 定义噪声功率
    no = torch.tensor(0.1, dtype=torch.float32)

    # 初始化 ApplyFlatFadingChannel 类
    apply_channel = ApplyFlatFadingChannel(add_awgn=True)

    # 构建计算图
    y = apply_channel((x, h, no))

    # 打印输出结果
    print("传输向量:")
    print(x)
    print("\n信道矩阵:")
    print(h)
    print("\n加了AWGN后的输出向量:")
    print(y)
test_apply_flat_fading_channel()

传输向量:
tensor([[ 0.1313-0.6977j,  0.0559+0.5227j,  1.2121-0.5464j,  0.6020+1.2260j],
        [ 0.2297+0.2950j,  0.2082+1.0829j,  1.4587-0.3842j, -0.1091-0.9006j],
        [ 0.4567+0.4774j,  0.4373-0.4676j,  0.2499-0.5733j,  0.4973-0.1342j]])

信道矩阵:
tensor([[[ 0.3331-0.1603j, -0.3885+0.6694j,  0.7061+0.1668j, -0.1577+0.0903j],
         [-0.9079+0.8356j,  0.8097+1.2289j,  0.2214+0.9629j, -2.3175-0.9456j]],

        [[-0.3671-0.0244j,  0.1677+1.0710j,  0.2633+0.4018j,  0.1188+1.1061j],
         [-0.1680-0.8280j, -0.3419+0.1864j, -1.2192-1.0328j, -0.7579+0.3268j]],

        [[-0.1402+0.2742j, -0.9989+0.2292j, -1.0726+1.3798j,  0.1012+0.7148j],
         [ 0.1439-0.3031j,  0.7712-0.6016j,  0.5720+0.2786j, -0.9295-0.3365j]]])

加了AWGN后的输出向量:
tensor([[ 0.3631-0.7183j,  0.3706-1.1167j],
        [ 0.4548+0.2452j, -1.6224-1.1649j],
        [ 0.2085+1.8357j, -0.2450-0.8538j]])


## FlatFadingChannel

In [86]:
class FlatFadingChannel(nn.Module):
    def __init__(self,
                 num_tx_ant,
                 num_rx_ant,
                 spatial_corr=None,
                 add_awgn=True,
                 return_channel=False,
                 dtype=torch.complex64,
                 **kwargs):
        super().__init__()
        self._num_tx_ant = num_tx_ant
        self._num_rx_ant = num_rx_ant
        self._add_awgn = add_awgn
        self._return_channel = return_channel
        self._gen_chn = GenerateFlatFadingChannel(self._num_tx_ant,
                                                  self._num_rx_ant,
                                                  spatial_corr,
                                                  dtype=dtype)
        self._app_chn = ApplyFlatFadingChannel(add_awgn=add_awgn, dtype=dtype)

    @property
    def spatial_corr(self):
        """The :class:`~sionna.channel.SpatialCorrelation` to be used."""
        return self._gen_chn.spatial_corr

    @spatial_corr.setter
    def spatial_corr(self, value):
        self._gen_chn.spatial_corr = value

    @property
    def generate(self):
        """Calls the internal :class:`GenerateFlatFadingChannel`."""
        return self._gen_chn

    @property
    def apply(self):
        """Calls the internal :class:`ApplyFlatFadingChannel`."""
        return self._app_chn

    def forward(self, inputs):
        if self._add_awgn:
            x, no =inputs
        else:
            x = inputs
        
        # Generate a batch of channel realizations
        batch_size = x.shape[0]
        h = self._gen_chn(batch_size)

        # Apply the channel to the input
        if self._add_awgn:
            y = self._app_chn([x, h, no])
        else:
            y = self._app_chn([x, h])

        if self._return_channel:
            return y, h
        else:
            return y

In [87]:
def test_flat_fading_channel():
    # 测试代码
    num_tx_ant = 4
    num_rx_ant = 2
    batch_size = 3

    # 生成输入数据
    x = torch.randn(batch_size, num_tx_ant, dtype=torch.cfloat)

    # 定义噪声功率
    no = torch.tensor(0.1, dtype=torch.float32)

    # 初始化 FlatFadingChannel 类
    flat_fading_channel = FlatFadingChannel(num_tx_ant, num_rx_ant, add_awgn=True, return_channel=True)

    # 进行前向传播
    y, h = flat_fading_channel((x, no))

    # 打印输出结果
    print("传输向量:")
    print(x)
    print("\n信道矩阵:")
    print(h)
    print("\n加了AWGN后的输出向量:")
    print(y)
test_flat_fading_channel()

传输向量:
tensor([[ 0.2329-1.1059j, -0.8660+0.3665j, -0.1342-1.3872j,  0.0426+0.5944j],
        [ 0.8383-0.0249j,  0.2961+1.1384j,  0.3307+0.4614j,  1.1108-0.0269j],
        [ 1.2513+0.3599j, -0.4263+0.2140j,  0.3840+0.2187j, -0.6161+0.5226j]])

信道矩阵:
tensor([[[-0.8303-0.1415j, -1.1471-0.2949j, -0.5462+1.3484j, -0.1300+1.3192j],
         [-1.0897+0.2057j,  2.1578-0.3344j,  0.6888+0.0411j,  0.2210-1.2011j]],

        [[ 1.7980-0.1506j, -1.1110+0.0184j, -0.3771+1.0635j, -1.2193+0.3169j],
         [ 0.8660+0.4287j, -0.3461+0.3921j,  0.1221-0.0784j,  0.4920-0.0571j]],

        [[ 1.7646-0.1076j, -0.4875-0.9969j,  1.2998+0.5949j,  0.2569-0.9195j],
         [ 0.6145-0.2512j,  1.3982+0.5809j,  0.4404-0.4569j,  0.1865+0.3349j]]])

加了AWGN后的输出向量:
tensor([[ 1.5648+1.0159j, -1.0795+1.5959j],
        [-0.9591-1.0191j,  1.0271+0.3200j],
        [ 4.0218+1.9995j,  0.1728-0.3820j]])
