In [1]:
import numpy as np
import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
pi = math.pi

In [2]:
class Mynet(nn.Module):
    def __init__(self, n):
        super(Mynet, self).__init__()
        self.n = n
        self.fc1 = nn.Linear(n, 64, bias=False)
        self.fc_last = nn.Linear(64, 1, bias=False)
        self.alpha = Parameter(torch.Tensor(1)+1)
        
    def forward(self, w):
        a = tocosines(w, self.n)
        out = self.fc1(a)
        r = self.fc_last(out)
        out = torch.cat((r, torch.ones(r.size(0),1)*self.alpha),1)
        return out

In [3]:
def tocosines(w, n):
    '''
    w: batch*1
    '''
    t = torch.arange(0, n, 1.0)
    c = torch.ones(1, n)*2
    c[0,0] = 1
    out = torch.cos(torch.ger(w, t))*c
    return out #batch*n

In [4]:
def customized_loss(w, out, wa, wb, tradeoff):
    '''
    w: batch * n
    out: batch * 2
    
    '''
    R = out[:,0].reshape(-1, 1)
    alpha = out[:,1].reshape(-1, 1)
    v = torch.cat((R*w.reshape(-1, 1)/alpha-alpha, 1/alpha-R*w.reshape(-1, 1)/alpha, -R), 1)
    v = F.relu(v)
    reg = reg_lambda(w, wa, wb, tradeoff)
    loss = (torch.sum(v*reg) + torch.sum(alpha))/w.size(0)
    return loss

In [5]:
def reg_lambda(w, wa, wb, tradeoff):
    reg = torch.zeros(w.size(0), 3)
    a = ((w > wa) * (w <wb))
    ind1 = a.nonzero()
    reg[ind1, 0] = 1
    reg[ind1, 1] = 1
    reg[:, 2] = 1
    return reg*tradeoff #batch*3

In [6]:
def D(w):
    return w**(-1/2)

In [7]:
class MyDataset(Dataset):
    def __init__(self, W):
        self.W = W

    def __len__(self):
        return len(self.W)

    def __getitem__(self, index):
        w = self.W[index]
        return w

In [8]:
wa = 0.01
wb = pi
tradeoff = 20
n = 50
m = n*20
# W1 = np.linspace(0, 0.01*pi, n*5).astype('float32')
# W2 = np.linspace(0.01*pi, pi, m).astype('float32')
W = np.linspace(0.01*pi, pi, m).astype('float32')
W = torch.from_numpy(W)

In [9]:
batch_size = 4
train_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=batch_size, shuffle=True)

In [10]:
model = Mynet(n)
optimizer = optim.SGD(model.parameters(), lr=0.000001, momentum=0.9)
scheduler = MultiStepLR(optimizer, milestones=[400, 600, 800], gamma=0.2)
epochs = 1000

model_best = model
previous_loss = 10000
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    scheduler.step(epoch)
    for i, w in enumerate(train_loader, 0):

        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        out = model(w)
    
        loss = customized_loss(w, out, wa, wb, tradeoff)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    print (out[0,1].item())
    print('epoch = %d, loss = %.6f' %(epoch, running_loss/i))
    
    if running_loss < previous_loss:
        model_best = model
    
    previous_loss = running_loss
    
print('Finished Training')

1.0564519166946411
epoch = 0, loss = 25.124162
1.0956920385360718
epoch = 1, loss = 19.333559
1.123392105102539
epoch = 2, loss = 15.997992
1.1462589502334595
epoch = 3, loss = 13.641335
1.1650151014328003
epoch = 4, loss = 11.838157
1.1811416149139404
epoch = 5, loss = 10.497504
1.1952003240585327
epoch = 6, loss = 9.509873
1.2070493698120117
epoch = 7, loss = 8.736655
1.218227505683899
epoch = 8, loss = 8.157906
1.2283061742782593
epoch = 9, loss = 7.728696
1.23769211769104
epoch = 10, loss = 7.368098
1.246625304222107
epoch = 11, loss = 7.048340
1.2549861669540405
epoch = 12, loss = 6.765895
1.2627675533294678
epoch = 13, loss = 6.501524
1.2699954509735107
epoch = 14, loss = 6.261445
1.276777744293213
epoch = 15, loss = 6.037513
1.2831367254257202
epoch = 16, loss = 5.834662
1.289258599281311
epoch = 17, loss = 5.650653
1.2949336767196655
epoch = 18, loss = 5.469536
1.3000565767288208
epoch = 19, loss = 5.293510
1.3049782514572144
epoch = 20, loss = 5.135945
1.3095695972442627
epoch

1.2801933288574219
epoch = 175, loss = 1.938814
1.2789703607559204
epoch = 176, loss = 1.935059
1.2776589393615723
epoch = 177, loss = 1.931388
1.2764768600463867
epoch = 178, loss = 1.927769
1.2753080129623413
epoch = 179, loss = 1.924338
1.274073600769043
epoch = 180, loss = 1.920778
1.2728337049484253
epoch = 181, loss = 1.917361
1.2716548442840576
epoch = 182, loss = 1.914014
1.2704236507415771
epoch = 183, loss = 1.910569
1.2691529989242554
epoch = 184, loss = 1.907123
1.2678812742233276
epoch = 185, loss = 1.903771
1.2666085958480835
epoch = 186, loss = 1.900381
1.2654920816421509
epoch = 187, loss = 1.897174
1.2641704082489014
epoch = 188, loss = 1.893936
1.2629367113113403
epoch = 189, loss = 1.890717
1.2616393566131592
epoch = 190, loss = 1.887497
1.2603158950805664
epoch = 191, loss = 1.884349
1.259161353111267
epoch = 192, loss = 1.881214
1.257925033569336
epoch = 193, loss = 1.878027
1.2567285299301147
epoch = 194, loss = 1.874983
1.2553914785385132
epoch = 195, loss = 1.87

1.080201268196106
epoch = 347, loss = 1.505123
1.0793803930282593
epoch = 348, loss = 1.503280
1.078829288482666
epoch = 349, loss = 1.501444
1.0779635906219482
epoch = 350, loss = 1.499691
1.077100157737732
epoch = 351, loss = 1.497978
1.0764585733413696
epoch = 352, loss = 1.496298
1.0758129358291626
epoch = 353, loss = 1.494597
1.0748753547668457
epoch = 354, loss = 1.492790
1.0743637084960938
epoch = 355, loss = 1.491052
1.0734537839889526
epoch = 356, loss = 1.489421
1.0728733539581299
epoch = 357, loss = 1.487506
1.0721721649169922
epoch = 358, loss = 1.485781
1.0713887214660645
epoch = 359, loss = 1.484121
1.0707262754440308
epoch = 360, loss = 1.482461
1.0699771642684937
epoch = 361, loss = 1.480762
1.0693615674972534
epoch = 362, loss = 1.479109
1.0686970949172974
epoch = 363, loss = 1.477453
1.0679479837417603
epoch = 364, loss = 1.475869
1.0672529935836792
epoch = 365, loss = 1.474136
1.0664148330688477
epoch = 366, loss = 1.472500
1.06586492061615
epoch = 367, loss = 1.4708

1.0389316082000732
epoch = 519, loss = 1.394914
1.0388469696044922
epoch = 520, loss = 1.394654
1.0387612581253052
epoch = 521, loss = 1.394410
1.0386581420898438
epoch = 522, loss = 1.394176
1.0385664701461792
epoch = 523, loss = 1.393929
1.0384572744369507
epoch = 524, loss = 1.393688
1.0384092330932617
epoch = 525, loss = 1.393452
1.038345217704773
epoch = 526, loss = 1.393218
1.0383235216140747
epoch = 527, loss = 1.392990
1.0381853580474854
epoch = 528, loss = 1.392760
1.0380750894546509
epoch = 529, loss = 1.392513
1.0380080938339233
epoch = 530, loss = 1.392263
1.0378755331039429
epoch = 531, loss = 1.392019
1.0378332138061523
epoch = 532, loss = 1.391795
1.037711262702942
epoch = 533, loss = 1.391548
1.0376522541046143
epoch = 534, loss = 1.391301
1.0375181436538696
epoch = 535, loss = 1.391065
1.0374726057052612
epoch = 536, loss = 1.390836
1.0373554229736328
epoch = 537, loss = 1.390594
1.0372768640518188
epoch = 538, loss = 1.390360
1.0372315645217896
epoch = 539, loss = 1.3

1.0306792259216309
epoch = 691, loss = 1.371653
1.0306649208068848
epoch = 692, loss = 1.371607
1.0306696891784668
epoch = 693, loss = 1.371569
1.0306532382965088
epoch = 694, loss = 1.371522
1.0306249856948853
epoch = 695, loss = 1.371474
1.0306071043014526
epoch = 696, loss = 1.371430
1.0306013822555542
epoch = 697, loss = 1.371386
1.030591607093811
epoch = 698, loss = 1.371343
1.0305770635604858
epoch = 699, loss = 1.371297
1.0305614471435547
epoch = 700, loss = 1.371253
1.0305448770523071
epoch = 701, loss = 1.371208
1.030524492263794
epoch = 702, loss = 1.371161
1.030514121055603
epoch = 703, loss = 1.371116
1.0305016040802002
epoch = 704, loss = 1.371073
1.030478596687317
epoch = 705, loss = 1.371028
1.0304793119430542
epoch = 706, loss = 1.370987
1.0304535627365112
epoch = 707, loss = 1.370939
1.0304462909698486
epoch = 708, loss = 1.370896
1.0304148197174072
epoch = 709, loss = 1.370850
1.0304042100906372
epoch = 710, loss = 1.370809
1.030388355255127
epoch = 711, loss = 1.3707

1.028863787651062
epoch = 863, loss = 1.366512
1.0288623571395874
epoch = 864, loss = 1.366505
1.0288608074188232
epoch = 865, loss = 1.366499
1.0288606882095337
epoch = 866, loss = 1.366494
1.0288602113723755
epoch = 867, loss = 1.366489
1.0288580656051636
epoch = 868, loss = 1.366483
1.0288538932800293
epoch = 869, loss = 1.366476
1.0288527011871338
epoch = 870, loss = 1.366470
1.0288519859313965
epoch = 871, loss = 1.366464
1.0288485288619995
epoch = 872, loss = 1.366458
1.0288459062576294
epoch = 873, loss = 1.366452
1.028845191001892
epoch = 874, loss = 1.366446
1.028842806816101
epoch = 875, loss = 1.366441
1.0288387537002563
epoch = 876, loss = 1.366435
1.0288385152816772
epoch = 877, loss = 1.366429
1.0288382768630981
epoch = 878, loss = 1.366424
1.0288358926773071
epoch = 879, loss = 1.366418
1.0288337469100952
epoch = 880, loss = 1.366412
1.0288288593292236
epoch = 881, loss = 1.366406
1.0288249254226685
epoch = 882, loss = 1.366400
1.0288262367248535
epoch = 883, loss = 1.36

In [11]:
def get_result():
    a =[]
    for parameter in model_best.parameters():
        a.append(parameter.data)
    alpha = a[0]
    r = torch.mm(a[1].permute(1, 0), a[2].permute(1, 0))
    return alpha, r.reshape(1, -1)

In [12]:
alpha, r = get_result()
print(alpha)
print (r)

tensor([1.0286])
tensor([[ 1.2266,  0.6780,  0.4175,  0.2939,  0.1949,  0.1300,  0.0741,  0.0352,
         -0.0006, -0.0305, -0.0551, -0.0728, -0.0860, -0.0973, -0.1092, -0.1169,
         -0.1193, -0.1212, -0.1231, -0.1242, -0.1232, -0.1197, -0.1157, -0.1117,
         -0.1077, -0.1034, -0.0970, -0.0899, -0.0851, -0.0794, -0.0728, -0.0663,
         -0.0607, -0.0548, -0.0498, -0.0432, -0.0389, -0.0342, -0.0287, -0.0246,
         -0.0223, -0.0173, -0.0140, -0.0127, -0.0094, -0.0065, -0.0054, -0.0042,
         -0.0022, -0.0012]])


In [16]:
test_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=20, shuffle=True)
with torch.no_grad():
    for w in test_loader:
        print(w)
        out = model(w)
        print (out)
        R = out[:,0].reshape(-1, 1)
        alpha = out[:,1].reshape(-1, 1)
        v = torch.cat((R*w.reshape(-1, 1)/alpha-alpha, 1/alpha-R*w.reshape(-1, 1)/alpha, -R), 1)
        v = F.relu(v)
        print(v)
        reg = reg_lambda(w, wa, wb, tradeoff)
        loss = (torch.sum(v*reg) + torch.sum(alpha))/w.size(0)
        loss = customized_loss(w, out, wa, wb, tradeoff)
        break

tensor([0.1124, 2.0270, 0.4704, 0.3147, 1.8122, 0.3552, 0.5762, 0.7724, 0.8720,
        1.7873, 2.4971, 3.0482, 1.1802, 2.8894, 2.0021, 1.2830, 2.6653, 1.2207,
        0.4579, 0.5825])
tensor([[7.3938, 1.0286],
        [0.5037, 1.0286],
        [2.2412, 1.0286],
        [3.3375, 1.0286],
        [0.5835, 1.0286],
        [2.9664, 1.0286],
        [1.8247, 1.0286],
        [1.3690, 1.0286],
        [1.2132, 1.0286],
        [0.5853, 1.0286],
        [0.4236, 1.0286],
        [0.3470, 1.0286],
        [0.8749, 1.0286],
        [0.3646, 1.0286],
        [0.5173, 1.0286],
        [0.7799, 1.0286],
        [0.3964, 1.0286],
        [0.8319, 1.0286],
        [2.2983, 1.0286],
        [1.8062, 1.0286]])
tensor([[0.0000, 0.1645, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.000

In [None]:
torch.save(model_best, 'model_chebychev_approx.pth')

In [None]:
W1 = np.linspace(0, 0.01*pi, 2).astype('float32')
print (W1)
W2 = np.linspace(0.01*pi, pi,4).astype('float32')
print (W2)
W = np.append(W1, W2)
print (W)