In [1]:
import tensorflow as tf
import keras
from keras.optimizers import Adam
from keras.layers import Dense
from keras.models import Model, Sequential
from keras.activations import relu,softmax
from keras.losses import SparseCategoricalCrossentropy

In [2]:
def get_feature_extractor():
      raise NotImplementedError()

def get_internal_slicer():
      raise NotImplementedError()


def get_rx_decoder():
      raise NotImplementedError()

def phase_multiply(internally_sliced_y,h):
      raise NotImplementedError()

In [3]:
SLICED_Y_LENGTH = 16
BATCH_SIZE =  512

In [4]:

class FeatureExtractor(Model):
    def __init__(self):
        super().__init__()

        self.cf1 = Dense(256)
        self.cf2 = Dense(8)

    def call(self,sliced_y):

        sliced_y = self.cf1(sliced_y)
        sliced_y = relu(sliced_y)
        sliced_y = self.cf2(sliced_y)
        return sliced_y

class PhaseEstimator(Model):
    def __init__(self):
        super().__init__()

        self.cf1 = Dense(256)
        self.cf2 = Dense(2)


    def call(self,sliced_y):
        sliced_y = self.cf1(sliced_y)
        sliced_y = relu(sliced_y)
        sliced_y = self.cf2(sliced_y)
        return sliced_y


class Rx_Decoder(Model):
    def __init__(self):
        super().__init__()

        self.cf1 = Dense(256)
        self.cf2 = Dense(256)
        self.cf3 = Dense(16)

    def call(self,concat):

        concat = self.cf1(concat)
        concat = relu(concat)
        concat = self.cf2(concat)
        concat = relu(concat)

        concat = self.cf3(concat)
        concat = softmax(concat)

        return concat


class InternalSlicer(Model):
    def __init__(self,l1,l2,complex_length):
        super().__init__()

        # define the slice boundaries
        mid = complex_length // 2
        self.start = mid - l1
        self.end = mid + l2 + 1

    def call(self,sliced_y):
        return C2R(R2C(sliced_y)[:, self.start:self.end])

def R2C(a):
#     print("reached here 01")
#     print(a.shape)
    aa = tf.cast(tf.reshape(a,shape=(BATCH_SIZE,-1,2)),tf.float32)
    # print(aa)
    aaa = tf.complex(aa[:,:,0],aa[:,:,1])
    return aaa

def C2R(a):
    real, imag = tf.expand_dims(tf.math.real(a),axis=2) ,tf.expand_dims(tf.math.imag(a), axis=2)
    R = tf.concat((real,imag),axis=2)
    R = tf.reshape(R , (BATCH_SIZE,-1)  )
    return R

def phase_multiply(internally_sliced_y,estimated_phase):
    # (a,b) * (c,d) = (ac-bd,ad+bc)
    internally_sliced_y_complex = R2C(internally_sliced_y)
    estimated_phase_complex = R2C(estimated_phase)
    phase_corrected_complex = estimated_phase_complex * internally_sliced_y_complex

    phase_corrected = C2R(phase_corrected_complex)
    return phase_corrected






## PARAMS

In [5]:
# generate fake data
m = 512* 2** 2
X = tf.random.normal(shape=(m,SLICED_Y_LENGTH),mean=0,stddev=1)

Y = tf.random.uniform(shape=(m,1), minval=0, maxval=16, dtype=tf.int32)
# Y = keras.utils.to_categorical(Y,16)


In [6]:
# sequence decoder


class SequenceDecoder(Model):

    def __init__(self):
        super().__init__()

        self.feature_extractor = FeatureExtractor()
        self.phase_estimator = PhaseEstimator()
        self.internal_slicer = InternalSlicer(l1=3,l2=3,complex_length=SLICED_Y_LENGTH//2)
        self.rx_decoder = Rx_Decoder()



    def call(self,sliced_y):

        extracted_features = self.feature_extractor(sliced_y)
        estimated_phase = self.phase_estimator(sliced_y)
        internally_sliced_y = self.internal_slicer(sliced_y)
        
#         print("estimated_phase.shape",estimated_phase.shape)
#         print("internally_sliced_y.shape",internally_sliced_y.shape)
        
        phase_corrected_ = phase_multiply(internally_sliced_y,estimated_phase)

        concat = tf.concat((extracted_features,phase_corrected_),axis=1)

        st_hat = self.rx_decoder(concat)

        return st_hat




In [7]:
# test the SD

mySD =   SequenceDecoder()

mySD.compile(optimizer=Adam(learning_rate=1e-2),
             loss=SparseCategoricalCrossentropy(from_logits=False),
             metrics=['accuracy'])


mySD.fit(X,Y,epochs=2,batch_size=BATCH_SIZE)


Epoch 1/2
estimated_phase.shape (512, 2)
internally_sliced_y.shape (512, 14)
estimated_phase.shape (512, 2)
internally_sliced_y.shape (512, 14)
Epoch 2/2


<keras.callbacks.History at 0x1d683fe2fd0>