In [None]:
import numpy as np
import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
pi = math.pi

In [None]:
class Mynet(nn.Module):
    def __init__(self, n):
        super(Mynet, self).__init__()
        self.n = n
        self.fc1 = nn.Linear(n, 64, bias=False)
        self.fc_last = nn.Linear(64, 1, bias=False)
        self.alpha = Parameter(torch.Tensor(1))
        
    def forward(self, w):
        a = tocosines(w, self.n)
        out = self.fc1(a)
        r = self.fc_last(out)
        out = torch.cat((r, torch.ones(r.size(0),1)*self.alpha),1)
        return out

In [None]:
def tocosines(w, n):
    '''
    w: batch*1
    '''
    t = torch.arange(0, n, 1.0)
    c = torch.ones(1, n)*2
    c[0,0] = 1
    out = torch.cos(torch.ger(w, t))*c
    return out #batch*n

In [None]:
def customized_loss(w, out, wa, wb, tradeoff):
    '''
    w: batch * n
    out: batch * 2
    
    '''
    R = out[:,0].reshape(-1, 1)
    alpha = out[:,1].reshape(-1, 1)
    v = torch.cat((R*w/alpha-alpha, 1/alpha-R*w/alpha, -R), 1)
    v = F.relu(v)
    reg = reg_lambda(w, wa, wb, tradeoff)
    loss = (torch.sum(v*reg) + torch.sum(alpha))/w.size(0)
    return loss

In [None]:
def reg_lambda(w, wa, wb, tradeoff):
    reg = torch.zeros(w.size(0), 3)
    ind1 = ((wa < w < wb).nonzero())
    reg[ind1, 0] = 1
    reg[ind1, 1] = 1
    reg[:, 2] = 1
    return reg*tradeoff #batch*3

In [None]:
def D(w):
    return w**(-1/2)

In [None]:
class MyDataset(Dataset):
    def __init__(self, W):
        self.W = W

    def __len__(self):
        return len(self.W)

    def __getitem__(self, index):
        w = self.W[index]
        return w

In [None]:
wa = 0.01*pi
wb = pi
tradeoff = 20
n = 50
m = n*20
W = np.linspace(0, pi, m).astype('float32')
W = torch.from_numpy(W)

In [None]:
batch_size = 40
train_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=batch_size, shuffle=True)

In [None]:
model = Mynet(n)
optimizer = optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
scheduler = MultiStepLR(optimizer, milestones=[400, 600], gamma=0.2)
epochs = 1000

model_best = model
previous_loss = 10000
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    scheduler.step(epoch)
    for i, w in enumerate(train_loader, 0):

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        out = model(w)
        loss = customized_loss(w, out, wa, wb, tradeoff)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    print('epoch = %d, loss = %.6f' %(epoch, running_loss/i))
    
    if running_loss < previous_loss:
        model_best = model
    
    previous_loss = running_loss
    
print('Finished Training')

In [None]:
def get_result():
    a =[]
    for parameter in model_best.parameters():
        a.append(parameter.data)
    delta = a[0]
    r = torch.mm(a[1].permute(1, 0), a[2].permute(1, 0))
    return delta, r.reshape(1, -1)

In [None]:
delta, r = get_result()
print(delta)
print (r)

In [None]:
test_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=20, shuffle=True)
with torch.no_grad():
    for w in test_loader:
        print(w)
        out = model(w)
        print (out)
        R = out[:,0].reshape(-1, 1)
        delta = out[:,1].reshape(-1, 1)
        v = torch.cat((R-alpha**2, 1/(alpha**2)-R, R-delta, -R), 1)
        print (v)
        v = F.relu(v)
        print(v)
        reg = reg_lambda(w, ws, wp, tradeoff)
        print (reg)
        print (v*reg)
        loss = (torch.sum(v*reg) + torch.sum(delta))/w.size(0)
        loss = customized_loss(w, out, alpha, ws, wp, tradeoff)
        break

In [None]:
torch.save(model_best, 'model_lowpass_attenuation.pth')

In [None]:
R = np.random.randn(4,1)
w = np.random.randn(4,1)
alpha = 2*np.ones((4,1))
out = R*w/alpha
print (R)
print (w)
print (alpha)
print (out)

In [None]:
0.18903622*-0.3123378/2

In [None]:
-0.63223123*0.98311085/2