In [None]:
import torch
from torch import nn
from torch import optim

from tqdm import tqdm

import numpy as np
import scipy
from scipy import signal

import matplotlib.pyplot as plt
from datetime import datetime


In [None]:
# Set Device
myseed = 42
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

np.random.seed(myseed)
torch.manual_seed(myseed)
torch.cuda.manual_seed_all(myseed)

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device = torch.device(device)
print(device)

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

PATH = '/content/drive/My Drive/Notebooks/'
RESULTS_PATH = PATH + "results/"

import sys

# sys.path.append(PATH)

# Utils

In [None]:
# from utils import *

import torch

## ---- Music -----
import numpy as np
from scipy import signal

def get_steering_vector(array):
    '''
        array : np.arange(0, M)
    '''
    assert len(array.shape) == 1, 'Must be 1-D array'
    # convert to col vector
    array = array.reshape((-1, 1))
    steering_vector = lambda theta : np.exp(-1j * np.pi * array * np.sin((theta)))#np.sin(np.deg2rad(theta)))
    return steering_vector

## ----- MUSIC -----

class MUSIC:
    ''' multiple signal classification '''
    def __init__(self, steering_vector, angles_grid):
        '''
            steering_vector   : Function
            angles_grid       : numpy Array
        '''
        self.angles_grid = angles_grid
        self.A_grid = steering_vector(angles_grid)
        self.AH_grid = self.A_grid.conj().T

    def spec(self, R, d):
        '''
            R: Covatiance Matrix
            d: num sources
        '''
        Q, _, _ = np.linalg.svd(R)
        Qn = Q[:, d:]

        # a = steering_vector(angles_grid)
        # P_MUSIC = 1 / (a.conj().T @ Qn @ Qn.conj().T @ a)
        spectrum = 1 / np.linalg.norm(self.AH_grid @ Qn, axis=1)
        return spectrum

    def estimate(self, R, d):
        '''
            R: Covatiance Matrix
            d: num sources
        '''
        spectrum = self.spec(R, d)
        # finding peaks
        doas = signal.find_peaks(spectrum)[0]
        doas = self.angles_grid[doas[np.argsort(spectrum[doas])[-d:]]]
        return np.sort(doas)



## -------------------------------------------------

def calc_R(Y, n_snapshots):
    return (Y @ Y.conj().T) / n_snapshots

## ---- Loss -----

def permutations(predDoA):
    if len(predDoA) == 0:
        return []
    if len(predDoA) == 1:
        return [predDoA]
    perms = []
    for i in range(len(predDoA)):
       remaining = predDoA[:i] + predDoA[i + 1:]

       for perm in permutations(remaining):
           perms.append([predDoA[i]] + perm)
    return perms

PERM = {n:permutations(np.arange(0, n, dtype=np.int32).tolist()) for n in range(2, 6+1)}
# perm = PERM[5]
def rmspe_torch(Y, Y_hat, D):
    if Y.ndim != 2:
        if Y.ndim == 1:     Y = Y.unsqueeze(0)
        else:               raise Exception('dim must be 2 or 1')
    if Y_hat.ndim != 2:
        if Y_hat.ndim == 1:     Y_hat = Y_hat.unsqueeze(0)
        else:                   raise Exception('dim must be 2 or 1')
    # Y_perm = torch.concat([torch.tensor(permutations(y.numpy().tolist())).unsqueeze(1) for y in Y], dim=1)
    perm = torch.tensor(PERM[D])
    Y_perm = torch.concat([y[perm].unsqueeze(1) for y in Y], dim=1)
    e = Y_perm - Y_hat

    PI = torch.pi
    # betta = PI
    e = torch.remainder(e + PI/2, PI) - PI/2
    e = torch.norm(e, dim=2)
    e = torch.sqrt((e ** 2) / D)
    e, _ = torch.min(e, dim=0)
    return e

def rmspe_numpy(Y, Y_hat, D):
    Y = torch.tensor(Y)
    Y_hat = torch.tensor(Y_hat)

    return rmspe_torch(Y, Y_hat, D).numpy()

def mse(y, y_hat):
    return np.mean((y - y_hat) ** 2)

## ---- Base Model -----

import torch
from torch import nn
from torch import optim

import matplotlib.pyplot as plt
from datetime import datetime

class Model(nn.Module):
    def __init__(self, lr, device=None):
        super().__init__()

        self.lr = lr

        self.gpu_is_available = False
        if device != None:
            self.device = device
        else:
            self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(self.device)

        self.gpu_is_available = True if self.device.type == 'cuda' else False

        self.epoch = 0
        self.history = {'train_loss' : [], 'val_loss':[]}

    def configure_optimizer(self, optimizer):
        self.optim = optimizer(self.parameters(), lr=self.lr)

    def prepare_batch(self, batch):
        if batch.device != self.device: # self.gpu_is_available:
            batch = batch.to(self.device)
        return batch

    def change_lr(self, lr):
        self.lr = lr # self.lr_scheduler(self.lr, self.epoch)
        if self.optim:
            self.optim.param_groups[0]['lr'] = self.lr

    def total_number_of_params(self):
        return sum([p.numel() for p in self.parameters()])

    def forward(self, X):
        y = self.net(X)
        return y

    def fit(self, max_epoch, train_dataloader, val_dataloader=None, lr_scheduler=None):
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.lr_scheduler = lr_scheduler

        for epoch in range(max_epoch):
            self.epoch += 1

            self.fit_epoch()

            if self.epoch % 5 == 0:
                print(f"{self.epoch:<3} | Loss : {self.history['train_loss'][-1]:.8f}")


    def fit_epoch(self):
        raise NotImplementedError()

    def plot(self):
        plt.figure(figsize=(5, 3))
        plt.plot(self.history['train_loss'], label='train_loss')
        if self.history['val_loss'] != []:
            plt.plot(self.history['val_loss'], label='val_loss')
        plt.grid()
        plt.xlabel('Epoch')
        plt.ylabel('loss')
        plt.legend()

    def save(self, fname=None):
        if fname == None:
            fname = self.__class__.__name__ + '_' + datetime.now().strftime('%Y-%M-%d_%H-%m')
        torch.save({
            'model_state_dict': self.state_dict(),
            'model_history': self.history,
            'model_lr': self.lr,
            'model_epoch' : self.epoch
        }, fname)

    def load(self, path):
        tmp = torch.load(path, weights_only=False, map_location=torch.device('cpu')) # self.device

        self.load_state_dict(tmp['model_state_dict'])
        self.to(self.device)

        self.history = tmp['model_history']
        # self.lr = tmp['model_lr']
        self.change_lr(tmp['model_lr'])
        self.epoch = tmp['model_epoch']
        print('LOAD :', path)

def plot_results(x_data, results_dict_list, fig, ax, figsize=(10, 3), black=[]):
    fig.set_figwidth(figsize[0])
    fig.set_figheight(figsize[1])
    # plt.figure(figsize=figsize)
    for results_dict in results_dict_list:
        for key, value in results_dict.items():
            if value == []:     continue
            if key in black : continue

            symbol = '-^' if key.startswith('model') else '-o'

            ax[0].plot(x_data, value, symbol, label=key)
            ax[1].semilogy(x_data, value, symbol, label=key)
    ax[0].legend(); ax[0].grid()
    ax[1].legend(); ax[1].grid()
    #         plt.subplot(1, 2, 1); plt.plot(x_data, value, symbol, label=key); plt.legend()
    #         plt.subplot(1, 2, 2); plt.semilogy(x_data, value, symbol, label=key); #plt.legend()
    # plt.subplot(1, 2, 1); plt.legend(); plt.grid()
    # plt.subplot(1, 2, 2); plt.legend(); plt.grid()

## -------------------------------------------------
SQRT2 = np.sqrt(2)
def complex_randn(sigma, shape):
    return (sigma / SQRT2) * (np.random.randn(*shape) + 1j * np.random.randn(*shape))

def snrdb2snr(snr_db):
    return 10 ** (snr_db / 10)

def loss_varD(Y, Y_hat, X_d):
    return torch.mean(torch.cat([rmspe_torch(y[:di], y_hat[:di], di.item()) for y, y_hat, di in zip(Y, Y_hat, X_d)]))
    # .to(device)

def total_number_of_params(m):
    return sum([p.numel() for p in m.parameters()])

In [None]:
class DOA:
    def __init__(self, M, array, n_snapshots, angels_grid, steering_vector=None) -> None:
        self.M = M
        self.n_snapshots = n_snapshots
        if steering_vector == None:
            self.steering_vector = get_steering_vector(array=array)
        else:
            self.steering_vector = steering_vector
        self.angles_grid = angels_grid
        self.A_grid = self.steering_vector(self.angles_grid)

    def construct_signal(self, theta, snr, alpha, S=None, X=None):
        d = theta.shape[0]

        if X is None:
            if S is None:
                sigma = 10 ** (snr / 10)
                S = complex_randn(sigma, (d, self.n_snapshots))
            A = self.steering_vector(theta)
            A = A / np.linalg.norm(A) # print(np.linalg.norm(A))
            X = A @ S

        gamma = 1.
        n = scipy.stats.levy_stable.rvs(alpha, 0, scale=gamma, size=X.shape) + 1j * scipy.stats.levy_stable.rvs(alpha, 0, scale=gamma, size=X.shape)
        Y = X + n

        return S, X, Y

    def construct_signal_noise_normal(self, theta, snr, S=None, X=None):
        d = theta.shape[0]

        if X is None:
            if S is None:
                sigma = 10 ** (snr / 10)
                S = complex_randn(sigma, (d, self.n_snapshots))
            A = self.steering_vector(theta)
            A = A / np.linalg.norm(A) # print(np.linalg.norm(A))
            X = A @ S

        n = complex_randn(1., (self.M, self.n_snapshots))
        Y = X + n

        return S, X, Y

    def generate_S_equal_power(self, n_src):
        S = complex_randn(1, (n_src, self.n_snapshots))
        return S

    def generate_S_diff_power(self, n_src, etta_db, S=None):
        if S is None:
            S = complex_randn(1, (n_src, self.n_snapshots))
        etta = 10 ** (np.array(etta_db) / 10)
        S = np.diag(np.sqrt(etta)) @ S
        return S

    def generate_S_coherent_sources(self, n_src, n_coh, coh_idx=None, coh_etta=None, perm=None):
        S_indep = self.generate_S_equal_power(n_src - n_coh)
        if n_coh == 0:
            return S_indep
        # print(S_indep.shape)
        if coh_idx is None:
            coh_idx = np.random.randint(0, n_src-n_coh, n_coh)
            # print(coh_idx)
        if coh_etta is None:
            coh_etta = (np.random.rand(n_coh) * .8 + .1) * np.exp(-1j * np.random.rand(n_coh) * np.pi)

        S_coh = np.diag(coh_etta) @ S_indep[coh_idx]
        # print(S_coh.shape)
        S = np.concatenate((S_indep, S_coh), axis=0)
        # print(S.shape)
        if perm is None:
            S = S[np.random.permutation(S.shape[0])]
            # print(np.random.permutation(S.shape[0]))
        else:
            S = S[perm]
        return S

    def calc_X(self, S, theta):
        A = self.steering_vector(theta)
        A = A / np.linalg.norm(A) # print(np.linalg.norm(A))
        X = A @ S
        return X

    @staticmethod
    def calc_Y_snr_n(X, snr_db):
        snr = 10 ** (snr_db / 10)
        sigma = np.sqrt(1 / snr)
        n = complex_randn(sigma, X.shape)
        Y = X + n
        return Y

    @staticmethod
    def calc_Y_snr_s(X):
        n = complex_randn(1, X.shape)
        Y = X + n
        return Y

    @staticmethod
    def calc_Y_gsnr_n(X, gsnr_db, alpha):
        gsnr = 10 ** (gsnr_db / 10)
        gamma = (1 / (gsnr))# ** (1/alpha)))
        n = scipy.stats.levy_stable.rvs(alpha, 0, scale=gamma, size=X.shape) + 1j * scipy.stats.levy_stable.rvs(alpha, 0, scale=gamma, size=X.shape)
        Y = X + n
        return Y

    @staticmethod
    def calc_Y_gsnr_s(X, alpha):
        gamma = 1.
        n = scipy.stats.levy_stable.rvs(alpha, 0, scale=gamma, size=X.shape) + 1j * scipy.stats.levy_stable.rvs(alpha, 0, scale=gamma, size=X.shape)
        Y = X + n
        return Y




In [None]:
N_SRC = 5
M = 8
SNAPSHOTS = 200

array = np.arange(0, M)
angles_grid = np.arange(-90., 90. + 1e-6, 1)
angles_grid = np.deg2rad(angles_grid)
N_ANGLES_GRID = angles_grid.shape[0]

doa = DOA(M=M , array=array , n_snapshots=SNAPSHOTS , angels_grid=angles_grid)

music = MUSIC(doa.steering_vector, doa.angles_grid)

N_SRC_MIN = 2
N_SRC_MAX = 5
# N_SRC_MAX = 2 # diff power
BATCH_SIZE = 64

In [None]:
def flos(Y, d, p=1.1):
    N = SNAPSHOTS

    y = Y
    yT = Y.T
    yH = yT.conj()

    Rf = (y @ ((np.abs(yT) ** (p-2)) * yH)) / N
    theta_hat = music.estimate(Rf, d)
    if theta_hat.shape[0] != d:
        theta_hat = np.zeros(d)
    return theta_hat

# Data

In [None]:
# data class
class Doa_Data:
    def __init__(self):
        pass

    @staticmethod
    def gsnr_eq_pow(d, theta, gsnr, alpha):
        ## eq pow - snr - control NOISE
        S = doa.generate_S_equal_power(d)
        X = doa.calc_X(S, theta)
        Y = doa.calc_Y_gsnr_n(X, gsnr, alpha)
        return X, Y

    @staticmethod
    def snr_eq_pow(d, theta, snr):
        # S, X, Y = doa.construct_signal_noise_normal(theta, snr)# , S=None, X=None)

        ## eq pow - snr - control noise
        S = doa.generate_S_equal_power(d)
        X = doa.calc_X(S, theta)
        Y = doa.calc_Y_snr_n(X, snr)
        return X, Y

    @staticmethod
    def gsnr_diff_pow(d, theta, etta, alpha):
        ## diff pow - gsnr - control signal
        # etta = np.ones(d) * gsnr
        S = doa.generate_S_diff_power(d, etta)
        X = doa.calc_X(S, theta)
        # Y = doa.calc_Y_gsnr_n(X, gsnr, alpha)
        Y = doa.calc_Y_gsnr_s(X, alpha)
        return X, Y

    @staticmethod
    def gsnr_eq_pow_coherent(n_src, n_coh, theta, gsnr, alpha):
        ## coherent sources - gsnr
        S = doa.generate_S_coherent_sources(n_src, n_coh, coh_idx=None, coh_etta=None, perm=None)
        X = doa.calc_X(S, theta)
        Y = doa.calc_Y_gsnr_n(X, gsnr, alpha)
        return X, Y

In [None]:
np.random.seed(myseed)
torch.manual_seed(myseed)
torch.cuda.manual_seed_all(myseed)

class Create_Dataset:
    def __init__(self, n_train, n_val, n_test,
                doa,
                eq_pow = True,
                gaussian_noise = False,
                d = 5,
                # variable_d=True, fix_d=None,
                # with_noise=True,
                snr=None,
                alpha=None, gsnr=None,
                is_coherent=False):

        size = n_train + n_val + n_test

        doa_data = Doa_Data()

        # self.doa = doa
        snapshots = doa.n_snapshots

        # (Batch, 2*M, T) # 2*M : real & imag
        data_X = np.zeros((size, 2*M, snapshots), dtype=np.float32)
        # data_X_d = np.zeros(size, dtype=np.int32)
        # if variable_d:
        #     data_d = np.random.randint(N_SRC_MIN, N_SRC_MAX+1, size)
        # else:
        #     if fix_d == None:
        #         raise Exception('fix d is None')
            # data_d = np.ones(size, dtype=np.int32) * fix_d
        data_d = np.ones(size, dtype=np.int32) * d

        data_Rx = np.zeros((size, 2, M, M), dtype=np.float32)
        data_Y = np.ones((size, N_SRC_MAX), dtype=np.float32) * np.pi

        if eq_pow:
            print('Eq Power')
            if is_coherent:
                print('**** COHERENT ****')
        else:
            print('Diff Power')

        if gaussian_noise:
            if snr is None:    raise Exception('snr is none')
            print('Gaussian Noise')
        else:
            print('Alpha Stable Noise')
            if gsnr is None:    raise Exception('gsnr is none')
            if alpha is None:   raise Exception('alpha is none')

            ALPHA = np.array([1.5, 2])
            alpha = ALPHA[np.random.randint(0, len(ALPHA), size)]

            if eq_pow:
                pass
            else:
                GSNR = np.array([5., 10., 15.])
                etta = GSNR[np.random.randint(0, len(GSNR), (size, d))]


        for i in tqdm(range(size)):
            # snr = np.random.randint(-20, 20+1)
            # snr = 10
            d = data_d[i]
            theta = np.pi * (np.random.rand(d) - .5)
            # theta = np.sort(theta)

            if eq_pow:
                if gaussian_noise:
                    snr = 10.
                    raise Exception('qqqq')
                    X, Y = doa_data.snr_eq_pow(d, theta, snr)
                else:
                    # g = [10., 15.]
                    # g = [5., 10., 15.]
                    gsnr = 10. #= g[np.random.randint(0, 3)]
                    # a = [1.5, 2]
                    # alpha = a[np.random.randint(0, 2)]

                    if is_coherent:
                        ## COHERENT
                        n_src, n_coh = 5, np.random.randint(0, 2)
                        X, Y = doa_data.gsnr_eq_pow_coherent(n_src, n_coh, theta, gsnr, alpha[i])
                        # print('****')
                    else:
                        ## NON COH
                        X, Y = doa_data.gsnr_eq_pow(d, theta, gsnr, alpha[i]) #alpha[np.random.randint(0, len_alpha)]

            else:
                if gaussian_noise:
                    raise Exception(NotImplemented)
                else:
                    # raise Exception(NotImplemented)
                    # a = [1.5, 2]
                    # alpha = a[np.random.randint(0, 2)]
                    # print('.')
                    X, Y = doa_data.gsnr_diff_pow(d, theta, etta[i], alpha[i])

                    # S = doa.generate_S_diff_power(d, etta)
                    # X = doa.calc_X(S, theta)
                    # # Y = doa.calc_Y_gsnr_n(X, gsnr, alpha)
                    # Y = doa.calc_Y_gsnr_s(X, alpha)

            Rx = calc_R(X, snapshots)

            data_X[i] = np.concatenate((np.real(Y), np.imag(Y)), axis=0)#.swapaxes(0, 1)
            data_Rx[i][0] = np.real(Rx); data_Rx[i][1] = np.imag(Rx)
            data_Y[i][:d] = theta

        data_X = torch.tensor(data_X)
        data_d = torch.tensor(data_d, dtype=torch.long)
        data_Rx = torch.tensor(data_Rx)
        data_Y = torch.tensor(data_Y)
        print(data_X.shape, data_Y.shape)

        indices = np.random.permutation(data_X.shape[0])
        train_indices = indices[:n_train]
        val_indices = indices[n_train: n_train + n_val]
        test_indices = indices[n_train + n_val:]

        self.traindata_X = data_X[train_indices]
        self.traindata_d = data_d[train_indices]
        self.traindata_Rx = data_Rx[train_indices]
        self.traindata_Y = data_Y[train_indices]

        self.val_X = data_X[val_indices]
        self.val_d = data_d[val_indices]
        self.val_Rx = data_Rx[val_indices]
        self.val_Y = data_Y[val_indices]

        self.test_X = data_X[test_indices]
        self.test_d = data_d[test_indices]
        self.test_Rx = data_Rx[test_indices]
        self.test_Y = data_Y[test_indices]

        print(f"train X: {self.traindata_X.shape}, train Y: {self.traindata_Y.shape}")
        print(f"val X: {self.val_X.shape}, val Y: {self.val_Y.shape}")
        print(f"test X: {self.test_X.shape}, test Y: {self.test_Y.shape}")


    def get_train_dataloader(self, batch_size, shuffle=True):
        train_dataloader = torch.utils.data.DataLoader(list(zip(self.traindata_X, self.traindata_d, self.traindata_Rx, self.traindata_Y)), batch_size, shuffle=shuffle)
        return train_dataloader

    def get_val_dataloader(self, batch_size, shuffle=False):
        val_dataloader = torch.utils.data.DataLoader(list(zip(self.val_X, self.val_d, self.val_Rx, self.val_Y)), batch_size, shuffle=shuffle)
        return val_dataloader

    def get_test_dataloader(self, batch_size, shuffle=False):
        test_dataloader = torch.utils.data.DataLoader(list(zip(self.test_X, self.test_d, self.test_Rx, self.test_Y)), batch_size, shuffle=shuffle)
        return test_dataloader

In [None]:
doa.n_snapshots = SNAPSHOTS

N_TRAIN = 150_000 # 50_000
N_VAL = 32_000
N_TEST = 32_000

N_SRC = 5;  D = N_SRC


## ALPHA STABLE (Eq Pow)
da_music_data = Create_Dataset(N_TRAIN, N_VAL, N_TEST, doa,
                eq_pow = True,
                gaussian_noise = False,
                d = 5,
                snr=None,
                gsnr=10., alpha=[1.5, 2])


In [None]:
doa.n_snapshots = SNAPSHOTS

N_TRAIN = 100_000
N_VAL = 21_000
N_TEST = 21_000
# size = N_TRAIN + 25_000

# '''
#     "N_SRC_MAX" must be changed to 2
# '''
## ALPHA STABLE - Diff Pow
da_music_data = Create_Dataset(N_TRAIN, N_VAL, N_TEST, doa,
                eq_pow = False,
                gaussian_noise = False,
                d = 2,
                snr=None,
                gsnr=10., alpha=[1.5, 2])


In [None]:
doa.n_snapshots = SNAPSHOTS

# N_TRAIN = 100_000 # 50_000
# size = N_TRAIN + 25_000

N_TRAIN = 150_000 # 50_000
N_VAL = 32_000
N_TEST = 32_000
# size = N_TRAIN + 30_000

N_SRC = 5;  D = N_SRC

## ALPHA STABLE - Coh
da_music_data = Create_Dataset(N_TRAIN, N_VAL, N_TEST, doa,
                eq_pow = True, is_coherent=True,
                gaussian_noise = False,
                d = 5,
                snr=None,
                gsnr=10., alpha=[1.5, 2])


In [None]:
print(BATCH_SIZE)
train_dataloader = da_music_data.get_train_dataloader(BATCH_SIZE)
val_dataloader = da_music_data.get_val_dataloader(BATCH_SIZE)
test_dataloader = da_music_data.get_test_dataloader(BATCH_SIZE * 2)

In [None]:
# X, X_d, X_R, Y = next(iter(train_dataloader))
# X.shape, X_d.shape, X_R.shape, Y.shape

In [None]:
# DATE = '0000-00-00'

# def save_model(date, model, model_name, n_traindata, batch_size, save1=True, save2=True):
#     fname = '{date}_{model_name}_{n_traindata}_B{batchsize}_E{epoch}_lr{lr}.pth'.format(
#         date=date, model_name=model_name, n_traindata=n_traindata, batchsize=batch_size, epoch=model.epoch, lr=model.lr
#     )
#     print(fname)
#     if save1:   model.save(fname)
#     else: print('**** 1-NOT SAVED ****')
#     if save2:   model.save(RESULTS_PATH + fname)
#     else: print('**** 2-NOT SAVED ****')

# # save_model(DATE, da_music, 'nn', '50k', BATCH_SIZE) # Ex

# Train

## DA_MUSIC

In [None]:
## DA MUSIC
class DA_Music_VarD(Model):
    def __init__(self, lr, optimizer, M, n_angles_grid, A_grid, n_src_max):
        super().__init__(lr)

        self.M = M
        # M : num sensors
        self.BatchNorm = nn.BatchNorm1d(2*M)
        self.gru = nn.GRU(2*M, 2*M, batch_first=True)
        self.fc = nn.Linear(2*M, 2*M*M)

        spec_len = n_angles_grid
        n = 2*M
        self.net = nn.Sequential(
            nn.Linear(spec_len, n), nn.ReLU(),
            nn.Linear(n, n), nn.ReLU(),
            nn.Linear(n, n), nn.ReLU(),
            nn.Linear(n, n_src_max)
        )

        self.A = torch.tensor(A_grid, dtype=torch.complex64, device=self.device)
        self.A = self.A.swapaxes(0, 1).unsqueeze(2)
        self.AH = self.A.swapaxes(1, 2).conj()

        self.configure_optimizer(optimizer)
        self.to(self.device)

    def loss(self, Y, Y_hat, X_d):
        return loss_varD(Y.cpu(), Y_hat.cpu(), X_d).to(self.device)

    def forward(self, X, X_d):
        M = self.M
        # X.shape = (B, M, T)
        # X.shape = (B, 2*M, T)     cat(real, imag)
        # ----------X.shape = (B, T, 2*M)
        X = self.BatchNorm(X).swapaxes(1, 2)    # (B, T, 2*M)

        y, _ = self.gru(X)
        # print(y.shape)
        y = y[:, -1, :]
        y = self.fc(y)
        y = y.reshape((-1, 2*M, M))
        y = torch.complex(y[:, :M, :], y[:, M:, :])

        Q, _, _ = torch.svd(y)
        ## Qn = Q[:, :, D:]
        # print(Qn.shape)

        # QnQnH = torch.cat([(q @ q.conj().T).unsqueeze(0) for q in Qn])
        QnQnH = torch.cat([(Q[i][:, X_d[i]:] @ Q[i][:, X_d[i]:].conj().T).unsqueeze(0) for i in range(Q.shape[0])])
        QnQnH = QnQnH.unsqueeze(1)

        spec = self.AH @ QnQnH @ self.A
        spec = spec.squeeze()
        spec = 1 / torch.abs(spec)
        # find doa
        y = self.net(spec)
        return y

    def fit_epoch(self):
        self.train()
        epoch_loss_avg = 0
        for X, X_d, X_R, Y in tqdm(self.train_dataloader):
            if self.gpu_is_available:
                X = X.to(self.device)
                Y = Y.to(self.device)

            Y_hat = self(X, X_d)
            # loss = self.loss(Y, Y_hat)
            loss = self.loss(Y, Y_hat, X_d)
            # loss = torch.mean((Y - Y_hat) ** 2)
            epoch_loss_avg += loss.detach().item()

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

        self.history['train_loss'].append(epoch_loss_avg / len(self.train_dataloader))

        ## change lr
        if self.lr_scheduler:
            # self.lr = self.lr_scheduler(self.lr, self.epoch)
            # self.optim.param_groups[0]['lr'] = self.lr
            self.change_lr(self.lr_scheduler(self.lr, self.epoch))

        if self.val_dataloader != None:
            with torch.no_grad():
                self.eval()
                epoch_loss_avg = 0
                for X, X_d, X_R, Y in self.val_dataloader:
                    if self.gpu_is_available:
                        X = X.to(self.device)
                        Y = Y.to(self.device)
                    Y_hat = self(X, X_d)

                    loss = self.loss(Y, Y_hat, X_d)

                    epoch_loss_avg += loss.detach().item()

                self.history['val_loss'].append(epoch_loss_avg / len(self.val_dataloader))

    def predict(self, Y, Y_d):
        ## Y complex Array output
        Ynet = torch.tensor(
                np.concatenate((np.real(Y), np.imag(Y)), axis=0),
                dtype=torch.float32, device=self.device
        ).unsqueeze(0)
        # ).swapaxes(0, 1).unsqueeze(0)
        self.eval()
        theta_hat = self(Ynet, Y_d).cpu().detach().numpy()
        return theta_hat



In [None]:
lr = 0.001
n_angles_grid=angles_grid.shape[0]; print(n_angles_grid)
A_grid = doa.A_grid
da_music = DA_Music_VarD(lr=lr, optimizer=torch.optim.Adam,
                        M=M, n_angles_grid=n_angles_grid, A_grid=A_grid,
                        n_src_max=N_SRC_MAX
)

In [None]:
da_music.fit(10, train_dataloader, val_dataloader)

In [None]:

da_music.plot()

## DCNTA-MUISIC

In [None]:
class NSFE_Block(nn.Module): # Noise Suppression and Feature Extraction Block
    def __init__(self, c_in, c_out):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv1d(c_in, c_out, 3, padding=1),
            nn.Tanh(),
            nn.Conv1d(c_out, c_out, 3, padding=1),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.net(x)

# c1 = nn.Conv1d(2*M, 4*M, 3, padding=1)
c_in, c_out = 2*M, 2*M
# v15
class First_Block(nn.Module):
    def __init__(self):
        super().__init__()

        c = 2*M
        f = [c, 2*c, 4*c, 8*c, 16*c]
        pool = nn.MaxPool1d
        self.net0 = nn.Sequential(
            NSFE_Block(f[0], f[1]), # (B, 4*M, T)
            pool(2),                # (B, 4*M, T/2)
            NSFE_Block(f[1], f[2]), # (B, 8*M, T/2)
            nn.AdaptiveAvgPool1d(1),# (B, 8*M)
        )
        self.net1 = nn.Sequential(
            nn.Linear(f[2], 2*M*M),# (B, 2*M*M)
        )

    def forward(self, x):
        y = self.net0(x)
        # y = torch.mean(y, axis=-1)
        y = y.squeeze(-1)
        # return y
        y = self.net1(y)
        return y

In [None]:
class DCNTA_MUSIC(Model):
    def __init__(self, lr, optimizer, M, n_angles_grid, A_grid, n_src_max):
        super().__init__(lr)

        # M : num sensors
        # self.BatchNorm = nn.BatchNorm1d(2*M)

        self.net0 = First_Block()

        spec_len = n_angles_grid
        n = 2*M
        self.net = nn.Sequential(
            nn.Linear(spec_len, n), nn.ReLU(),
            nn.Linear(n, n), nn.ReLU(),
            nn.Linear(n, n), nn.ReLU(),
            nn.Linear(n, n_src_max),
            nn.Tanh()
        )

        self.A = torch.tensor(A_grid, dtype=torch.complex64, device=self.device)
        self.A = self.A.swapaxes(0, 1).unsqueeze(2)
        self.AH = self.A.swapaxes(1, 2).conj()

        self.configure_optimizer(optimizer)
        self.to(self.device)

    def loss(self, Y, Y_hat, X_d):
        return loss_varD(Y.cpu(), Y_hat.cpu(), X_d).to(self.device)

    def forward(self, X, X_d):
        
        # X = self.BatchNorm(X)   # (B, 2*M, T)
        #.swapaxes(1, 2)    # (B, T, 2*M)

        y = self.net0(X) # (B, 2*M*M)
        
        y = y.reshape((-1, 2*M, M)) # (B, 2*M, M)
        
        y = torch.complex(y[:, :M, :], y[:, M:, :]) # (B, M, M)
        # print(y.shape)

        Q, _, _ = torch.svd(y)
        ## Qn = Q[:, :, D:]
        # print(Qn.shape)

        # QnQnH = torch.cat([(q @ q.conj().T).unsqueeze(0) for q in Qn])
        QnQnH = torch.cat([(Q[i][:, X_d[i]:] @ Q[i][:, X_d[i]:].conj().T).unsqueeze(0) for i in range(Q.shape[0])])
        QnQnH = QnQnH.unsqueeze(1)

        spec = self.AH @ QnQnH @ self.A
        spec = spec.squeeze()
        spec = 1 / torch.abs(spec)
        # find doa
        y = self.net(spec)
        return y

    def fit_epoch(self):
        self.train()
        epoch_loss_avg = 0
        for X, X_d, X_R, Y in tqdm(self.train_dataloader):
            if self.gpu_is_available:
                X = X.to(self.device)
                Y = Y.to(self.device)

            Y_hat = self(X, X_d)
            # loss = self.loss(Y, Y_hat)
            Y /= (np.pi / 2.0) # *************************************************
            loss = self.loss(Y, Y_hat, X_d)
            # loss = torch.mean((Y - Y_hat) ** 2)
            epoch_loss_avg += loss.detach().item()

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

        self.history['train_loss'].append(epoch_loss_avg / len(self.train_dataloader))

        ## change lr
        if self.lr_scheduler:
            # self.lr = self.lr_scheduler(self.lr, self.epoch)
            # self.optim.param_groups[0]['lr'] = self.lr
            self.change_lr(self.lr_scheduler(self.lr, self.epoch))

        if self.val_dataloader != None:
            with torch.no_grad():
                self.eval()
                epoch_loss_avg = 0
                for X, X_d, X_R, Y in self.val_dataloader:
                    if self.gpu_is_available:
                        X = X.to(self.device)
                        Y = Y.to(self.device)
                    Y_hat = self(X, X_d)
                    Y /= (np.pi / 2.0) # *************************************************
                    loss = self.loss(Y, Y_hat, X_d)

                    epoch_loss_avg += loss.detach().item()

                self.history['val_loss'].append(epoch_loss_avg / len(self.val_dataloader))

    def predict(self, Y, Y_d):
        ## Y complex Array output
        Ynet = torch.tensor(
                np.concatenate((np.real(Y), np.imag(Y)), axis=0),
                dtype=torch.float32, device=self.device
        ).unsqueeze(0)
        # ).swapaxes(0, 1).unsqueeze(0)
        self.eval()
        theta_hat = self(Ynet, Y_d).cpu().detach().numpy()
        return theta_hat

In [None]:
lr = 0.001
n_angles_grid=angles_grid.shape[0]; print(n_angles_grid)
A_grid = doa.A_grid
dcnta = DCNTA_MUSIC(lr=lr, optimizer=torch.optim.Adam,
                        M=M, n_angles_grid=n_angles_grid, A_grid=A_grid,
                        n_src_max=N_SRC_MAX
)

In [None]:
dcnta.fit(40, train_dataloader, val_dataloader)
dcnta.fit(40, train_dataloader, val_dataloader)

In [None]:
dcnta.plot()

## Trans-MUSIC

In [None]:
# Trans_MUSIC Source Code: https://github.com/jijunkai/Transformer_Music

import math

m = M
r = angles_grid.shape[0]

def ULA_action_vector(theta):
    array = np.linspace(0, m, m, endpoint=False)
    return np.exp(- 1j * np.pi * array * np.sin(theta))

## -----------------------------------------------------
a=torch.zeros([m,r])+1j*torch.zeros([m,r])
for i in range(r):
    # a[:,i] =torch.from_numpy(ULA_action_vector(array, angles[0,i]))
    # a[:,i] =torch.from_numpy(ULA_action_vector(angles[0,i]))
    a[:,i] =torch.from_numpy(ULA_action_vector(angles_grid[i]))
a=torch.complex(a.real.float(),a.imag.float()).to(device)


def calculate_spectrum(En):

    H1=torch.matmul(En.to(device)@ torch.conj(En.permute(0,2,1)).to(device),a).to(device)

    H2=torch.mul(H1,torch.conj(a))

    H3=torch.sum(H2,dim=1)

    return (1.0/abs(H3)).to(device)



class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):

        x = x.to(device)  + self.pe[:x.size(0)].to(device)

        return x


#**********************************#
#   trans_music                    #
#**********************************#

class Trans_MUSIC(Model):

    def __init__(self, lr, optimizer, m, n_src_max):

        super().__init__(lr)

        self.m=m

        self.BN=torch.nn.BatchNorm1d(16).to(device)
        # self.LN= torch.nn.LayerNorm(normalized_shape = [16]).to(device)

        self.pos_encoder = PositionalEncoding(16)#Position embedding

        encoder_layer=torch.nn.TransformerEncoderLayer(    #Define encoder layer
            d_model=16,
            nhead=8, ##The number of heads in a multi head attention model
            dim_feedforward=1024, #Dimensions of feedforward network models
            dropout=0, #Dropout value
            activation="relu",
            layer_norm_eps=1e-05,
            batch_first=True,
            norm_first=False,
            device=None,
            dtype=None).to(device)


        self.encoder=torch.nn.TransformerEncoder(
            encoder_layer,
            num_layers=3,
            norm=None).to(device)

        self.input_linear = nn.Linear(in_features=16, out_features=128).to(device)

        self.output = nn.Sequential(

            # nn.Linear(in_features=360, out_features=16) ,
            nn.Linear(in_features=N_ANGLES_GRID, out_features=16) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16),
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=16) ,
            nn.ReLU(inplace=False),
            nn.Linear(in_features=16, out_features=n_src_max)
                ).to(device)



        # self.output_d = nn.Sequential(
        #     nn.Linear(in_features=128, out_features=128),
        #     nn.ReLU(inplace=False),
        #     nn.Linear(in_features=128, out_features=128) ,
        #     nn.ReLU(inplace=False),
        #     nn.Linear(in_features=128, out_features=64),
        #     nn.ReLU(inplace=False),
        #     nn.Linear(in_features=64, out_features=32),
        #     nn.ReLU(inplace=False),
        #     nn.Linear(in_features=32, out_features=4)
        #         ).to(device)

        self.configure_optimizer(optimizer)
        self.to(self.device)

    def forward(self, x, x_d):



        #The input X is [size, 16200] batch_ Size=16, input dimension 16, sequence length 200


        size=x.shape[0]#Get batch_ Size size

        # x=x.permute(0,2,1).float().to(device)# Exchange dimension becomes [size, 200,16]

        x=self.BN(x).to(device)#Become [size, 16200]

        # x=self.LN(x.to(device)).to(device)


        x=x.permute(2,0,1).float().to(device)# Exchange dimension becomes [200, size, 16]

        #Position embedding
        x=self.pos_encoder(x.to(device)).to(device)  #x: Tensor, shape [seq_len, batch_size, embedding_dim]

        x=x.permute(1,0,2).float().to(device)# Exchange dimension becomes [size, 200,16]

        x1=self.encoder(x.to(device) )    #Transformer_ Encoder network output becomes [size, 200,16]

        x2=torch.mean(x1,dim=1) #Output becomes [size, 16]

        x3=self.input_linear(x2).to(device) #The output is passed to a fully connected layer and becomes [size, 128]

        vector=x3 #

        x4=x3.reshape(size,16,8).to(device) #Change its mapping covariance to [size, 16,8]

        # vector=x4  #CNN for classifier

        x5=torch.complex(x4[ :,:8 ,:].to(device),x4[ :,8: ,:].to(device))    #feature vector  [size,8,8]


        x6=calculate_spectrum(x5).to(device)  #Calculate spectrum

        x7=x6.float().to(device)

        x8=self.output(x7.to(device)).to(device)


        # #Classifier for estimating D
        # x9=x3  #[size,16]
        # x9=x9.detach()#Truncated gradient flow
        # x9=self.output_d(x9)


        return x8 #,x9,vector
    def loss(self, Y, Y_hat, X_d):
        return loss_varD(Y.cpu(), Y_hat.cpu(), X_d).to(self.device)
    def fit_epoch(self):
        self.train()
        epoch_loss_avg = 0
        for X, X_d, X_R, Y in tqdm(self.train_dataloader):
            if self.gpu_is_available:
                X = X.to(self.device)
                Y = Y.to(self.device)

            Y_hat = self(X, X_d)
            # loss = self.loss(Y, Y_hat)
            loss = self.loss(Y, Y_hat, X_d)
            # loss = torch.mean((Y - Y_hat) ** 2)
            epoch_loss_avg += loss.detach().item()

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

        self.history['train_loss'].append(epoch_loss_avg / len(self.train_dataloader))

        ## change lr
        if self.lr_scheduler:
            # self.lr = self.lr_scheduler(self.lr, self.epoch)
            # self.optim.param_groups[0]['lr'] = self.lr
            self.change_lr(self.lr_scheduler(self.lr, self.epoch))

        if self.val_dataloader != None:
            with torch.no_grad():
                self.eval()
                epoch_loss_avg = 0
                for X, X_d, X_R, Y in self.val_dataloader:
                    if self.gpu_is_available:
                        X = X.to(self.device)
                        Y = Y.to(self.device)
                    Y_hat = self(X, X_d)

                    loss = self.loss(Y, Y_hat, X_d)

                    epoch_loss_avg += loss.detach().item()

                self.history['val_loss'].append(epoch_loss_avg / len(self.val_dataloader))
    def predict(self, Y, Y_d):
        ## Y complex Array output
        Ynet = torch.tensor(
                np.concatenate((np.real(Y), np.imag(Y)), axis=0),
                dtype=torch.float32, device=self.device
        ).unsqueeze(0)
        # ).swapaxes(0, 1).unsqueeze(0)
        self.eval()
        theta_hat = self(Ynet, Y_d).cpu().detach().numpy()
        return theta_hat

In [None]:
lr = 0.001
trans_music = Trans_MUSIC(lr=lr, optimizer=torch.optim.Adam, m=M, n_src_max=N_SRC_MAX)

In [None]:
trans_music.fit(80, train_dataloader, val_dataloader)

In [None]:
trans_music.plot()