In [1]:
import numpy as np
import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
pi = math.pi

In [2]:
class Mynet(nn.Module):
    def __init__(self, n):
        super(Mynet, self).__init__()
        self.n = n
        self.fc1 = nn.Linear(n, 64, bias=False)
        self.fc_last = nn.Linear(64, 1, bias=False)
        self.delta = Parameter(torch.Tensor(1))
        
    def forward(self, w):
        a = tocosines(w, self.n)
        out = self.fc1(a)
        r = self.fc_last(out)
        out = torch.cat((r, torch.ones(r.size(0),1)*self.delta),1)
        return out

In [3]:
def tocosines(w, n):
    '''
    w: batch*1
    '''
    t = torch.arange(0, n, 1.0)
    c = torch.ones(1, n)*2
    c[0,0] = 1
    out = torch.cos(torch.ger(w, t))*c
    return out #batch*n

In [5]:
def customized_loss(w, out, alpha, ws, wp, tradeoff):
    '''
    w: batch * n
    out: batch * 2
    
    '''
    R = out[:,0].reshape(-1, 1)
    delta = out[:,1].reshape(-1, 1)
    v = torch.cat((R-alpha**2, 1/(alpha**2)-R, R-delta, -R), 1)
    v = F.relu(v)
    reg = reg_lambda(w, ws, wp, tradeoff)
#     print (reg)
    loss = (torch.sum(v*reg) + torch.sum(delta))/w.size(0)
    return loss

In [6]:
def reg_lambda(w, ws, wp, tradeoff):
    reg = torch.zeros(w.size(0), 4)
    ind1 = ((w < wp).nonzero())
    ind2 = ((w > ws).nonzero())
    reg[ind1, 0] = 1*tradeoff[0]
    reg[ind1, 1] = 1*tradeoff[1]
    reg[ind2, 2] = 1*tradeoff[2]
    reg[:, 3] = 1*tradeoff[3]
    return reg #batch*4

In [7]:
class MyDataset(Dataset):
    def __init__(self, W):
        self.W = W

    def __len__(self):
        return len(self.W)

    def __getitem__(self, index):
        w = self.W[index]
        return w

In [184]:
alpha = 1.1
wp = 0.12*pi
ws = 0.24*pi
tradeoff = [10, 10, 20, 20]
n = 20
m = n*40
W = np.linspace(0, pi*1.1, m).astype('float32')
W = torch.from_numpy(W)

In [185]:
batch_size = 100
train_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=batch_size, shuffle=True)

In [186]:
model = Mynet(n)
optimizer = optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
scheduler = MultiStepLR(optimizer, milestones=[400, 800, 1500, 2000], gamma=0.2)
epochs = 2500

model_best = model
previous_loss = 10000
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    scheduler.step(epoch)
    for i, w in enumerate(train_loader, 0):

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        out = model(w)
        loss = customized_loss(w, out, alpha, ws, wp, tradeoff)
        loss.backward()
        optimizer.step()

        # print statistics
    print('epoch = %d, loss = %.6f' %(epoch, loss))
    
    if running_loss < previous_loss:
        model_best = model
    
    previous_loss = running_loss
    
print('Finished Training')

epoch = 0, loss = 4.804831
epoch = 1, loss = 4.599703
epoch = 2, loss = 4.002461
epoch = 3, loss = 3.643502
epoch = 4, loss = 3.200828
epoch = 5, loss = 3.064402
epoch = 6, loss = 3.031597
epoch = 7, loss = 2.519829
epoch = 8, loss = 2.300697
epoch = 9, loss = 1.282029
epoch = 10, loss = 1.734572
epoch = 11, loss = 1.060347
epoch = 12, loss = 0.893185
epoch = 13, loss = 1.199900
epoch = 14, loss = 0.554724
epoch = 15, loss = 0.498867
epoch = 16, loss = 0.454237
epoch = 17, loss = 0.445779
epoch = 18, loss = 0.474876
epoch = 19, loss = 0.509845
epoch = 20, loss = 0.424407
epoch = 21, loss = 0.252635
epoch = 22, loss = 0.437786
epoch = 23, loss = 0.358960
epoch = 24, loss = 0.293566
epoch = 25, loss = 0.445366
epoch = 26, loss = 0.198551
epoch = 27, loss = 0.383895
epoch = 28, loss = 0.342723
epoch = 29, loss = 0.181380
epoch = 30, loss = 0.230176
epoch = 31, loss = 0.513438
epoch = 32, loss = 0.234315
epoch = 33, loss = 0.137088
epoch = 34, loss = 0.529957
epoch = 35, loss = 0.211857
ep

epoch = 295, loss = 0.002088
epoch = 296, loss = 0.001541
epoch = 297, loss = 0.001411
epoch = 298, loss = 0.001126
epoch = 299, loss = 0.002200
epoch = 300, loss = 0.001267
epoch = 301, loss = 0.002157
epoch = 302, loss = 0.003051
epoch = 303, loss = 0.002977
epoch = 304, loss = 0.002963
epoch = 305, loss = 0.001884
epoch = 306, loss = 0.001729
epoch = 307, loss = 0.001322
epoch = 308, loss = 0.000983
epoch = 309, loss = 0.004858
epoch = 310, loss = 0.002111
epoch = 311, loss = 0.003884
epoch = 312, loss = 0.002522
epoch = 313, loss = 0.002340
epoch = 314, loss = 0.002043
epoch = 315, loss = 0.002034
epoch = 316, loss = 0.001311
epoch = 317, loss = 0.001351
epoch = 318, loss = 0.001343
epoch = 319, loss = 0.002725
epoch = 320, loss = 0.001782
epoch = 321, loss = 0.001770
epoch = 322, loss = 0.001538
epoch = 323, loss = 0.001569
epoch = 324, loss = 0.004996
epoch = 325, loss = 0.002561
epoch = 326, loss = 0.002732
epoch = 327, loss = 0.002220
epoch = 328, loss = 0.001969
epoch = 329, l

epoch = 591, loss = 0.000521
epoch = 592, loss = 0.000494
epoch = 593, loss = 0.000470
epoch = 594, loss = 0.000780
epoch = 595, loss = 0.000359
epoch = 596, loss = 0.000443
epoch = 597, loss = 0.000430
epoch = 598, loss = 0.001064
epoch = 599, loss = 0.000571
epoch = 600, loss = 0.000514
epoch = 601, loss = 0.000545
epoch = 602, loss = 0.000443
epoch = 603, loss = 0.000631
epoch = 604, loss = 0.000721
epoch = 605, loss = 0.000601
epoch = 606, loss = 0.000754
epoch = 607, loss = 0.000629
epoch = 608, loss = 0.000543
epoch = 609, loss = 0.000513
epoch = 610, loss = 0.000480
epoch = 611, loss = 0.000486
epoch = 612, loss = 0.000551
epoch = 613, loss = 0.000763
epoch = 614, loss = 0.001032
epoch = 615, loss = 0.001174
epoch = 616, loss = 0.000619
epoch = 617, loss = 0.000718
epoch = 618, loss = 0.000806
epoch = 619, loss = 0.000511
epoch = 620, loss = 0.000739
epoch = 621, loss = 0.000717
epoch = 622, loss = 0.000543
epoch = 623, loss = 0.000597
epoch = 624, loss = 0.000505
epoch = 625, l

epoch = 883, loss = 0.000269
epoch = 884, loss = 0.000335
epoch = 885, loss = 0.000406
epoch = 886, loss = 0.000394
epoch = 887, loss = 0.000243
epoch = 888, loss = 0.000277
epoch = 889, loss = 0.000292
epoch = 890, loss = 0.000239
epoch = 891, loss = 0.000291
epoch = 892, loss = 0.000299
epoch = 893, loss = 0.000283
epoch = 894, loss = 0.000242
epoch = 895, loss = 0.000629
epoch = 896, loss = 0.000304
epoch = 897, loss = 0.000478
epoch = 898, loss = 0.000456
epoch = 899, loss = 0.000313
epoch = 900, loss = 0.000352
epoch = 901, loss = 0.000248
epoch = 902, loss = 0.000253
epoch = 903, loss = 0.000288
epoch = 904, loss = 0.000271
epoch = 905, loss = 0.000272
epoch = 906, loss = 0.000263
epoch = 907, loss = 0.000298
epoch = 908, loss = 0.000276
epoch = 909, loss = 0.000464
epoch = 910, loss = 0.000445
epoch = 911, loss = 0.000237
epoch = 912, loss = 0.000228
epoch = 913, loss = 0.000336
epoch = 914, loss = 0.000250
epoch = 915, loss = 0.000372
epoch = 916, loss = 0.000274
epoch = 917, l

epoch = 1170, loss = 0.000233
epoch = 1171, loss = 0.000416
epoch = 1172, loss = 0.000341
epoch = 1173, loss = 0.000264
epoch = 1174, loss = 0.000553
epoch = 1175, loss = 0.000257
epoch = 1176, loss = 0.000269
epoch = 1177, loss = 0.000235
epoch = 1178, loss = 0.000314
epoch = 1179, loss = 0.000501
epoch = 1180, loss = 0.000239
epoch = 1181, loss = 0.000469
epoch = 1182, loss = 0.000182
epoch = 1183, loss = 0.000462
epoch = 1184, loss = 0.000673
epoch = 1185, loss = 0.000262
epoch = 1186, loss = 0.000421
epoch = 1187, loss = 0.000215
epoch = 1188, loss = 0.000261
epoch = 1189, loss = 0.000205
epoch = 1190, loss = 0.000216
epoch = 1191, loss = 0.000286
epoch = 1192, loss = 0.000279
epoch = 1193, loss = 0.000273
epoch = 1194, loss = 0.000214
epoch = 1195, loss = 0.000297
epoch = 1196, loss = 0.000252
epoch = 1197, loss = 0.000269
epoch = 1198, loss = 0.000248
epoch = 1199, loss = 0.000274
epoch = 1200, loss = 0.000383
epoch = 1201, loss = 0.000475
epoch = 1202, loss = 0.000278
epoch = 12

epoch = 1452, loss = 0.000253
epoch = 1453, loss = 0.000299
epoch = 1454, loss = 0.000547
epoch = 1455, loss = 0.000229
epoch = 1456, loss = 0.000368
epoch = 1457, loss = 0.000244
epoch = 1458, loss = 0.000220
epoch = 1459, loss = 0.000552
epoch = 1460, loss = 0.000251
epoch = 1461, loss = 0.000366
epoch = 1462, loss = 0.000206
epoch = 1463, loss = 0.000253
epoch = 1464, loss = 0.000265
epoch = 1465, loss = 0.000469
epoch = 1466, loss = 0.000243
epoch = 1467, loss = 0.000241
epoch = 1468, loss = 0.000228
epoch = 1469, loss = 0.000348
epoch = 1470, loss = 0.000202
epoch = 1471, loss = 0.000200
epoch = 1472, loss = 0.000211
epoch = 1473, loss = 0.000222
epoch = 1474, loss = 0.000399
epoch = 1475, loss = 0.000336
epoch = 1476, loss = 0.000294
epoch = 1477, loss = 0.000186
epoch = 1478, loss = 0.000378
epoch = 1479, loss = 0.000247
epoch = 1480, loss = 0.000254
epoch = 1481, loss = 0.000337
epoch = 1482, loss = 0.000306
epoch = 1483, loss = 0.000303
epoch = 1484, loss = 0.000255
epoch = 14

epoch = 1733, loss = 0.000248
epoch = 1734, loss = 0.000474
epoch = 1735, loss = 0.000561
epoch = 1736, loss = 0.000413
epoch = 1737, loss = 0.000494
epoch = 1738, loss = 0.000228
epoch = 1739, loss = 0.000415
epoch = 1740, loss = 0.000201
epoch = 1741, loss = 0.000205
epoch = 1742, loss = 0.000394
epoch = 1743, loss = 0.000225
epoch = 1744, loss = 0.000194
epoch = 1745, loss = 0.000273
epoch = 1746, loss = 0.000427
epoch = 1747, loss = 0.000168
epoch = 1748, loss = 0.000600
epoch = 1749, loss = 0.000532
epoch = 1750, loss = 0.000167
epoch = 1751, loss = 0.000168
epoch = 1752, loss = 0.000290
epoch = 1753, loss = 0.000190
epoch = 1754, loss = 0.000247
epoch = 1755, loss = 0.000194
epoch = 1756, loss = 0.000402
epoch = 1757, loss = 0.000216
epoch = 1758, loss = 0.000213
epoch = 1759, loss = 0.000442
epoch = 1760, loss = 0.000386
epoch = 1761, loss = 0.000249
epoch = 1762, loss = 0.000253
epoch = 1763, loss = 0.000204
epoch = 1764, loss = 0.000257
epoch = 1765, loss = 0.000209
epoch = 17

epoch = 2020, loss = 0.000353
epoch = 2021, loss = 0.000222
epoch = 2022, loss = 0.000500
epoch = 2023, loss = 0.000254
epoch = 2024, loss = 0.000224
epoch = 2025, loss = 0.000335
epoch = 2026, loss = 0.000224
epoch = 2027, loss = 0.000246
epoch = 2028, loss = 0.000358
epoch = 2029, loss = 0.000372
epoch = 2030, loss = 0.000410
epoch = 2031, loss = 0.000206
epoch = 2032, loss = 0.000227
epoch = 2033, loss = 0.000326
epoch = 2034, loss = 0.000228
epoch = 2035, loss = 0.000463
epoch = 2036, loss = 0.000447
epoch = 2037, loss = 0.000174
epoch = 2038, loss = 0.000250
epoch = 2039, loss = 0.000682
epoch = 2040, loss = 0.000273
epoch = 2041, loss = 0.000198
epoch = 2042, loss = 0.000366
epoch = 2043, loss = 0.000222
epoch = 2044, loss = 0.000312
epoch = 2045, loss = 0.000321
epoch = 2046, loss = 0.000195
epoch = 2047, loss = 0.000476
epoch = 2048, loss = 0.000489
epoch = 2049, loss = 0.000311
epoch = 2050, loss = 0.000182
epoch = 2051, loss = 0.000156
epoch = 2052, loss = 0.000366
epoch = 20

epoch = 2294, loss = 0.000270
epoch = 2295, loss = 0.000171
epoch = 2296, loss = 0.000205
epoch = 2297, loss = 0.000214
epoch = 2298, loss = 0.000215
epoch = 2299, loss = 0.000167
epoch = 2300, loss = 0.000420
epoch = 2301, loss = 0.000376
epoch = 2302, loss = 0.000205
epoch = 2303, loss = 0.000258
epoch = 2304, loss = 0.000427
epoch = 2305, loss = 0.000398
epoch = 2306, loss = 0.000153
epoch = 2307, loss = 0.000520
epoch = 2308, loss = 0.000485
epoch = 2309, loss = 0.000169
epoch = 2310, loss = 0.000381
epoch = 2311, loss = 0.000263
epoch = 2312, loss = 0.000414
epoch = 2313, loss = 0.000306
epoch = 2314, loss = 0.000303
epoch = 2315, loss = 0.000191
epoch = 2316, loss = 0.000256
epoch = 2317, loss = 0.000185
epoch = 2318, loss = 0.000444
epoch = 2319, loss = 0.000284
epoch = 2320, loss = 0.000477
epoch = 2321, loss = 0.000306
epoch = 2322, loss = 0.000269
epoch = 2323, loss = 0.000231
epoch = 2324, loss = 0.000314
epoch = 2325, loss = 0.000263
epoch = 2326, loss = 0.000381
epoch = 23

In [187]:
def get_result():
    a =[]
    for parameter in model_best.parameters():
        a.append(parameter.data)
    delta = a[0]
    r = torch.mm(a[1].permute(1, 0), a[2].permute(1, 0))
    return delta, r.reshape(1, -1)

In [219]:
delta, r = get_result()
print('attenuation = %.4f' %delta)
print ('The autocorrelation coefficients are %s' %r[0])

attenuation = 0.0002
The autocorrelation coefficients are tensor([ 0.1499,  0.1436,  0.1260,  0.0999,  0.0693,  0.0386,  0.0116, -0.0089,
        -0.0219, -0.0276, -0.0274, -0.0234, -0.0177, -0.0118, -0.0069, -0.0034,
        -0.0014, -0.0003,  0.0000,  0.0001])


In [235]:
test_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=W.shape[0], shuffle=False)
with torch.no_grad():
    for w in test_loader:
        out = model(w)
        R = out[:,0].reshape(-1, 1)
        delta = out[:,1].reshape(-1, 1)
        v = torch.cat((R-alpha**2, 1/(alpha**2)-R, R-delta, -R), 1)
        v = F.relu(v)
        reg = reg_lambda(w, ws, wp, tradeoff)
        tradeoff_tensor = torch.from_numpy(np.asarray(tradeoff).astype('float32'))
        loss = (torch.sum(v*reg/tradeoff_tensor) + torch.sum(delta))/w.size(0)

if abs(loss - delta[0])< 1e-5:
    print ('Succeed!')
    print ('loss = %.4f' %loss)        

Succeed!
loss = 0.0002


In [236]:
torch.save(model_best, 'model_lowpass_attenuation.pth')

  "type " + obj.__name__ + ". It won't be checked "
