In [None]:
#!/usr/bin/env python
# coding: utf-8

import os
import sys
import time

In [None]:
import torch
import numpy as np

In [None]:
dirc = "C:\\Soham\\Git\\spectral-NN\\Data\\"
repl = 0

In [None]:
    print('Example'+str(repl+1)+':')
    file = dirc+'locations'+str(repl+1)+'.dat'
    u = np.loadtxt(dirc+"locations.dat",dtype="float32")
    if len(u.shape)==1:
        D, d = len(u), 1
        u = u.reshape(D,1)
    else:
        D, d = u.shape
    u = torch.from_numpy(u)
    file = dirc+'Example'+str(repl+1)+'.dat'
    x = np.loadtxt(file,dtype='float32')
    N = x.shape[0]
    if x.shape[1] != D:
        exit('Data shape mismatch!! Aborting..')
    print('N='+str(N)+', D='+str(D)+', d='+str(d))

    x = torch.from_numpy(x)
    x = x - torch.mean(x,dim=0,keepdim=True)

In [None]:
class spectralNNShallow(torch.nn.Module):
    def __init__(self,N,d,M,L,act_fn=torch.nn.Sigmoid(),init=torch.nn.init.xavier_normal_):
        super(spectralNNShallow, self).__init__()
        self.N = N
        self.L = L
        self.act_fn = act_fn
        self.init = init
        self.weight = torch.zeros([M,2*L+1,d],dtype=torch.float32,requires_grad=True) #weights of the shallow networks
        self.bias = torch.zeros([M,2*L+1,1],dtype=torch.float32,requires_grad=True) #biases of the shallow networks
        self.xi = torch.zeros([M,N+2*L],dtype=torch.float32,requires_grad=True) #the multipliers xi_{m,h}
        self.init(self.weight)
        self.init(self.xi)
        self.params = list([self.weight, self.xi, self.bias])
        ### add bias term

    def first_step(self, u):
        return self.act_fn(torch.einsum("ijk,lk -> ijl", self.weight, u) + self.bias) #an object of size M x 2L+1 x D

    def iter_prod(self, i, G): ## iterated product with the coefficients in xi
        return torch.einsum("ij,ijk -> k", self.xi[:,i:(i+2*self.L+1)], G).reshape(1,-1)

    def forward(self, u):
        G = self.first_step(u)
        return torch.cat([model.iter_prod(i,G) for i in range(self.N)])

In [None]:
class loss_spectralNN:
    """Module to compute the loss function associated with the spectral NN estimator"""
    def __init__(self, wt_fn, grid_size = 100, q=10):
        """
        Args:
            grid_size - size of the discrete grid on [-pi,pi]
                        for choice of theta
            q - lag value
        """
        self.q = q
        self.thetas = torch.arange(start=-self.q/(self.q+1),end=self.q/(self.q+1),step=1/(self.q+1),dtype=torch.float32)*np.pi
        hs = np.arange(start=-self.q,stop=self.q+0.5,step=1.,dtype="float32")
        self.C_diff = torch.from_numpy(np.array([[h1-h2 for h2 in hs] for h1 in hs]))
        self.w = torch.from_numpy(wt_fn(hs/self.q))

    def inner_sum(self,A11,A22,A12,h1,h2):
        if h1 < 0:
            h1 = -h1
        if h2 < 0:
            h2 = -h2
        if h1 == 0 and h2 == 0:
            a1 = torch.sum(A11*A11)
            a2 = torch.sum(A22*A22)
            a3 = torch.sum(A12*A12)
            return a1 + a2 - 2*a3
        elif h1 == 0:
            a1 = torch.sum(A11[:,h2:]*A11[:,:-h2])
            a2 = torch.sum(A22[:,h2:]*A22[:,:-h2])
            a3 = torch.sum(A12[:,h2:]*A12[:,:-h2])
            a4 = torch.sum(A12[h2:,:]*A12[:-h2,:])
            return a1 + a2 - a3 - a4
        elif h2 == 0:
            a1 = torch.sum(A11[h1:,:]*A11[:-h1,:])
            a2 = torch.sum(A22[h1:,:]*A22[:-h1,:])
            a3 = torch.sum(A12[h1:,:]*A12[:-h1,:])
            a4 = torch.sum(A12[:,h1:]*A12[:,:-h1])
            return a1 + a2 - a3 - a4
        else:
            a1 = torch.sum(A11[h1:,h2:]*A11[:-h1,:-h2])
            a2 = torch.sum(A22[h1:,h2:]*A22[:-h1,:-h2])
            a3 = torch.sum(A12[h1:,h2:]*A12[:-h1,:-h2])
            a4 = torch.sum(A12[h2:,h1:]*A12[:-h2,:-h1])
            return a1 + a2 - a3 - a4

    def inner_part(self, x, x_tilde):
        """
        Calculates the inner part of the loss function a(h,h') for h,h'=-q,...,q

        Args:
            x - observed functional time series (NxD matrix)
            x_tilde - fitted time seris using neural networks (NxD matrix)
        """
        A11 = torch.matmul(x,x.T)
        A22 = torch.matmul(x_tilde,x_tilde.T)
        A12 = torch.matmul(x,x_tilde.T)
        A = torch.zeros([2*self.q+1,2*self.q+1],dtype=torch.float32,requires_grad=False)
        for h1 in range(self.q):
            for h2 in range(self.q):
                A[self.q+h1,self.q+h2] = self.inner_sum(A11,A22,A12,h1,h2)
                A[self.q-h1,self.q-h2] = A[self.q+h1,self.q+h2]
                A[self.q-h1,self.q+h2] = A[self.q+h1,self.q+h2]
                A[self.q+h1,self.q-h2] = A[self.q+h1,self.q+h2]
        return A

#    for h1 in range(q):
#    for h2 in range(h1,q):
#        C[q+h1,q+h2] = inner_sum(A11,A22,A12,h1,h2)
#        C[q-h1,q-h2] = C[q+h1,q+h2]
#        C[q-h1,q+h2] = C[q+h1,q+h2]
#        C[q+h1,q-h2] = C[q+h1,q+h2]
#        C[q+h2,q+h1] = C[q+h1,q+h2]
#        C[q-h2,q-h1] = C[q+h1,q+h2]
#        C[q-h2,q+h1] = C[q+h1,q+h2]
#        C[q+h2,q-h1] = C[q+h1,q+h2]

    def loss_fn(self, x, x_tilde):
        N = x.shape[0]
        A = self.inner_part(x, x_tilde)
        l = 0.
        for theta in self.thetas:
            l += torch.sqrt(torch.matmul(self.w,torch.matmul(torch.cos(theta*self.C_diff)*A,self.w)))/N
        return l/(2*self.q+1)

In [None]:
model = spectralNNShallow(N,d,10,4)
#x_tilde = model(u)

In [None]:
wt_fn = lambda x: np.exp(-x**2)

In [None]:
loss = loss_spectralNN(wt_fn, grid_size=100, q=10)

In [None]:
#M = 10
#L = 10
#N = 200
#d = 1
#K = 100
#D = K**d

In [None]:
#x = np.array(np.random.randn(N,D),dtype="float32")
#u = np.array(np.arange(1,K+1)/(K+1),dtype="float32").reshape(-1,1)
#print(x.shape)
#print(u.shape)
#x = torch.from_numpy(x)
#u = torch.from_numpy(u)

In [None]:
#print(model.params)
#x_hat = model(u)
#loss = torch.norm(x-x_hat)
#print(loss.item())

In [None]:
optimizer = torch.optim.Adam(model.params,lr=0.01)

In [None]:
print(time.ctime())

In [None]:
for i in range(1000):
    #loss = torch.norm(x-model(u))
    l = loss.loss_fn(x,model(u))
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
    print(l.item())

In [None]:
print(time.ctime())

In [None]:
A11 = torch.matmul(x,x.T)
A22 = torch.matmul(x_tilde,x_tilde.T)
A12 = torch.matmul(x,x_tilde.T)
print(A11.shape)
print(A22.shape)
print(A12.shape)

In [None]:
def inner_sum(A11,A22,A12,h1,h2):
    if h1 < 0:
        h1 = -h1
    if h2 < 0:
        h2 = -h2
    if h1 == 0 and h2 == 0:
        a1 = torch.sum(A11*A11)
        a2 = torch.sum(A22*A22)
        a3 = torch.sum(A12*A12)
        return a1 + a2 - 2*a3
    elif h1 == 0:
        a1 = torch.sum(A11[:,h2:]*A11[:,:-h2])
        a2 = torch.sum(A22[:,h2:]*A22[:,:-h2])
        a3 = torch.sum(A12[:,h2:]*A12[:,:-h2])
        a4 = torch.sum(A12[h2:,:]*A12[:-h2,:])
        return a1 + a2 - a3 - a4
    elif h2 == 0:
        a1 = torch.sum(A11[h1:,:]*A11[:-h1,:])
        a2 = torch.sum(A22[h1:,:]*A22[:-h1,:])
        a3 = torch.sum(A12[h1:,:]*A12[:-h1,:])
        a4 = torch.sum(A12[:,h1:]*A12[:,:-h1])
        return a1 + a2 - a3 - a4
    else:
        a1 = torch.sum(A11[h1:,h2:]*A11[:-h1,:-h2])
        a2 = torch.sum(A22[h1:,h2:]*A22[:-h1,:-h2])
        a3 = torch.sum(A12[h1:,h2:]*A12[:-h1,:-h2])
        a4 = torch.sum(A12[h2:,h1:]*A12[:-h2,:-h1])
        return a1 + a2 - a3 - a4

In [None]:
for h1 in range(q):
    for h2 in range(q):
        A[q+h1,q+h2] = inner_sum(A11,A22,A12,h1,h2)
        A[q-h1,q-h2] = inner_sum(A11,A22,A12,-h1,-h2)
        A[q-h1,q+h2] = inner_sum(A11,A22,A12,-h1,h2)
        A[q+h1,q-h2] = inner_sum(A11,A22,A12,h1,-h2)

In [None]:
for h1 in range(q):
    for h2 in range(q):
        B[q+h1,q+h2] = inner_sum(A11,A22,A12,h1,h2)
        B[q-h1,q-h2] = B[q+h1,q+h2]
        B[q-h1,q+h2] = B[q+h1,q+h2]
        B[q+h1,q-h2] = B[q+h1,q+h2]

In [None]:
for h1 in range(q):
    for h2 in range(h1,q):
        C[q+h1,q+h2] = inner_sum(A11,A22,A12,h1,h2)
        C[q-h1,q-h2] = C[q+h1,q+h2]
        C[q-h1,q+h2] = C[q+h1,q+h2]
        C[q+h1,q-h2] = C[q+h1,q+h2]
        C[q+h2,q+h1] = C[q+h1,q+h2]
        C[q-h2,q-h1] = C[q+h1,q+h2]
        C[q-h2,q+h1] = C[q+h1,q+h2]
        C[q+h2,q-h1] = C[q+h1,q+h2]

In [None]:
print(np.round(A.detach().numpy()/N**2,1))

In [None]:
print(np.round(B.detach().numpy()/N**2,1))

In [None]:
print(np.round(C.detach().numpy()/N**2,1))

In [None]:
print(torch.mean((A-B)**2))
print(torch.mean((A-C)**2))
print(torch.mean((B-C)**2))

In [None]:
for i in range(2*q+1):
    for j in range(2*q+1):
        if A[i,j] != C[i,j]:
            print(str(i)+","+str(j))

In [None]:
i,j=1,7
print(A[i,j],B[i,j],C[i,j])