In [1]:
import numpy as np
import torch
import torch.nn as nn
import math
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
pi = math.pi


In [2]:
def sign(x):
    y = torch.zeros(x.size())
    y[x > 0] = 1
    return y

In [3]:
class Mynet(nn.Module):
    def __init__(self, n):
        super(Mynet, self).__init__()
        self.n = n
        self.fc1 = nn.Linear(n, 64, bias=False)
        self.fc_last = nn.Linear(64, 1, bias=False)
        self.delta = Parameter(torch.Tensor(1))
        
    def forward(self, w):
        a = tocosines(w, self.n)
        out = self.fc1(a)
        r = self.fc_last(out)
        out = torch.cat((r, torch.ones(r.size(0),1)*self.delta),1)
        return out

In [4]:
def tocosines(w, n):
    '''
    w: batch*1
    '''
    t = torch.arange(0, n, 1.0)
    c = torch.ones(1, n)*2
    c[0,0] = 1
    out = torch.cos(torch.ger(w, t))*c
    return out #batch*n

In [5]:
def customized_loss(w, out, alpha, ws, wp, tradeoff):
    '''
    w: batch * n
    out: batch * 2
    
    '''
    R = out[:,0].reshape(-1, 1)
    delta = out[:,1].reshape(-1, 1)
    v = torch.cat((R-alpha**2, 1/(alpha**2)-R, R-delta, -R), 1)
    v = F.relu(v)
#     print(v)
    reg = reg_lambda(w, ws, wp, tradeoff)
#     print (reg)
    loss = (torch.sum(v*reg) + torch.sum(delta))/w.size(0)
    return loss

In [6]:
def reg_lambda(w, ws, wp, tradeoff):
    reg = torch.zeros(w.size(0), 4)
    ind1 = ((w < wp).nonzero())
    ind2 = ((w > ws).nonzero())
    reg[ind1, 0] = 1
    reg[ind1, 1] = 1
    reg[ind2, 2] = 1
    reg[:, 3] = 1
    return reg*tradeoff #batch*4

In [7]:
class MyDataset(Dataset):
    def __init__(self, W):
        self.W = W

    def __len__(self):
        return len(self.W)

    def __getitem__(self, index):
        w = self.W[index]
        return w

In [8]:
alpha = 1.1
wp = 0.12*pi
ws = 0.24*pi
tradeoff = 20
n = 40
m = n*15
W = np.linspace(0, pi, m).astype('float32')
W = torch.from_numpy(W)

In [14]:
batch_size = 200
train_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=batch_size, shuffle=True)

In [15]:
model = Mynet(n)
optimizer = optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
scheduler = MultiStepLR(optimizer, milestones=[2000, 4000], gamma=0.2)
epochs = 6000

model_best = model
previous_loss = 10000
for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    scheduler.step(epoch)
    for i, w in enumerate(train_loader, 0):

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        out = model(w)
        loss = customized_loss(w, out, alpha, ws, wp, tradeoff)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    print('epoch = %d, loss = %.6f' %(epoch, running_loss/i))
    
    if running_loss < previous_loss:
        model_best = model
    
    previous_loss = running_loss
    
print('Finished Training')

epoch = 0, loss = 9.891221
epoch = 1, loss = 9.795806
epoch = 2, loss = 9.630456
epoch = 3, loss = 9.419255
epoch = 4, loss = 9.183475
epoch = 5, loss = 8.922299
epoch = 6, loss = 8.651001
epoch = 7, loss = 8.373721
epoch = 8, loss = 8.097283
epoch = 9, loss = 7.824198
epoch = 10, loss = 7.546487
epoch = 11, loss = 7.292489
epoch = 12, loss = 7.044749
epoch = 13, loss = 6.808045
epoch = 14, loss = 6.576346
epoch = 15, loss = 6.360705
epoch = 16, loss = 6.154924
epoch = 17, loss = 5.961677
epoch = 18, loss = 5.789044
epoch = 19, loss = 5.624713
epoch = 20, loss = 5.464049
epoch = 21, loss = 5.317451
epoch = 22, loss = 5.171974
epoch = 23, loss = 5.037196
epoch = 24, loss = 4.910702
epoch = 25, loss = 4.787076
epoch = 26, loss = 4.661668
epoch = 27, loss = 4.543801
epoch = 28, loss = 4.426100
epoch = 29, loss = 4.310467
epoch = 30, loss = 4.200254
epoch = 31, loss = 4.093387
epoch = 32, loss = 3.988854
epoch = 33, loss = 3.882932
epoch = 34, loss = 3.777977
epoch = 35, loss = 3.678701
ep

epoch = 308, loss = 0.008151
epoch = 309, loss = 0.008069
epoch = 310, loss = 0.008056
epoch = 311, loss = 0.008087
epoch = 312, loss = 0.008033
epoch = 313, loss = 0.008025
epoch = 314, loss = 0.007993
epoch = 315, loss = 0.008130
epoch = 316, loss = 0.008028
epoch = 317, loss = 0.007946
epoch = 318, loss = 0.008008
epoch = 319, loss = 0.007999
epoch = 320, loss = 0.008022
epoch = 321, loss = 0.008021
epoch = 322, loss = 0.007917
epoch = 323, loss = 0.007961
epoch = 324, loss = 0.007889
epoch = 325, loss = 0.007930
epoch = 326, loss = 0.007884
epoch = 327, loss = 0.007888
epoch = 328, loss = 0.007790
epoch = 329, loss = 0.007817
epoch = 330, loss = 0.007689
epoch = 331, loss = 0.007777
epoch = 332, loss = 0.007712
epoch = 333, loss = 0.007675
epoch = 334, loss = 0.007676
epoch = 335, loss = 0.007622
epoch = 336, loss = 0.007688
epoch = 337, loss = 0.007757
epoch = 338, loss = 0.007648
epoch = 339, loss = 0.007630
epoch = 340, loss = 0.007784
epoch = 341, loss = 0.007618
epoch = 342, l

epoch = 596, loss = 0.006366
epoch = 597, loss = 0.006524
epoch = 598, loss = 0.006479
epoch = 599, loss = 0.006520
epoch = 600, loss = 0.006452
epoch = 601, loss = 0.006573
epoch = 602, loss = 0.006541
epoch = 603, loss = 0.006419
epoch = 604, loss = 0.006424
epoch = 605, loss = 0.006478
epoch = 606, loss = 0.006359
epoch = 607, loss = 0.006392
epoch = 608, loss = 0.006402
epoch = 609, loss = 0.006337
epoch = 610, loss = 0.006305
epoch = 611, loss = 0.006343
epoch = 612, loss = 0.006304
epoch = 613, loss = 0.006375
epoch = 614, loss = 0.006241
epoch = 615, loss = 0.006386
epoch = 616, loss = 0.006213
epoch = 617, loss = 0.006314
epoch = 618, loss = 0.006336
epoch = 619, loss = 0.006222
epoch = 620, loss = 0.006287
epoch = 621, loss = 0.006301
epoch = 622, loss = 0.006266
epoch = 623, loss = 0.006297
epoch = 624, loss = 0.006305
epoch = 625, loss = 0.006298
epoch = 626, loss = 0.006284
epoch = 627, loss = 0.006329
epoch = 628, loss = 0.006279
epoch = 629, loss = 0.006392
epoch = 630, l

epoch = 882, loss = 0.005594
epoch = 883, loss = 0.005585
epoch = 884, loss = 0.005676
epoch = 885, loss = 0.005554
epoch = 886, loss = 0.005618
epoch = 887, loss = 0.005483
epoch = 888, loss = 0.005569
epoch = 889, loss = 0.005507
epoch = 890, loss = 0.005583
epoch = 891, loss = 0.005692
epoch = 892, loss = 0.005595
epoch = 893, loss = 0.005911
epoch = 894, loss = 0.005634
epoch = 895, loss = 0.005639
epoch = 896, loss = 0.005706
epoch = 897, loss = 0.005708
epoch = 898, loss = 0.005521
epoch = 899, loss = 0.005551
epoch = 900, loss = 0.005448
epoch = 901, loss = 0.005518
epoch = 902, loss = 0.005404
epoch = 903, loss = 0.005458
epoch = 904, loss = 0.005411
epoch = 905, loss = 0.005430
epoch = 906, loss = 0.005452
epoch = 907, loss = 0.005432
epoch = 908, loss = 0.005421
epoch = 909, loss = 0.005455
epoch = 910, loss = 0.005447
epoch = 911, loss = 0.005536
epoch = 912, loss = 0.005529
epoch = 913, loss = 0.005548
epoch = 914, loss = 0.005367
epoch = 915, loss = 0.005527
epoch = 916, l

epoch = 1173, loss = 0.004668
epoch = 1174, loss = 0.004835
epoch = 1175, loss = 0.004585
epoch = 1176, loss = 0.004572
epoch = 1177, loss = 0.004668
epoch = 1178, loss = 0.004657
epoch = 1179, loss = 0.004885
epoch = 1180, loss = 0.004850
epoch = 1181, loss = 0.004771
epoch = 1182, loss = 0.004786
epoch = 1183, loss = 0.004691
epoch = 1184, loss = 0.004807
epoch = 1185, loss = 0.004740
epoch = 1186, loss = 0.004789
epoch = 1187, loss = 0.004676
epoch = 1188, loss = 0.004783
epoch = 1189, loss = 0.004614
epoch = 1190, loss = 0.004644
epoch = 1191, loss = 0.004582
epoch = 1192, loss = 0.004637
epoch = 1193, loss = 0.004646
epoch = 1194, loss = 0.004610
epoch = 1195, loss = 0.004674
epoch = 1196, loss = 0.004566
epoch = 1197, loss = 0.004537
epoch = 1198, loss = 0.004658
epoch = 1199, loss = 0.004649
epoch = 1200, loss = 0.004748
epoch = 1201, loss = 0.004691
epoch = 1202, loss = 0.004670
epoch = 1203, loss = 0.004727
epoch = 1204, loss = 0.004802
epoch = 1205, loss = 0.004728
epoch = 12

epoch = 1468, loss = 0.003842
epoch = 1469, loss = 0.003862
epoch = 1470, loss = 0.003775
epoch = 1471, loss = 0.003777
epoch = 1472, loss = 0.003839
epoch = 1473, loss = 0.003808
epoch = 1474, loss = 0.003826
epoch = 1475, loss = 0.003829
epoch = 1476, loss = 0.003833
epoch = 1477, loss = 0.003743
epoch = 1478, loss = 0.003800
epoch = 1479, loss = 0.003769
epoch = 1480, loss = 0.003775
epoch = 1481, loss = 0.003734
epoch = 1482, loss = 0.003838
epoch = 1483, loss = 0.003811
epoch = 1484, loss = 0.003752
epoch = 1485, loss = 0.003842
epoch = 1486, loss = 0.003925
epoch = 1487, loss = 0.003804
epoch = 1488, loss = 0.004028
epoch = 1489, loss = 0.003933
epoch = 1490, loss = 0.003879
epoch = 1491, loss = 0.003788
epoch = 1492, loss = 0.003824
epoch = 1493, loss = 0.003907
epoch = 1494, loss = 0.003867
epoch = 1495, loss = 0.003880
epoch = 1496, loss = 0.003915
epoch = 1497, loss = 0.003896
epoch = 1498, loss = 0.004003
epoch = 1499, loss = 0.003952
epoch = 1500, loss = 0.004039
epoch = 15

epoch = 1747, loss = 0.003412
epoch = 1748, loss = 0.003422
epoch = 1749, loss = 0.003387
epoch = 1750, loss = 0.003478
epoch = 1751, loss = 0.003389
epoch = 1752, loss = 0.003471
epoch = 1753, loss = 0.003353
epoch = 1754, loss = 0.003340
epoch = 1755, loss = 0.003316
epoch = 1756, loss = 0.003396
epoch = 1757, loss = 0.003396
epoch = 1758, loss = 0.003231
epoch = 1759, loss = 0.003340
epoch = 1760, loss = 0.003190
epoch = 1761, loss = 0.003270
epoch = 1762, loss = 0.003279
epoch = 1763, loss = 0.003289
epoch = 1764, loss = 0.003240
epoch = 1765, loss = 0.003172
epoch = 1766, loss = 0.003179
epoch = 1767, loss = 0.003260
epoch = 1768, loss = 0.003288
epoch = 1769, loss = 0.003173
epoch = 1770, loss = 0.003271
epoch = 1771, loss = 0.003279
epoch = 1772, loss = 0.003153
epoch = 1773, loss = 0.003167
epoch = 1774, loss = 0.003220
epoch = 1775, loss = 0.003195
epoch = 1776, loss = 0.003152
epoch = 1777, loss = 0.003167
epoch = 1778, loss = 0.003283
epoch = 1779, loss = 0.003272
epoch = 17

epoch = 2048, loss = 0.002257
epoch = 2049, loss = 0.002248
epoch = 2050, loss = 0.002266
epoch = 2051, loss = 0.002247
epoch = 2052, loss = 0.002246
epoch = 2053, loss = 0.002250
epoch = 2054, loss = 0.002243
epoch = 2055, loss = 0.002275
epoch = 2056, loss = 0.002260
epoch = 2057, loss = 0.002272
epoch = 2058, loss = 0.002251
epoch = 2059, loss = 0.002251
epoch = 2060, loss = 0.002261
epoch = 2061, loss = 0.002242
epoch = 2062, loss = 0.002260
epoch = 2063, loss = 0.002246
epoch = 2064, loss = 0.002246
epoch = 2065, loss = 0.002240
epoch = 2066, loss = 0.002263
epoch = 2067, loss = 0.002239
epoch = 2068, loss = 0.002248
epoch = 2069, loss = 0.002261
epoch = 2070, loss = 0.002239
epoch = 2071, loss = 0.002263
epoch = 2072, loss = 0.002229
epoch = 2073, loss = 0.002250
epoch = 2074, loss = 0.002226
epoch = 2075, loss = 0.002254
epoch = 2076, loss = 0.002232
epoch = 2077, loss = 0.002272
epoch = 2078, loss = 0.002267
epoch = 2079, loss = 0.002240
epoch = 2080, loss = 0.002241
epoch = 20

epoch = 2326, loss = 0.002148
epoch = 2327, loss = 0.002148
epoch = 2328, loss = 0.002140
epoch = 2329, loss = 0.002143
epoch = 2330, loss = 0.002130
epoch = 2331, loss = 0.002148
epoch = 2332, loss = 0.002113
epoch = 2333, loss = 0.002146
epoch = 2334, loss = 0.002128
epoch = 2335, loss = 0.002107
epoch = 2336, loss = 0.002128
epoch = 2337, loss = 0.002117
epoch = 2338, loss = 0.002123
epoch = 2339, loss = 0.002116
epoch = 2340, loss = 0.002118
epoch = 2341, loss = 0.002122
epoch = 2342, loss = 0.002123
epoch = 2343, loss = 0.002114
epoch = 2344, loss = 0.002112
epoch = 2345, loss = 0.002093
epoch = 2346, loss = 0.002098
epoch = 2347, loss = 0.002098
epoch = 2348, loss = 0.002092
epoch = 2349, loss = 0.002118
epoch = 2350, loss = 0.002095
epoch = 2351, loss = 0.002100
epoch = 2352, loss = 0.002092
epoch = 2353, loss = 0.002093
epoch = 2354, loss = 0.002104
epoch = 2355, loss = 0.002111
epoch = 2356, loss = 0.002095
epoch = 2357, loss = 0.002103
epoch = 2358, loss = 0.002088
epoch = 23

epoch = 2621, loss = 0.001997
epoch = 2622, loss = 0.002029
epoch = 2623, loss = 0.001984
epoch = 2624, loss = 0.002000
epoch = 2625, loss = 0.002000
epoch = 2626, loss = 0.002019
epoch = 2627, loss = 0.002002
epoch = 2628, loss = 0.001986
epoch = 2629, loss = 0.002000
epoch = 2630, loss = 0.001993
epoch = 2631, loss = 0.001985
epoch = 2632, loss = 0.001978
epoch = 2633, loss = 0.001991
epoch = 2634, loss = 0.001990
epoch = 2635, loss = 0.001977
epoch = 2636, loss = 0.001965
epoch = 2637, loss = 0.001988
epoch = 2638, loss = 0.001974
epoch = 2639, loss = 0.001982
epoch = 2640, loss = 0.001972
epoch = 2641, loss = 0.002010
epoch = 2642, loss = 0.001977
epoch = 2643, loss = 0.001980
epoch = 2644, loss = 0.002011
epoch = 2645, loss = 0.001973
epoch = 2646, loss = 0.001981
epoch = 2647, loss = 0.001979
epoch = 2648, loss = 0.001970
epoch = 2649, loss = 0.001981
epoch = 2650, loss = 0.001993
epoch = 2651, loss = 0.002011
epoch = 2652, loss = 0.002029
epoch = 2653, loss = 0.002010
epoch = 26

epoch = 2906, loss = 0.001909
epoch = 2907, loss = 0.001879
epoch = 2908, loss = 0.001894
epoch = 2909, loss = 0.001861
epoch = 2910, loss = 0.001876
epoch = 2911, loss = 0.001874
epoch = 2912, loss = 0.001881
epoch = 2913, loss = 0.001869
epoch = 2914, loss = 0.001872
epoch = 2915, loss = 0.001872
epoch = 2916, loss = 0.001870
epoch = 2917, loss = 0.001861
epoch = 2918, loss = 0.001842
epoch = 2919, loss = 0.001854
epoch = 2920, loss = 0.001850
epoch = 2921, loss = 0.001864
epoch = 2922, loss = 0.001868
epoch = 2923, loss = 0.001863
epoch = 2924, loss = 0.001857
epoch = 2925, loss = 0.001864
epoch = 2926, loss = 0.001854
epoch = 2927, loss = 0.001875
epoch = 2928, loss = 0.001870
epoch = 2929, loss = 0.001883
epoch = 2930, loss = 0.001860
epoch = 2931, loss = 0.001876
epoch = 2932, loss = 0.001865
epoch = 2933, loss = 0.001859
epoch = 2934, loss = 0.001855
epoch = 2935, loss = 0.001868
epoch = 2936, loss = 0.001855
epoch = 2937, loss = 0.001864
epoch = 2938, loss = 0.001848
epoch = 29

epoch = 3196, loss = 0.001774
epoch = 3197, loss = 0.001749
epoch = 3198, loss = 0.001758
epoch = 3199, loss = 0.001761
epoch = 3200, loss = 0.001761
epoch = 3201, loss = 0.001795
epoch = 3202, loss = 0.001766
epoch = 3203, loss = 0.001765
epoch = 3204, loss = 0.001772
epoch = 3205, loss = 0.001769
epoch = 3206, loss = 0.001791
epoch = 3207, loss = 0.001742
epoch = 3208, loss = 0.001759
epoch = 3209, loss = 0.001739
epoch = 3210, loss = 0.001747
epoch = 3211, loss = 0.001753
epoch = 3212, loss = 0.001767
epoch = 3213, loss = 0.001753
epoch = 3214, loss = 0.001752
epoch = 3215, loss = 0.001756
epoch = 3216, loss = 0.001751
epoch = 3217, loss = 0.001727
epoch = 3218, loss = 0.001738
epoch = 3219, loss = 0.001736
epoch = 3220, loss = 0.001734
epoch = 3221, loss = 0.001747
epoch = 3222, loss = 0.001758
epoch = 3223, loss = 0.001731
epoch = 3224, loss = 0.001717
epoch = 3225, loss = 0.001723
epoch = 3226, loss = 0.001719
epoch = 3227, loss = 0.001729
epoch = 3228, loss = 0.001746
epoch = 32

epoch = 3493, loss = 0.001640
epoch = 3494, loss = 0.001626
epoch = 3495, loss = 0.001636
epoch = 3496, loss = 0.001626
epoch = 3497, loss = 0.001673
epoch = 3498, loss = 0.001642
epoch = 3499, loss = 0.001654
epoch = 3500, loss = 0.001651
epoch = 3501, loss = 0.001657
epoch = 3502, loss = 0.001643
epoch = 3503, loss = 0.001623
epoch = 3504, loss = 0.001629
epoch = 3505, loss = 0.001644
epoch = 3506, loss = 0.001631
epoch = 3507, loss = 0.001635
epoch = 3508, loss = 0.001641
epoch = 3509, loss = 0.001673
epoch = 3510, loss = 0.001639
epoch = 3511, loss = 0.001693
epoch = 3512, loss = 0.001645
epoch = 3513, loss = 0.001678
epoch = 3514, loss = 0.001664
epoch = 3515, loss = 0.001648
epoch = 3516, loss = 0.001650
epoch = 3517, loss = 0.001638
epoch = 3518, loss = 0.001650
epoch = 3519, loss = 0.001654
epoch = 3520, loss = 0.001658
epoch = 3521, loss = 0.001666
epoch = 3522, loss = 0.001678
epoch = 3523, loss = 0.001605
epoch = 3524, loss = 0.001637
epoch = 3525, loss = 0.001646
epoch = 35

epoch = 3788, loss = 0.001551
epoch = 3789, loss = 0.001540
epoch = 3790, loss = 0.001534
epoch = 3791, loss = 0.001553
epoch = 3792, loss = 0.001534
epoch = 3793, loss = 0.001552
epoch = 3794, loss = 0.001544
epoch = 3795, loss = 0.001537
epoch = 3796, loss = 0.001539
epoch = 3797, loss = 0.001529
epoch = 3798, loss = 0.001569
epoch = 3799, loss = 0.001532
epoch = 3800, loss = 0.001525
epoch = 3801, loss = 0.001513
epoch = 3802, loss = 0.001537
epoch = 3803, loss = 0.001517
epoch = 3804, loss = 0.001511
epoch = 3805, loss = 0.001526
epoch = 3806, loss = 0.001537
epoch = 3807, loss = 0.001532
epoch = 3808, loss = 0.001547
epoch = 3809, loss = 0.001511
epoch = 3810, loss = 0.001521
epoch = 3811, loss = 0.001527
epoch = 3812, loss = 0.001530
epoch = 3813, loss = 0.001542
epoch = 3814, loss = 0.001550
epoch = 3815, loss = 0.001575
epoch = 3816, loss = 0.001564
epoch = 3817, loss = 0.001553
epoch = 3818, loss = 0.001535
epoch = 3819, loss = 0.001545
epoch = 3820, loss = 0.001558
epoch = 38

epoch = 4082, loss = 0.001411
epoch = 4083, loss = 0.001409
epoch = 4084, loss = 0.001411
epoch = 4085, loss = 0.001411
epoch = 4086, loss = 0.001406
epoch = 4087, loss = 0.001409
epoch = 4088, loss = 0.001412
epoch = 4089, loss = 0.001407
epoch = 4090, loss = 0.001404
epoch = 4091, loss = 0.001409
epoch = 4092, loss = 0.001409
epoch = 4093, loss = 0.001409
epoch = 4094, loss = 0.001410
epoch = 4095, loss = 0.001407
epoch = 4096, loss = 0.001403
epoch = 4097, loss = 0.001404
epoch = 4098, loss = 0.001408
epoch = 4099, loss = 0.001408
epoch = 4100, loss = 0.001404
epoch = 4101, loss = 0.001412
epoch = 4102, loss = 0.001405
epoch = 4103, loss = 0.001405
epoch = 4104, loss = 0.001407
epoch = 4105, loss = 0.001406
epoch = 4106, loss = 0.001406
epoch = 4107, loss = 0.001408
epoch = 4108, loss = 0.001407
epoch = 4109, loss = 0.001405
epoch = 4110, loss = 0.001408
epoch = 4111, loss = 0.001407
epoch = 4112, loss = 0.001405
epoch = 4113, loss = 0.001408
epoch = 4114, loss = 0.001412
epoch = 41

epoch = 4356, loss = 0.001392
epoch = 4357, loss = 0.001391
epoch = 4358, loss = 0.001391
epoch = 4359, loss = 0.001390
epoch = 4360, loss = 0.001392
epoch = 4361, loss = 0.001389
epoch = 4362, loss = 0.001392
epoch = 4363, loss = 0.001394
epoch = 4364, loss = 0.001391
epoch = 4365, loss = 0.001392
epoch = 4366, loss = 0.001395
epoch = 4367, loss = 0.001394
epoch = 4368, loss = 0.001387
epoch = 4369, loss = 0.001389
epoch = 4370, loss = 0.001389
epoch = 4371, loss = 0.001392
epoch = 4372, loss = 0.001393
epoch = 4373, loss = 0.001389
epoch = 4374, loss = 0.001388
epoch = 4375, loss = 0.001390
epoch = 4376, loss = 0.001388
epoch = 4377, loss = 0.001389
epoch = 4378, loss = 0.001391
epoch = 4379, loss = 0.001389
epoch = 4380, loss = 0.001395
epoch = 4381, loss = 0.001389
epoch = 4382, loss = 0.001391
epoch = 4383, loss = 0.001393
epoch = 4384, loss = 0.001391
epoch = 4385, loss = 0.001388
epoch = 4386, loss = 0.001387
epoch = 4387, loss = 0.001391
epoch = 4388, loss = 0.001390
epoch = 43

epoch = 4630, loss = 0.001377
epoch = 4631, loss = 0.001376
epoch = 4632, loss = 0.001381
epoch = 4633, loss = 0.001377
epoch = 4634, loss = 0.001373
epoch = 4635, loss = 0.001374
epoch = 4636, loss = 0.001372
epoch = 4637, loss = 0.001371
epoch = 4638, loss = 0.001373
epoch = 4639, loss = 0.001374
epoch = 4640, loss = 0.001375
epoch = 4641, loss = 0.001374
epoch = 4642, loss = 0.001371
epoch = 4643, loss = 0.001372
epoch = 4644, loss = 0.001374
epoch = 4645, loss = 0.001374
epoch = 4646, loss = 0.001371
epoch = 4647, loss = 0.001376
epoch = 4648, loss = 0.001375
epoch = 4649, loss = 0.001372
epoch = 4650, loss = 0.001372
epoch = 4651, loss = 0.001371
epoch = 4652, loss = 0.001373
epoch = 4653, loss = 0.001372
epoch = 4654, loss = 0.001372
epoch = 4655, loss = 0.001371
epoch = 4656, loss = 0.001375
epoch = 4657, loss = 0.001375
epoch = 4658, loss = 0.001375
epoch = 4659, loss = 0.001374
epoch = 4660, loss = 0.001371
epoch = 4661, loss = 0.001369
epoch = 4662, loss = 0.001369
epoch = 46

epoch = 4909, loss = 0.001357
epoch = 4910, loss = 0.001361
epoch = 4911, loss = 0.001356
epoch = 4912, loss = 0.001364
epoch = 4913, loss = 0.001361
epoch = 4914, loss = 0.001360
epoch = 4915, loss = 0.001358
epoch = 4916, loss = 0.001357
epoch = 4917, loss = 0.001358
epoch = 4918, loss = 0.001356
epoch = 4919, loss = 0.001358
epoch = 4920, loss = 0.001358
epoch = 4921, loss = 0.001357
epoch = 4922, loss = 0.001357
epoch = 4923, loss = 0.001358
epoch = 4924, loss = 0.001355
epoch = 4925, loss = 0.001357
epoch = 4926, loss = 0.001354
epoch = 4927, loss = 0.001357
epoch = 4928, loss = 0.001355
epoch = 4929, loss = 0.001357
epoch = 4930, loss = 0.001355
epoch = 4931, loss = 0.001360
epoch = 4932, loss = 0.001359
epoch = 4933, loss = 0.001358
epoch = 4934, loss = 0.001358
epoch = 4935, loss = 0.001355
epoch = 4936, loss = 0.001355
epoch = 4937, loss = 0.001356
epoch = 4938, loss = 0.001354
epoch = 4939, loss = 0.001355
epoch = 4940, loss = 0.001354
epoch = 4941, loss = 0.001356
epoch = 49

epoch = 5205, loss = 0.001340
epoch = 5206, loss = 0.001340
epoch = 5207, loss = 0.001344
epoch = 5208, loss = 0.001339
epoch = 5209, loss = 0.001342
epoch = 5210, loss = 0.001344
epoch = 5211, loss = 0.001340
epoch = 5212, loss = 0.001343
epoch = 5213, loss = 0.001341
epoch = 5214, loss = 0.001340
epoch = 5215, loss = 0.001342
epoch = 5216, loss = 0.001341
epoch = 5217, loss = 0.001343
epoch = 5218, loss = 0.001341
epoch = 5219, loss = 0.001345
epoch = 5220, loss = 0.001341
epoch = 5221, loss = 0.001339
epoch = 5222, loss = 0.001340
epoch = 5223, loss = 0.001340
epoch = 5224, loss = 0.001339
epoch = 5225, loss = 0.001339
epoch = 5226, loss = 0.001340
epoch = 5227, loss = 0.001340
epoch = 5228, loss = 0.001338
epoch = 5229, loss = 0.001341
epoch = 5230, loss = 0.001339
epoch = 5231, loss = 0.001343
epoch = 5232, loss = 0.001339
epoch = 5233, loss = 0.001344
epoch = 5234, loss = 0.001342
epoch = 5235, loss = 0.001346
epoch = 5236, loss = 0.001343
epoch = 5237, loss = 0.001339
epoch = 52

epoch = 5485, loss = 0.001327
epoch = 5486, loss = 0.001326
epoch = 5487, loss = 0.001329
epoch = 5488, loss = 0.001323
epoch = 5489, loss = 0.001328
epoch = 5490, loss = 0.001326
epoch = 5491, loss = 0.001326
epoch = 5492, loss = 0.001324
epoch = 5493, loss = 0.001322
epoch = 5494, loss = 0.001325
epoch = 5495, loss = 0.001326
epoch = 5496, loss = 0.001323
epoch = 5497, loss = 0.001325
epoch = 5498, loss = 0.001325
epoch = 5499, loss = 0.001325
epoch = 5500, loss = 0.001324
epoch = 5501, loss = 0.001324
epoch = 5502, loss = 0.001326
epoch = 5503, loss = 0.001322
epoch = 5504, loss = 0.001327
epoch = 5505, loss = 0.001323
epoch = 5506, loss = 0.001322
epoch = 5507, loss = 0.001323
epoch = 5508, loss = 0.001325
epoch = 5509, loss = 0.001327
epoch = 5510, loss = 0.001326
epoch = 5511, loss = 0.001322
epoch = 5512, loss = 0.001325
epoch = 5513, loss = 0.001327
epoch = 5514, loss = 0.001329
epoch = 5515, loss = 0.001332
epoch = 5516, loss = 0.001326
epoch = 5517, loss = 0.001329
epoch = 55

epoch = 5761, loss = 0.001310
epoch = 5762, loss = 0.001310
epoch = 5763, loss = 0.001311
epoch = 5764, loss = 0.001315
epoch = 5765, loss = 0.001312
epoch = 5766, loss = 0.001309
epoch = 5767, loss = 0.001314
epoch = 5768, loss = 0.001312
epoch = 5769, loss = 0.001312
epoch = 5770, loss = 0.001314
epoch = 5771, loss = 0.001310
epoch = 5772, loss = 0.001310
epoch = 5773, loss = 0.001307
epoch = 5774, loss = 0.001310
epoch = 5775, loss = 0.001311
epoch = 5776, loss = 0.001309
epoch = 5777, loss = 0.001310
epoch = 5778, loss = 0.001310
epoch = 5779, loss = 0.001308
epoch = 5780, loss = 0.001307
epoch = 5781, loss = 0.001310
epoch = 5782, loss = 0.001306
epoch = 5783, loss = 0.001310
epoch = 5784, loss = 0.001310
epoch = 5785, loss = 0.001310
epoch = 5786, loss = 0.001309
epoch = 5787, loss = 0.001309
epoch = 5788, loss = 0.001309
epoch = 5789, loss = 0.001309
epoch = 5790, loss = 0.001307
epoch = 5791, loss = 0.001307
epoch = 5792, loss = 0.001312
epoch = 5793, loss = 0.001310
epoch = 57

In [18]:
def get_result():
    a =[]
    for parameter in model_best.parameters():
        a.append(parameter.data)
    delta = a[0]
    r = torch.mm(a[1].permute(1, 0), a[2].permute(1, 0))
    return delta, r.reshape(1, -1)

In [19]:
delta, r = get_result()
print(delta)
print (r)

tensor([0.0002])
tensor([[ 0.1470,  0.1387,  0.1161,  0.0847,  0.0517,  0.0237,  0.0049, -0.0040,
         -0.0053, -0.0034, -0.0020, -0.0037, -0.0085, -0.0141, -0.0176, -0.0163,
         -0.0096,  0.0014,  0.0134,  0.0230,  0.0270,  0.0242,  0.0155,  0.0036,
         -0.0082, -0.0167, -0.0200, -0.0180, -0.0120, -0.0040,  0.0037,  0.0093,
          0.0119,  0.0119,  0.0100,  0.0072,  0.0044,  0.0023,  0.0009,  0.0002]])


In [None]:
test_loader = torch.utils.data.DataLoader(
    MyDataset(W),
    batch_size=20, shuffle=True)
with torch.no_grad():
    for w in test_loader:
        print(w)
        out = model(w)
        print (out)
        R = out[:,0].reshape(-1, 1)
        delta = out[:,1].reshape(-1, 1)
        v = torch.cat((R-alpha**2, 1/(alpha**2)-R, R-delta, -R), 1)
        print (v)
        v = F.relu(v)
        print(v)
        reg = reg_lambda(w, ws, wp, tradeoff)
        print (reg)
        print (v*reg)
        loss = (torch.sum(v*reg) + torch.sum(delta))/w.size(0)
        loss = customized_loss(w, out, alpha, ws, wp, tradeoff)
        break