In [1]:
import numpy as np
import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
pi = math.pi

In [2]:
class Mynet(nn.Module):
    def __init__(self, n):
        super(Mynet, self).__init__()
        self.n = n
        self.fc1 = nn.Linear(n, 64, bias=False)
        self.fc_last = nn.Linear(64, 1, bias=False)
        self.alpha = Parameter(torch.Tensor(1)+1)
        
    def forward(self, w):
        a = tocosines(w, self.n)
        out = self.fc1(a)
        r = self.fc_last(out)
        out = torch.cat((r, torch.ones(r.size(0),1)*self.alpha),1)
        return out

In [3]:
def tocosines(w, n):
    '''
    w: batch*1
    '''
    t = torch.arange(0, n, 1.0)
    c = torch.ones(1, n)*2
    c[0,0] = 1
    out = torch.cos(torch.ger(w, t))*c
    return out #batch*n

In [4]:
def customized_loss(w, out, wa, wb, tradeoff):
    '''
    w: batch * n
    out: batch * 2
    
    '''
    R = out[:,0].reshape(-1, 1)
    alpha = out[:,1].reshape(-1, 1)
    v = torch.cat((R*w.reshape(-1, 1)/alpha-alpha, 1/alpha-R*w.reshape(-1, 1)/alpha, -R), 1)
    v = F.relu(v)
    reg = reg_lambda(w, wa, wb, tradeoff)
    loss = (torch.sum(v*reg) + torch.sum(alpha))/w.size(0)
    return loss

In [5]:
def reg_lambda(w, wa, wb, tradeoff):
    reg = torch.zeros(w.size(0), 3)
    a = ((w > wa) * (w <wb))
    ind1 = a.nonzero()
    reg[ind1, 0] = 1*tradeoff[0]
    reg[ind1, 1] = 1*tradeoff[1]
    reg[:, 2] = 1*tradeoff[2]
    return reg #batch*3

In [6]:
def D(w):
    return w**(-1/2)

In [7]:
class MyDataset(Dataset):
    def __init__(self, W):
        self.W = W

    def __len__(self):
        return len(self.W)

    def __getitem__(self, index):
        w = self.W[index]
        return w

In [8]:
wa = 0.01*pi
wb = pi
tradeoff = [50, 80, 20]
n = 50
m = n*60
# W1 = np.linspace(0, 0.01*pi, n*5).astype('float32')
# W2 = np.linspace(0.01*pi, pi, m).astype('float32')
W = np.linspace(0.01*pi, pi, m).astype('float32')
W = torch.from_numpy(W)

In [9]:
batch_size = 4
train_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=batch_size, shuffle=True)

In [10]:
model = Mynet(n)
optimizer = optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
scheduler = MultiStepLR(optimizer, milestones=[200, 400], gamma=0.2)
epochs = 700

model_best = model
previous_loss = 10000
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    scheduler.step(epoch)
    for i, w in enumerate(train_loader, 0):

        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        out = model(w)
    
        loss = customized_loss(w, out, wa, wb, tradeoff)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    print (out[0,1].item())
    print('epoch = %d, loss = %.6f' %(epoch, running_loss/i))
    
    if running_loss < previous_loss:
        model_best = model
    
    previous_loss = running_loss
    
print('Finished Training')

1.6184911727905273
epoch = 0, loss = 12.325908
1.6479922533035278
epoch = 1, loss = 3.907630
1.6468271017074585
epoch = 2, loss = 3.215095
1.6341311931610107
epoch = 3, loss = 2.839537
1.6059870719909668
epoch = 4, loss = 2.580584
1.5682964324951172
epoch = 5, loss = 2.387726
1.5312079191207886
epoch = 6, loss = 2.221573
1.5348410606384277
epoch = 7, loss = 2.237031
1.5733705759048462
epoch = 8, loss = 2.302593
1.5760977268218994
epoch = 9, loss = 2.218486
1.5751876831054688
epoch = 10, loss = 2.119332
1.5547523498535156
epoch = 11, loss = 2.027841
1.4991416931152344
epoch = 12, loss = 1.877777
1.5508815050125122
epoch = 13, loss = 2.265578
1.4978114366531372
epoch = 14, loss = 1.837743
1.579820156097412
epoch = 15, loss = 2.352523
1.5233495235443115
epoch = 16, loss = 1.834009
1.463180422782898
epoch = 17, loss = 1.765666
1.5353312492370605
epoch = 18, loss = 2.076387
1.51995849609375
epoch = 19, loss = 1.956784
1.4900691509246826
epoch = 20, loss = 1.829283
1.4348828792572021
epoch =

1.3166615962982178
epoch = 174, loss = 1.408208
1.3041125535964966
epoch = 175, loss = 1.444276
1.3587645292282104
epoch = 176, loss = 1.615563
1.3777581453323364
epoch = 177, loss = 1.574309
1.3356566429138184
epoch = 178, loss = 1.453227
1.3583903312683105
epoch = 179, loss = 1.620087
1.310436725616455
epoch = 180, loss = 1.428665
1.3119231462478638
epoch = 181, loss = 1.409857
1.382962942123413
epoch = 182, loss = 1.772756
1.3391879796981812
epoch = 183, loss = 1.481420
1.2713549137115479
epoch = 184, loss = 1.347826
1.3546026945114136
epoch = 185, loss = 1.815138
1.3207064867019653
epoch = 186, loss = 1.435616
1.2681241035461426
epoch = 187, loss = 1.361279
1.3253811597824097
epoch = 188, loss = 1.694007
1.3255232572555542
epoch = 189, loss = 1.501014
1.2636536359786987
epoch = 190, loss = 1.353386
1.3248451948165894
epoch = 191, loss = 1.644923
1.3596352338790894
epoch = 192, loss = 1.558907
1.3093465566635132
epoch = 193, loss = 1.396493
1.3121143579483032
epoch = 194, loss = 1.4

1.0685304403305054
epoch = 346, loss = 1.103257
1.068331241607666
epoch = 347, loss = 1.104163
1.0777719020843506
epoch = 348, loss = 1.157681
1.0946619510650635
epoch = 349, loss = 1.187885
1.0804786682128906
epoch = 350, loss = 1.116529
1.0726367235183716
epoch = 351, loss = 1.105067
1.069358468055725
epoch = 352, loss = 1.104607
1.0671309232711792
epoch = 353, loss = 1.099966
1.0920575857162476
epoch = 354, loss = 1.213919
1.0815976858139038
epoch = 355, loss = 1.128716
1.0888748168945312
epoch = 356, loss = 1.182918
1.0752184391021729
epoch = 357, loss = 1.108504
1.0999584197998047
epoch = 358, loss = 1.250237
1.0894005298614502
epoch = 359, loss = 1.128817
1.0772371292114258
epoch = 360, loss = 1.108040
1.0740338563919067
epoch = 361, loss = 1.102394
1.0707061290740967
epoch = 362, loss = 1.100415
1.086457371711731
epoch = 363, loss = 1.203523
1.106414556503296
epoch = 364, loss = 1.225920
1.101428747177124
epoch = 365, loss = 1.173525
1.0903828144073486
epoch = 366, loss = 1.1303

1.0493354797363281
epoch = 518, loss = 1.080459
1.0496944189071655
epoch = 519, loss = 1.080560
1.049414873123169
epoch = 520, loss = 1.080597
1.049896240234375
epoch = 521, loss = 1.082380
1.0497660636901855
epoch = 522, loss = 1.080589
1.0495173931121826
epoch = 523, loss = 1.080505
1.049727439880371
epoch = 524, loss = 1.080498
1.0498013496398926
epoch = 525, loss = 1.080526
1.0496175289154053
epoch = 526, loss = 1.080481
1.0493574142456055
epoch = 527, loss = 1.080398
1.0496866703033447
epoch = 528, loss = 1.080408
1.049369215965271
epoch = 529, loss = 1.080380
1.0493546724319458
epoch = 530, loss = 1.080397
1.049257755279541
epoch = 531, loss = 1.080360
1.0499738454818726
epoch = 532, loss = 1.080627
1.0495781898498535
epoch = 533, loss = 1.080431
1.0500961542129517
epoch = 534, loss = 1.080834
1.0499777793884277
epoch = 535, loss = 1.080631
1.0498597621917725
epoch = 536, loss = 1.080576
1.0497848987579346
epoch = 537, loss = 1.080556
1.0498201847076416
epoch = 538, loss = 1.0805

1.0510106086730957
epoch = 690, loss = 1.080222
1.051183819770813
epoch = 691, loss = 1.080252
1.0508981943130493
epoch = 692, loss = 1.080200
1.0508534908294678
epoch = 693, loss = 1.080161
1.0511611700057983
epoch = 694, loss = 1.080187
1.051020860671997
epoch = 695, loss = 1.080448
1.0511715412139893
epoch = 696, loss = 1.080408
1.051180124282837
epoch = 697, loss = 1.080420
1.0512416362762451
epoch = 698, loss = 1.080367
1.05112886428833
epoch = 699, loss = 1.081014
Finished Training


In [14]:
def get_result():
    a =[]
    for parameter in model_best.parameters():
        a.append(parameter.data)
    alpha = a[0]
    r = torch.mm(a[1].permute(1, 0), a[2].permute(1, 0))
    return alpha, r.reshape(1, -1)

In [15]:
alpha, r = get_result()
print(alpha)
print (r)

tensor([1.0511])
tensor([[1.8163, 1.2618, 0.9975, 0.8735, 0.7716, 0.7010, 0.6376, 0.5884, 0.5425,
         0.5052, 0.4699, 0.4391, 0.4107, 0.3850, 0.3608, 0.3392, 0.3194, 0.2996,
         0.2823, 0.2652, 0.2494, 0.2350, 0.2216, 0.2072, 0.1964, 0.1835, 0.1735,
         0.1618, 0.1521, 0.1433, 0.1328, 0.1249, 0.1171, 0.1094, 0.1018, 0.0945,
         0.0881, 0.0817, 0.0753, 0.0699, 0.0647, 0.0600, 0.0545, 0.0496, 0.0455,
         0.0413, 0.0375, 0.0340, 0.0297, 0.0183]])


In [21]:
test_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size = W.shape[0] , shuffle=True)
with torch.no_grad():
    for w in test_loader:
        out = model(w)
        R = out[:,0].reshape(-1, 1)
        alpha = out[:,1].reshape(-1, 1)
        v = torch.cat((R*w.reshape(-1, 1)/alpha-alpha, 1/alpha-R*w.reshape(-1, 1)/alpha, -R), 1)
        v = F.relu(v)
        reg = reg_lambda(w, wa, wb, tradeoff)
        tradeoff_tensor = torch.from_numpy(np.asarray(tradeoff).astype('float32'))
        loss = (torch.sum(v*reg/tradeoff_tensor) + torch.sum(alpha))/w.size(0)
        
if abs(loss - alpha[0])< 1e-3:
    print ('Succeed!')
    print ('loss = %.4f' %loss)        
       

Succeed!
loss = 1.0515


In [22]:
torch.save(model_best, 'model_chebychev_approx.pth')

  "type " + obj.__name__ + ". It won't be checked "
