online version of this notebook : https://colab.research.google.com/drive/12iz5_mTOmTa0oyn5IZZYnEskVUonViVW#scrollTo=VSTgsilO1_ur

In [1]:
import tensorflow as tf
import keras
from keras.optimizers import Adam
from keras.layers import Dense
from keras.models import Model, Sequential
from keras.activations import relu, softmax
from keras.losses import SparseCategoricalCrossentropy



In [2]:
def get_feature_extractor():
  raise NotImplementedError()

def get_internal_slicer():
  raise NotImplementedError()


def get_rx_decoder():
  raise NotImplementedError()

def phase_multiply(internally_sliced_y,h):
  raise NotImplementedError()

## PARAMS

In [4]:
SLICED_Y_LENGTH = 16
BATCH_SIZE =  2048

# in teh feature extractor path "f" : design param
# Our experiments have shown that even a
# small number of features, e.g., F = 4, significantly improves
# the performance.
N_FEATURES_EXTRACTED = 8

In [5]:
def R2C(a):
#     print("reached here 01")
#     print(a.shape)
    aa = tf.cast(tf.reshape(a,shape=(BATCH_SIZE,-1,2)),tf.float32)
    # print(aa)
    aaa = tf.complex(aa[:,:,0],aa[:,:,1])
    return aaa

def C2R(a):
    real, imag = tf.expand_dims(tf.math.real(a),axis=2) ,tf.expand_dims(tf.math.imag(a), axis=2)
    R = tf.concat((real,imag),axis=2)
    R = tf.reshape(R , (BATCH_SIZE,-1)  )
    return R


## Main Blocks in the Sequence Decoder

In [6]:

class FeatureExtractor(Model):
    def __init__(self):
        super().__init__()

        self.cf1 = Dense(256)

        self.cf2 = Dense(N_FEATURES_EXTRACTED)

        self.cf_state = Dense(8)

    def call(self,sliced_y,prev_state_FE):
        print("sliced_y.shape", sliced_y.shape)
        sliced_y = self.cf1(sliced_y)
        sliced_y = relu(sliced_y)
        state_FE = self.cf_state(sliced_y) # state calculated here
        sliced_y = self.cf2(sliced_y)
        return sliced_y,state_FE

class PhaseEstimator(Model):
    def __init__(self):
        super().__init__()

        self.cf1 = Dense(256)
        self.cf2 = Dense(2)

        self.cf_state = Dense(8)


    def call(self,sliced_y,prev_state_PE):
        sliced_y = self.cf1(sliced_y)
        sliced_y = relu(sliced_y)
        state_PE = self.cf_state(sliced_y) # state calculated here
        sliced_y = self.cf2(sliced_y)
        return sliced_y,state_PE


class Rx_Decoder_old(Model):
    def __init__(self):
        super().__init__()

        self.cf1 = Dense(256)
        self.cf2 = Dense(256)
        self.cf3 = Dense(16)

    def call(self,concat):

        concat = self.cf1(concat)
        concat = relu(concat)
        concat = self.cf2(concat)
        concat = relu(concat)

        concat = self.cf3(concat)

        # do not use softmax here : put from logit  = True in loss func
        # concat = softmax(concat)

        return concat


class Rx_Decoder_new(Model):
    def __init__(self):
        super().__init__()

        self.cf1 = Dense(256)
        self.cf2 = Dense(256)
        self.cf3 = Dense(16,name="final_out_cf3")

        # useless
        #self.cf4_state = Dense(8,name="state_dense_cf4")

    def call(self,concat):

        concat = self.cf1(concat)
        concat = relu(concat)
        concat = self.cf2(concat)
        concat = relu(concat)

        # state = self.cf4_state(concat)
        concat = self.cf3(concat)



        # do not use softmax here : put from logit  = True in loss func
        # concat = softmax(concat)

        return concat

class InternalSlicer(Model):
    def __init__(self,l1,l2,complex_length):
        super().__init__()

        # define the slice boundaries
        mid = complex_length // 2
        self.start = mid - l1
        self.end = mid + l2 + 1

    def call(self,sliced_y):
        print(f"internal slicer : sliced_y.shape : {sliced_y.shape}")
        ret = C2R(R2C(sliced_y)[:, self.start:self.end])
        print(f"internal slicer : ret.shape : {ret.shape}")
        return ret


def phase_multiply(internally_sliced_y,estimated_phase):
    # (a,b) * (c,d) = (ac-bd,ad+bc)
    print("internally_sliced_y.shape:", internally_sliced_y.shape)
    print("estimated_phase.shape: ", estimated_phase.shape)
    internally_sliced_y_complex = R2C(internally_sliced_y)
    estimated_phase_complex = R2C(estimated_phase)
    phase_corrected_complex = estimated_phase_complex * internally_sliced_y_complex

    phase_corrected = C2R(phase_corrected_complex)
    return phase_corrected






## Fake data syn

In [7]:
# generate fake data
m = 512* 2** 2
X = tf.random.normal(shape=(m,SLICED_Y_LENGTH),
                     mean=0,
                     stddev=1)

Y = tf.random.uniform(shape=(m,1),
                      minval=0,
                      maxval=16,
                      dtype=tf.int32)
# Y = keras.utils.to_categorical(Y,16)


## Main Model : Sequence Decoder Class

In [8]:
# sequence decoder


class SequenceDecoder(Model):

    def __init__(self,take_prev_phase_state=False):
        super(SequenceDecoder,self).__init__()

        self.take_prev_phase_state = take_prev_phase_state

        self.feature_extractor = FeatureExtractor()
        self.phase_estimator = PhaseEstimator()
        self.internal_slicer = InternalSlicer(l1=3,l2=3,complex_length=SLICED_Y_LENGTH//2)
        if take_prev_phase_state:
            self.rx_decoder_RNN = Rx_Decoder_new()
        else:
            raise Exception("How here come??")
            self.rx_decoder = Rx_Decoder_old()



    def call(self,sliced_y,prev_state_FE=None,prev_state_PE=None):

        if prev_state_PE is None:
            print(" How this none?")
            prev_state_PE = tf.constant(tf.zeros((X.shape[0],8)))
        if prev_state_FE is None:
            print(" How this none?")
            prev_state_FE = tf.constant(tf.zeros((X.shape[0],8)))


        # RNN conn starts here

        output_FE = self.feature_extractor(sliced_y,prev_state_FE=prev_state_FE)
        extracted_features,state_FE = output_FE[0], output_FE[1]

        output_PE = self.phase_estimator(sliced_y,prev_state_PE=prev_state_PE)
        estimated_phase,state_PE = output_PE[0], output_PE[1]

        # RNN conn ends here

        internally_sliced_y = self.internal_slicer(sliced_y)

        print("SD call estimated_phase.shape",estimated_phase.shape)
        print("SD call internally_sliced_y.shape",internally_sliced_y.shape)

        phase_corrected_ = phase_multiply(internally_sliced_y,estimated_phase)

        concat = tf.concat((extracted_features,phase_corrected_),axis=1)
        if self.take_prev_phase_state:
            st_hat = self.rx_decoder_RNN(concat)
            return (st_hat,state_FE,state_PE)
        else:
            print("--PROBLEM--")
            st_hat = self.rx_decoder(concat)
            return st_hat



    def custom_train(self,X,Y,epochs=1): # X =  vertically stacked sliced_y, y = message index

        # temp_prev_state_PE = [tf.constant(tf.zeros((X.shape[0],8)))] #append the last PE state here
        # temp_prev_state_FE = [tf.constant(tf.zeros((X.shape[0],8)))] #append the last FE state here

        temp_prev_state_PE = tf.constant(tf.zeros((X.shape[0],8)))
        temp_prev_state_FE = tf.constant(tf.zeros((X.shape[0],8)))

        loss_acc = 0

        # tarin per each time step
        for i in range(epochs):
            print(f"iterration : {i}")
            x = X # tf.expand_dims(X[i,:],axis=0)
            print("x shape:", x.shape)
            y = Y # tf.expand_dims(Y[i,:],axis=0)
            print("y shape:", y.shape)
            with tf.GradientTape() as tape:
                output = self.call(x,
                                   prev_state_PE=temp_prev_state_PE,
                                   prev_state_FE=temp_prev_state_FE)
                st_hat,state_FE,state_PE = output[0], output[1], output[2]
                loss = self.compiled_loss(y,st_hat)

                #temp_prev_state = state ###### assign add dala balanna
                # print and see the shapes
                print("state_FE.shape",state_FE.shape)
                print("state_PE.shape",state_PE.shape)
                print("----------")
                temp_prev_state_FE = (state_FE)
                temp_prev_state_PE = (state_PE)

            grads = tape.gradient(loss,self.trainable_variables)
            self.optimizer.apply_gradients(zip(grads,self.trainable_variables))

            loss_acc += tf.stop_gradient(loss).numpy()

        return loss_acc




In [9]:
X.shape

TensorShape([2048, 16])

In [10]:
# test the SD

mySD =   SequenceDecoder(take_prev_phase_state=True)

mySD.compile(optimizer=Adam(learning_rate=1e-2),
             loss=SparseCategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])


mySD.custom_train(X,Y)





iterration : 0
x shape: (2048, 16)
y shape: (2048, 1)
sliced_y.shape (2048, 16)
internal slicer : sliced_y.shape : (2048, 16)
internal slicer : ret.shape : (2048, 14)
SD call estimated_phase.shape (2048, 2)
SD call internally_sliced_y.shape (2048, 14)
internally_sliced_y.shape: (2048, 14)
estimated_phase.shape:  (2048, 2)
state_FE.shape (2048, 8)
state_PE.shape (2048, 8)
----------


2.7816436290740967

In [11]:
# mySD.build((2048,16))