In [31]:
import numpy as np
import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
pi = math.pi


In [32]:
def sign(x):
    y = torch.zeros(x.size())
    y[x > 0] = 1
    return y

In [33]:
class Mynet(nn.Module):
    def __init__(self, n):
        super(Mynet, self).__init__()
        self.n = n
        self.fc1 = nn.Linear(n, 64, bias=False)
        self.fc_last = nn.Linear(64, 1, bias=False)
        self.delta = Parameter(torch.Tensor(1))
        
    def forward(self, w):
        a = tocosines(w, self.n)
        out = self.fc1(a)
        r = self.fc_last(out)
        out = torch.cat((r, torch.ones(r.size(0),1)*self.delta),1)
        return out

In [34]:
def tocosines(w, n):
    '''
    w: batch*1
    '''
    t = torch.arange(0, n, 1.0)
    c = torch.ones(1, n)*2
    c[0,0] = 1
    out = torch.cos(torch.ger(w, t))*c
    return out #batch*n

In [35]:
def customized_loss(w, out, alpha, ws, wp, tradeoff):
    '''
    w: batch * n
    out: batch * 2
    
    '''
    R = out[:,0].reshape(-1, 1)
    delta = out[:,1].reshape(-1, 1)
    v = torch.cat((R-alpha**2, 1/(alpha**2)-R, R-delta, -R), 1)
    v = F.relu(v)
#     print(v)
    reg = reg_lambda(w, ws, wp, tradeoff)
#     print (reg)
    loss = (torch.sum(v*reg) + torch.sum(delta))/w.size(0)
    return loss

In [36]:
def reg_lambda(w, ws, wp, tradeoff):
    reg = torch.zeros(w.size(0), 4)
    ind1 = ((w < wp).nonzero())
    ind2 = ((w > ws).nonzero())
    reg[ind1, 0] = 1
    reg[ind1, 1] = 1
    reg[ind2, 2] = 1
    reg[:, 3] = 1
    return reg*tradeoff #batch*4

In [37]:
class MyDataset(Dataset):
    def __init__(self, W):
        self.W = W

    def __len__(self):
        return len(self.W)

    def __getitem__(self, index):
        w = self.W[index]
        return w

In [38]:
alpha = 1.1
wp = 0.12*pi
ws = 0.24*pi
tradeoff = 50
n = 20
m = n*20
W = np.linspace(0, pi, m).astype('float32')
W = torch.from_numpy(W)

In [39]:
batch_size = 40
train_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=batch_size, shuffle=True)

In [43]:
model = Mynet(n)
optimizer = optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
scheduler = MultiStepLR(optimizer, milestones=[400, 600], gamma=0.2)
epochs = 2000

model_best = model
previous_loss = 10000
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    scheduler.step(epoch)
    for i, w in enumerate(train_loader, 0):

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        out = model(w)
        loss = customized_loss(w, out, alpha, ws, wp, tradeoff)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    print('epoch = %d, loss = %.6f' %(epoch, running_loss/i))
    
    if running_loss < previous_loss:
        model_best = model
    
    previous_loss = running_loss
    
print('Finished Training')

epoch = 0, loss = 17.689024
epoch = 1, loss = 13.526695
epoch = 2, loss = 9.291194
epoch = 3, loss = 6.539594
epoch = 4, loss = 4.515940
epoch = 5, loss = 2.685839
epoch = 6, loss = 1.369485
epoch = 7, loss = 0.610682
epoch = 8, loss = 0.112034
epoch = 9, loss = 0.050084
epoch = 10, loss = 0.048673
epoch = 11, loss = 0.047913
epoch = 12, loss = 0.046854
epoch = 13, loss = 0.045856
epoch = 14, loss = 0.044784
epoch = 15, loss = 0.043687
epoch = 16, loss = 0.042581
epoch = 17, loss = 0.041487
epoch = 18, loss = 0.040477
epoch = 19, loss = 0.039527
epoch = 20, loss = 0.038507
epoch = 21, loss = 0.037428
epoch = 22, loss = 0.036327
epoch = 23, loss = 0.035220
epoch = 24, loss = 0.034110
epoch = 25, loss = 0.033173
epoch = 26, loss = 0.032276
epoch = 27, loss = 0.031416
epoch = 28, loss = 0.030545
epoch = 29, loss = 0.029685
epoch = 30, loss = 0.028729
epoch = 31, loss = 0.027916
epoch = 32, loss = 0.027198
epoch = 33, loss = 0.026197
epoch = 34, loss = 0.025361
epoch = 35, loss = 0.024505


epoch = 292, loss = 0.007442
epoch = 293, loss = 0.007392
epoch = 294, loss = 0.006364
epoch = 295, loss = 0.005598
epoch = 296, loss = 0.005710
epoch = 297, loss = 0.007173
epoch = 298, loss = 0.007372
epoch = 299, loss = 0.009530
epoch = 300, loss = 0.007711
epoch = 301, loss = 0.009985
epoch = 302, loss = 0.007529
epoch = 303, loss = 0.009102
epoch = 304, loss = 0.008731
epoch = 305, loss = 0.009024
epoch = 306, loss = 0.011634
epoch = 307, loss = 0.010872
epoch = 308, loss = 0.011207
epoch = 309, loss = 0.009485
epoch = 310, loss = 0.008719
epoch = 311, loss = 0.007836
epoch = 312, loss = 0.006899
epoch = 313, loss = 0.006077
epoch = 314, loss = 0.005878
epoch = 315, loss = 0.006433
epoch = 316, loss = 0.006926
epoch = 317, loss = 0.007683
epoch = 318, loss = 0.009502
epoch = 319, loss = 0.011431
epoch = 320, loss = 0.008119
epoch = 321, loss = 0.007565
epoch = 322, loss = 0.006999
epoch = 323, loss = 0.006893
epoch = 324, loss = 0.006641
epoch = 325, loss = 0.006542
epoch = 326, l

epoch = 577, loss = 0.002058
epoch = 578, loss = 0.001997
epoch = 579, loss = 0.001841
epoch = 580, loss = 0.001864
epoch = 581, loss = 0.002230
epoch = 582, loss = 0.002137
epoch = 583, loss = 0.001866
epoch = 584, loss = 0.002790
epoch = 585, loss = 0.002249
epoch = 586, loss = 0.002049
epoch = 587, loss = 0.002010
epoch = 588, loss = 0.002781
epoch = 589, loss = 0.002809
epoch = 590, loss = 0.002598
epoch = 591, loss = 0.002327
epoch = 592, loss = 0.002380
epoch = 593, loss = 0.002285
epoch = 594, loss = 0.002253
epoch = 595, loss = 0.002123
epoch = 596, loss = 0.002006
epoch = 597, loss = 0.001917
epoch = 598, loss = 0.001959
epoch = 599, loss = 0.001727
epoch = 600, loss = 0.001617
epoch = 601, loss = 0.001617
epoch = 602, loss = 0.001569
epoch = 603, loss = 0.001548
epoch = 604, loss = 0.001495
epoch = 605, loss = 0.001477
epoch = 606, loss = 0.001453
epoch = 607, loss = 0.001457
epoch = 608, loss = 0.001432
epoch = 609, loss = 0.001462
epoch = 610, loss = 0.001397
epoch = 611, l

epoch = 862, loss = 0.001207
epoch = 863, loss = 0.001222
epoch = 864, loss = 0.001159
epoch = 865, loss = 0.001156
epoch = 866, loss = 0.001132
epoch = 867, loss = 0.001197
epoch = 868, loss = 0.001221
epoch = 869, loss = 0.001196
epoch = 870, loss = 0.001249
epoch = 871, loss = 0.001292
epoch = 872, loss = 0.001446
epoch = 873, loss = 0.001306
epoch = 874, loss = 0.001441
epoch = 875, loss = 0.001345
epoch = 876, loss = 0.001250
epoch = 877, loss = 0.001229
epoch = 878, loss = 0.001196
epoch = 879, loss = 0.001176
epoch = 880, loss = 0.001242
epoch = 881, loss = 0.001263
epoch = 882, loss = 0.001361
epoch = 883, loss = 0.001280
epoch = 884, loss = 0.001270
epoch = 885, loss = 0.001213
epoch = 886, loss = 0.001178
epoch = 887, loss = 0.001221
epoch = 888, loss = 0.001198
epoch = 889, loss = 0.001227
epoch = 890, loss = 0.001373
epoch = 891, loss = 0.001283
epoch = 892, loss = 0.001340
epoch = 893, loss = 0.001223
epoch = 894, loss = 0.001186
epoch = 895, loss = 0.001257
epoch = 896, l

epoch = 1144, loss = 0.001403
epoch = 1145, loss = 0.001458
epoch = 1146, loss = 0.001276
epoch = 1147, loss = 0.001234
epoch = 1148, loss = 0.001336
epoch = 1149, loss = 0.001238
epoch = 1150, loss = 0.001313
epoch = 1151, loss = 0.001232
epoch = 1152, loss = 0.001204
epoch = 1153, loss = 0.001192
epoch = 1154, loss = 0.001215
epoch = 1155, loss = 0.001193
epoch = 1156, loss = 0.001237
epoch = 1157, loss = 0.001198
epoch = 1158, loss = 0.001146
epoch = 1159, loss = 0.001181
epoch = 1160, loss = 0.001155
epoch = 1161, loss = 0.001128
epoch = 1162, loss = 0.001138
epoch = 1163, loss = 0.001094
epoch = 1164, loss = 0.001120
epoch = 1165, loss = 0.001191
epoch = 1166, loss = 0.001166
epoch = 1167, loss = 0.001212
epoch = 1168, loss = 0.001206
epoch = 1169, loss = 0.001139
epoch = 1170, loss = 0.001254
epoch = 1171, loss = 0.001195
epoch = 1172, loss = 0.001255
epoch = 1173, loss = 0.001236
epoch = 1174, loss = 0.001226
epoch = 1175, loss = 0.001177
epoch = 1176, loss = 0.001250
epoch = 11

epoch = 1430, loss = 0.001246
epoch = 1431, loss = 0.001176
epoch = 1432, loss = 0.001302
epoch = 1433, loss = 0.001213
epoch = 1434, loss = 0.001202
epoch = 1435, loss = 0.001192
epoch = 1436, loss = 0.001128
epoch = 1437, loss = 0.001126
epoch = 1438, loss = 0.001266
epoch = 1439, loss = 0.001210
epoch = 1440, loss = 0.001202
epoch = 1441, loss = 0.001222
epoch = 1442, loss = 0.001283
epoch = 1443, loss = 0.001230
epoch = 1444, loss = 0.001185
epoch = 1445, loss = 0.001212
epoch = 1446, loss = 0.001192
epoch = 1447, loss = 0.001172
epoch = 1448, loss = 0.001165
epoch = 1449, loss = 0.001203
epoch = 1450, loss = 0.001162
epoch = 1451, loss = 0.001162
epoch = 1452, loss = 0.001145
epoch = 1453, loss = 0.001193
epoch = 1454, loss = 0.001239
epoch = 1455, loss = 0.001204
epoch = 1456, loss = 0.001225
epoch = 1457, loss = 0.001153
epoch = 1458, loss = 0.001171
epoch = 1459, loss = 0.001243
epoch = 1460, loss = 0.001213
epoch = 1461, loss = 0.001235
epoch = 1462, loss = 0.001196
epoch = 14

epoch = 1720, loss = 0.001163
epoch = 1721, loss = 0.001254
epoch = 1722, loss = 0.001405
epoch = 1723, loss = 0.001389
epoch = 1724, loss = 0.001189
epoch = 1725, loss = 0.001180
epoch = 1726, loss = 0.001279
epoch = 1727, loss = 0.001118
epoch = 1728, loss = 0.001144
epoch = 1729, loss = 0.001168
epoch = 1730, loss = 0.001109
epoch = 1731, loss = 0.001089
epoch = 1732, loss = 0.001114
epoch = 1733, loss = 0.001170
epoch = 1734, loss = 0.001140
epoch = 1735, loss = 0.001145
epoch = 1736, loss = 0.001122
epoch = 1737, loss = 0.001073
epoch = 1738, loss = 0.001092
epoch = 1739, loss = 0.001201
epoch = 1740, loss = 0.001167
epoch = 1741, loss = 0.001150
epoch = 1742, loss = 0.001145
epoch = 1743, loss = 0.001113
epoch = 1744, loss = 0.001141
epoch = 1745, loss = 0.001179
epoch = 1746, loss = 0.001187
epoch = 1747, loss = 0.001124
epoch = 1748, loss = 0.001165
epoch = 1749, loss = 0.001204
epoch = 1750, loss = 0.001164
epoch = 1751, loss = 0.001140
epoch = 1752, loss = 0.001131
epoch = 17

In [44]:
def get_result():
    a =[]
    for parameter in model_best.parameters():
        a.append(parameter.data)
    delta = a[0]
    r = torch.mm(a[1].permute(1, 0), a[2].permute(1, 0))
    return delta, r.reshape(1, -1)

In [45]:
delta, r = get_result()
print(delta)
print (r)

tensor([0.0009])
tensor([[ 0.1465,  0.1395,  0.1210,  0.0938,  0.0624,  0.0316,  0.0055, -0.0131,
         -0.0233, -0.0261, -0.0232, -0.0172, -0.0104, -0.0045, -0.0005,  0.0017,
          0.0022,  0.0018,  0.0011,  0.0006]])


In [46]:
test_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=W.shape[0], shuffle=True)
with torch.no_grad():
    for w in test_loader:
       
        out = model(w)
   
        R = out[:,0].reshape(-1, 1)
        delta = out[:,1].reshape(-1, 1)
        v = torch.cat((R-alpha**2, 1/(alpha**2)-R, R-delta, -R), 1)

        v = F.relu(v)

        reg = reg_lambda(w, ws, wp, tradeoff)
        loss = (torch.sum(v*reg) + torch.sum(delta))/w.size(0)
        loss = customized_loss(w, out, alpha, ws, wp, tradeoff)
        print (loss)
        break

tensor(0.0010)


In [47]:
torch.save(model_best, 'model_lowpass_attenuation.pth')

  "type " + obj.__name__ + ". It won't be checked "


In [13]:
model = torch.load('model_lowpass_attenuation.pth')