In [65]:
# https://github.com/llSourcell/pytorch_in_5_minutes/blob/master/demo.py

import torch
from torch.autograd import Variable
import pandas as pd
from random import randint
import numpy as np
import torch.nn as nn
# from torchviz import make_dot # used for grad visualization
import torch.optim as optim

In [20]:
from torch.autograd import Function
from torch.nn.modules.distance import PairwiseDistance

class TripletLoss(Function):
    
    def __init__(self, alpha):
        super(TripletLoss, self).__init__()
        self.alpha = alpha
        self.pdist  = PairwiseDistance(2)
        
    def forward(self, anchor, positive, negative):
        pos_dist   = self.pdist.forward(anchor, positive).pow(2)
        neg_dist   = self.pdist.forward(anchor, negative).pow(2)
        hinge_dist = torch.clamp(self.alpha + pos_dist - neg_dist, min = 0.0)
        loss       = torch.mean(hinge_dist)
        return loss

In [21]:
%%time
#read in relevant data
trainData = torch.from_numpy(np.loadtxt('data/trainData.txt', dtype=np.float32))
queryData = torch.from_numpy(np.loadtxt('data/queryData.txt', dtype=np.float32))
df =  pd.read_pickle("./data/KNN.pkl")

CPU times: user 26 s, sys: 30.7 s, total: 56.7 s
Wall time: 1min


In [22]:
def generateTripplet(index):
    point = queryData[index].reshape(-1, 1)
#     pos = trainData[df.iloc[index].KNN[randint(0,K)]].reshape(-1, 1) # pos fom KNN
#     negIndicies = list(range(K,K + 10)) + list(range(df.shape[0]-20, df.shape[0]-1))
#     neg = trainData[df.iloc[index].KNN[np.random.choice(negIndicies)]].reshape(-1, 1)
#     neg = trainData[df.iloc[index].KNN[randint(K, df.shape[0]-1)]].reshape(-1, 1)
    pos = trainData[df.iloc[index].KNN[0]].reshape(-1, 1) # pos fom KNN
    neg = trainData[df.iloc[index].KNN[5004]].reshape(-1, 1)
    return point, pos, neg

In [23]:
sigmoid = nn.Sigmoid()
def forward_pass(query):
    return sigmoid(torch.norm(query.t() - anchors, 2, 1).reshape(-1, 1) - biases)

In [69]:
BATCH_SIZE, INPUT_D, HIDDEN_D, OUTPUT_D = 100, 192, 128, 128
ALPHA = 0.5
LEARNING_RATE = 1e2
K = 5

def init_model():
    print("--- Initialising Model Params --- ")
    
    anchors = Variable(torch.randn(OUTPUT_D, INPUT_D).type(torch.FloatTensor), requires_grad=True)
    # weights = Variable(torch.randn(HIDDEN_D, OUTPUT_D).type(torch.FloatTensor), requires_grad=True)

    # set biases to be mean value
    aggregate = torch.zeros(OUTPUT_D)
    for point in queryData:
        w0 = torch.norm(point.t() - anchors, 2, 1)
        aggregate += w0

    biases = Variable((aggregate/queryData.shape[0]).reshape(-1, 1), requires_grad=True)
    print("--- Done. Begining training ---")
    return anchors, biases

In [71]:
#training
anchors, biases = init_model()
optimizer = optim.Adam([anchors, biases], lr=LEARNING_RATE)

for epoch in range(10000):

    #generate batch and compute collective loss for batch
    # UNSTABLE LEARNING WHEN SAMPLE SIZE > BATCH SIZE
    batch_indicies = np.random.choice(queryData.shape[0], BATCH_SIZE, replace=False) 
    
    loss = 0
    for index in batch_indicies:
        query, pos, neg = generateTripplet(index)
        queryMapped, posMapped, negMapped = [forward_pass(x) for x in [query, pos, neg]]
        triplet_loss = TripletLoss(ALPHA).forward(queryMapped, posMapped, negMapped)
        loss += triplet_loss
    
    loss /= BATCH_SIZE # computes mean so learning rate remains the same
    
    print(epoch, loss.data)

        
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

--- Initialising Model Params --- 
--- Done. Begining training ---
0 tensor(0.3603)
1 tensor(0.3894)
2 tensor(0.3666)
3 tensor(0.3484)
4 tensor(0.3657)
5 tensor(0.3508)
6 tensor(0.3818)
7 tensor(0.3604)
8 tensor(0.3599)
9 tensor(0.3582)
10 tensor(0.3660)
11 tensor(0.3956)
12 tensor(0.3872)
13 tensor(0.3614)
14 tensor(0.3849)
15 tensor(0.3549)
16 tensor(0.3524)
17 tensor(0.3463)
18 tensor(0.3499)
19 tensor(0.3397)
20 tensor(0.3371)
21 tensor(0.3621)
22 tensor(0.3505)
23 tensor(0.3523)
24 tensor(0.3627)
25 tensor(0.3550)
26 tensor(0.3637)
27 tensor(0.3732)
28 tensor(0.3599)
29 tensor(0.3482)
30 tensor(0.3421)
31 tensor(0.3416)
32 tensor(0.3707)
33 tensor(0.3505)
34 tensor(0.3615)
35 tensor(0.3581)
36 tensor(0.3606)
37 tensor(0.3703)
38 tensor(0.3709)
39 tensor(0.3588)
40 tensor(0.3465)
41 tensor(0.3588)
42 tensor(0.3558)
43 tensor(0.3489)
44 tensor(0.3610)
45 tensor(0.3591)
46 tensor(0.3637)
47 tensor(0.3424)
48 tensor(0.3460)
49 tensor(0.3496)
50 tensor(0.3559)
51 tensor(0.3470)
52 tens

434 tensor(0.3416)
435 tensor(0.3352)
436 tensor(0.3392)
437 tensor(0.3500)
438 tensor(0.3381)
439 tensor(0.3382)
440 tensor(0.3291)
441 tensor(0.3433)
442 tensor(0.3274)
443 tensor(0.3304)
444 tensor(0.3370)
445 tensor(0.3382)
446 tensor(0.3386)
447 tensor(0.3330)
448 tensor(0.3092)
449 tensor(0.3332)
450 tensor(0.3420)
451 tensor(0.3377)
452 tensor(0.3229)
453 tensor(0.3200)
454 tensor(0.3264)
455 tensor(0.3272)
456 tensor(0.3256)
457 tensor(0.3353)
458 tensor(0.3219)
459 tensor(0.3373)
460 tensor(0.3632)
461 tensor(0.3292)
462 tensor(0.3372)
463 tensor(0.3366)
464 tensor(0.3351)
465 tensor(0.3220)
466 tensor(0.3339)
467 tensor(0.3572)
468 tensor(0.3200)
469 tensor(0.3371)
470 tensor(0.3471)
471 tensor(0.3442)
472 tensor(0.3285)
473 tensor(0.3258)
474 tensor(0.3440)
475 tensor(0.3299)
476 tensor(0.3056)
477 tensor(0.3013)
478 tensor(0.3254)
479 tensor(0.3192)
480 tensor(0.3353)
481 tensor(0.3183)
482 tensor(0.3330)
483 tensor(0.3402)
484 tensor(0.3127)
485 tensor(0.3459)
486 tensor(0

866 tensor(0.3325)
867 tensor(0.2824)
868 tensor(0.3411)
869 tensor(0.3232)
870 tensor(0.3280)
871 tensor(0.3075)
872 tensor(0.3521)
873 tensor(0.3173)
874 tensor(0.3193)
875 tensor(0.3090)
876 tensor(0.3314)
877 tensor(0.3187)
878 tensor(0.3197)
879 tensor(0.3286)
880 tensor(0.3406)
881 tensor(0.3233)
882 tensor(0.3429)
883 tensor(0.3288)
884 tensor(0.3124)
885 tensor(0.3340)
886 tensor(0.3349)
887 tensor(0.3020)
888 tensor(0.3169)
889 tensor(0.3273)
890 tensor(0.3353)
891 tensor(0.3204)
892 tensor(0.3062)
893 tensor(0.3079)
894 tensor(0.3115)
895 tensor(0.3005)
896 tensor(0.3094)
897 tensor(0.3121)
898 tensor(0.3322)
899 tensor(0.3125)
900 tensor(0.3410)
901 tensor(0.3176)
902 tensor(0.3170)
903 tensor(0.3198)
904 tensor(0.3089)
905 tensor(0.2974)
906 tensor(0.3286)
907 tensor(0.3194)
908 tensor(0.3379)
909 tensor(0.3216)
910 tensor(0.3299)
911 tensor(0.3372)
912 tensor(0.3283)
913 tensor(0.3366)
914 tensor(0.3175)
915 tensor(0.3155)
916 tensor(0.3231)
917 tensor(0.3239)
918 tensor(0

1283 tensor(0.3041)
1284 tensor(0.3102)
1285 tensor(0.3023)
1286 tensor(0.3228)
1287 tensor(0.2971)
1288 tensor(0.3212)
1289 tensor(0.3186)
1290 tensor(0.3142)
1291 tensor(0.3425)
1292 tensor(0.3019)
1293 tensor(0.3180)
1294 tensor(0.2975)
1295 tensor(0.2800)
1296 tensor(0.3139)
1297 tensor(0.2876)
1298 tensor(0.2922)
1299 tensor(0.2985)
1300 tensor(0.3311)
1301 tensor(0.3282)
1302 tensor(0.3078)
1303 tensor(0.3193)
1304 tensor(0.3047)
1305 tensor(0.3247)
1306 tensor(0.3086)
1307 tensor(0.2877)
1308 tensor(0.3215)
1309 tensor(0.3075)
1310 tensor(0.3226)
1311 tensor(0.2930)
1312 tensor(0.3107)
1313 tensor(0.3252)
1314 tensor(0.2966)
1315 tensor(0.3064)
1316 tensor(0.3229)
1317 tensor(0.2985)
1318 tensor(0.3325)
1319 tensor(0.3216)
1320 tensor(0.3325)
1321 tensor(0.3158)
1322 tensor(0.2988)
1323 tensor(0.3285)
1324 tensor(0.3171)
1325 tensor(0.3038)
1326 tensor(0.3156)
1327 tensor(0.3035)
1328 tensor(0.3216)
1329 tensor(0.3179)
1330 tensor(0.3115)
1331 tensor(0.3192)
1332 tensor(0.3441)


1693 tensor(0.2996)
1694 tensor(0.3041)
1695 tensor(0.2821)
1696 tensor(0.2997)
1697 tensor(0.3003)
1698 tensor(0.3002)
1699 tensor(0.2904)
1700 tensor(0.3084)
1701 tensor(0.2961)
1702 tensor(0.3038)
1703 tensor(0.3356)
1704 tensor(0.2810)
1705 tensor(0.3032)
1706 tensor(0.2873)
1707 tensor(0.3209)
1708 tensor(0.2859)
1709 tensor(0.3121)
1710 tensor(0.3154)
1711 tensor(0.2774)
1712 tensor(0.2696)
1713 tensor(0.3046)
1714 tensor(0.3208)
1715 tensor(0.2976)
1716 tensor(0.2986)
1717 tensor(0.3048)
1718 tensor(0.2744)
1719 tensor(0.3178)
1720 tensor(0.3180)
1721 tensor(0.3121)
1722 tensor(0.3040)
1723 tensor(0.3232)
1724 tensor(0.3009)
1725 tensor(0.3180)
1726 tensor(0.3117)
1727 tensor(0.3021)
1728 tensor(0.3076)
1729 tensor(0.2982)
1730 tensor(0.2891)
1731 tensor(0.3033)
1732 tensor(0.3017)
1733 tensor(0.3263)
1734 tensor(0.2831)
1735 tensor(0.2812)
1736 tensor(0.3321)
1737 tensor(0.3299)
1738 tensor(0.3061)
1739 tensor(0.3169)
1740 tensor(0.3066)
1741 tensor(0.3088)
1742 tensor(0.2916)


2103 tensor(0.2844)
2104 tensor(0.2856)
2105 tensor(0.2877)
2106 tensor(0.2986)
2107 tensor(0.3062)
2108 tensor(0.2900)
2109 tensor(0.2892)
2110 tensor(0.2874)
2111 tensor(0.3155)
2112 tensor(0.3039)
2113 tensor(0.3014)
2114 tensor(0.2898)
2115 tensor(0.2914)
2116 tensor(0.2972)
2117 tensor(0.3040)
2118 tensor(0.2838)
2119 tensor(0.3080)
2120 tensor(0.3022)
2121 tensor(0.2910)
2122 tensor(0.2802)
2123 tensor(0.2870)
2124 tensor(0.2783)
2125 tensor(0.3079)
2126 tensor(0.2967)
2127 tensor(0.3009)
2128 tensor(0.2807)
2129 tensor(0.3009)
2130 tensor(0.2882)
2131 tensor(0.2986)
2132 tensor(0.2892)
2133 tensor(0.2837)
2134 tensor(0.2961)
2135 tensor(0.2843)
2136 tensor(0.3162)
2137 tensor(0.3087)
2138 tensor(0.3022)
2139 tensor(0.2834)
2140 tensor(0.2969)
2141 tensor(0.2860)
2142 tensor(0.2837)
2143 tensor(0.2908)
2144 tensor(0.3131)
2145 tensor(0.3013)
2146 tensor(0.2936)
2147 tensor(0.3159)
2148 tensor(0.3029)
2149 tensor(0.2964)
2150 tensor(0.3060)
2151 tensor(0.2943)
2152 tensor(0.2981)


2513 tensor(0.3147)
2514 tensor(0.3100)
2515 tensor(0.3139)
2516 tensor(0.2892)
2517 tensor(0.2959)
2518 tensor(0.3185)
2519 tensor(0.3053)
2520 tensor(0.2926)
2521 tensor(0.2796)
2522 tensor(0.2813)
2523 tensor(0.2776)
2524 tensor(0.2946)
2525 tensor(0.3053)
2526 tensor(0.2794)
2527 tensor(0.2858)
2528 tensor(0.3029)
2529 tensor(0.3054)
2530 tensor(0.3183)
2531 tensor(0.2967)
2532 tensor(0.3018)
2533 tensor(0.3018)
2534 tensor(0.2978)
2535 tensor(0.3086)
2536 tensor(0.2726)
2537 tensor(0.3273)
2538 tensor(0.2977)
2539 tensor(0.3005)
2540 tensor(0.2866)
2541 tensor(0.3041)
2542 tensor(0.3130)
2543 tensor(0.2823)
2544 tensor(0.2889)
2545 tensor(0.3022)
2546 tensor(0.2862)
2547 tensor(0.3048)
2548 tensor(0.2962)
2549 tensor(0.2998)
2550 tensor(0.2940)
2551 tensor(0.3167)
2552 tensor(0.3055)
2553 tensor(0.2913)
2554 tensor(0.2985)
2555 tensor(0.3031)
2556 tensor(0.2796)
2557 tensor(0.2835)
2558 tensor(0.2889)
2559 tensor(0.2958)
2560 tensor(0.2678)
2561 tensor(0.2797)
2562 tensor(0.2980)


2923 tensor(0.2774)
2924 tensor(0.2911)
2925 tensor(0.2987)
2926 tensor(0.2887)
2927 tensor(0.3029)
2928 tensor(0.2967)
2929 tensor(0.2968)
2930 tensor(0.2913)
2931 tensor(0.2882)
2932 tensor(0.2900)
2933 tensor(0.2795)
2934 tensor(0.2849)
2935 tensor(0.2979)
2936 tensor(0.2993)
2937 tensor(0.3015)
2938 tensor(0.2880)
2939 tensor(0.3040)
2940 tensor(0.2845)
2941 tensor(0.2636)
2942 tensor(0.3074)
2943 tensor(0.2965)
2944 tensor(0.3055)
2945 tensor(0.2848)
2946 tensor(0.3132)
2947 tensor(0.2857)
2948 tensor(0.2722)
2949 tensor(0.3064)
2950 tensor(0.2759)
2951 tensor(0.2851)
2952 tensor(0.3034)
2953 tensor(0.2680)
2954 tensor(0.3058)
2955 tensor(0.3089)
2956 tensor(0.2898)
2957 tensor(0.3285)
2958 tensor(0.2813)
2959 tensor(0.3041)
2960 tensor(0.2826)
2961 tensor(0.2829)
2962 tensor(0.2633)
2963 tensor(0.2700)
2964 tensor(0.2793)
2965 tensor(0.3043)
2966 tensor(0.2849)
2967 tensor(0.2970)
2968 tensor(0.2943)
2969 tensor(0.2968)
2970 tensor(0.3073)
2971 tensor(0.2914)
2972 tensor(0.2951)


3333 tensor(0.2985)
3334 tensor(0.2794)
3335 tensor(0.2653)
3336 tensor(0.2749)
3337 tensor(0.2940)
3338 tensor(0.2809)
3339 tensor(0.2806)
3340 tensor(0.3058)
3341 tensor(0.2871)
3342 tensor(0.2782)
3343 tensor(0.3159)
3344 tensor(0.2782)
3345 tensor(0.3008)
3346 tensor(0.2808)
3347 tensor(0.2969)
3348 tensor(0.2856)
3349 tensor(0.2810)
3350 tensor(0.2938)
3351 tensor(0.2834)
3352 tensor(0.2901)
3353 tensor(0.2894)
3354 tensor(0.2814)
3355 tensor(0.2782)
3356 tensor(0.2880)
3357 tensor(0.3012)
3358 tensor(0.2908)
3359 tensor(0.2870)
3360 tensor(0.2975)
3361 tensor(0.2962)
3362 tensor(0.2716)
3363 tensor(0.3097)
3364 tensor(0.2819)
3365 tensor(0.2733)
3366 tensor(0.2828)
3367 tensor(0.2762)
3368 tensor(0.2569)
3369 tensor(0.3011)
3370 tensor(0.2770)
3371 tensor(0.2658)
3372 tensor(0.2957)
3373 tensor(0.2542)
3374 tensor(0.2789)
3375 tensor(0.2859)
3376 tensor(0.2878)
3377 tensor(0.2759)
3378 tensor(0.2848)
3379 tensor(0.2845)
3380 tensor(0.2936)
3381 tensor(0.2754)
3382 tensor(0.2996)


3743 tensor(0.2902)
3744 tensor(0.2855)
3745 tensor(0.2636)
3746 tensor(0.2534)
3747 tensor(0.3181)
3748 tensor(0.2927)
3749 tensor(0.2755)
3750 tensor(0.2832)
3751 tensor(0.2733)
3752 tensor(0.2943)
3753 tensor(0.2878)
3754 tensor(0.2952)
3755 tensor(0.2950)
3756 tensor(0.2915)
3757 tensor(0.2824)
3758 tensor(0.2884)
3759 tensor(0.2715)
3760 tensor(0.2719)
3761 tensor(0.2927)
3762 tensor(0.2827)
3763 tensor(0.2642)
3764 tensor(0.2845)
3765 tensor(0.2954)
3766 tensor(0.2729)
3767 tensor(0.2910)
3768 tensor(0.2743)
3769 tensor(0.2808)
3770 tensor(0.3036)
3771 tensor(0.2783)
3772 tensor(0.2829)
3773 tensor(0.2982)
3774 tensor(0.2872)
3775 tensor(0.2890)
3776 tensor(0.2994)
3777 tensor(0.2910)
3778 tensor(0.2910)
3779 tensor(0.2861)
3780 tensor(0.2879)
3781 tensor(0.2820)
3782 tensor(0.2860)
3783 tensor(0.2776)
3784 tensor(0.2850)
3785 tensor(0.2821)
3786 tensor(0.2730)
3787 tensor(0.3147)
3788 tensor(0.2658)
3789 tensor(0.2889)
3790 tensor(0.2814)
3791 tensor(0.2860)
3792 tensor(0.2856)


4153 tensor(0.2911)
4154 tensor(0.2932)
4155 tensor(0.2688)
4156 tensor(0.2794)
4157 tensor(0.2611)
4158 tensor(0.2952)
4159 tensor(0.2756)
4160 tensor(0.2987)
4161 tensor(0.2652)
4162 tensor(0.2656)
4163 tensor(0.3007)
4164 tensor(0.2684)
4165 tensor(0.2859)
4166 tensor(0.2744)
4167 tensor(0.2916)
4168 tensor(0.2906)
4169 tensor(0.2784)
4170 tensor(0.2800)
4171 tensor(0.2860)
4172 tensor(0.2820)
4173 tensor(0.2927)
4174 tensor(0.2930)
4175 tensor(0.2978)
4176 tensor(0.2833)
4177 tensor(0.2595)
4178 tensor(0.2896)
4179 tensor(0.2700)
4180 tensor(0.2863)
4181 tensor(0.2880)
4182 tensor(0.2940)
4183 tensor(0.2949)
4184 tensor(0.2808)
4185 tensor(0.2776)
4186 tensor(0.2720)
4187 tensor(0.3078)
4188 tensor(0.2879)
4189 tensor(0.2849)
4190 tensor(0.2546)
4191 tensor(0.2927)
4192 tensor(0.2970)
4193 tensor(0.2850)
4194 tensor(0.2740)
4195 tensor(0.2684)
4196 tensor(0.2836)
4197 tensor(0.2912)
4198 tensor(0.2832)
4199 tensor(0.3009)
4200 tensor(0.2869)
4201 tensor(0.2647)
4202 tensor(0.2633)


4563 tensor(0.2863)
4564 tensor(0.2799)
4565 tensor(0.2786)
4566 tensor(0.2789)
4567 tensor(0.2949)
4568 tensor(0.2833)
4569 tensor(0.2781)
4570 tensor(0.3059)
4571 tensor(0.2744)
4572 tensor(0.2865)
4573 tensor(0.2966)
4574 tensor(0.2872)
4575 tensor(0.2755)
4576 tensor(0.2554)
4577 tensor(0.2780)
4578 tensor(0.2782)
4579 tensor(0.2645)
4580 tensor(0.2793)
4581 tensor(0.2689)
4582 tensor(0.3100)
4583 tensor(0.2972)
4584 tensor(0.2829)
4585 tensor(0.2967)
4586 tensor(0.2814)
4587 tensor(0.3018)
4588 tensor(0.2845)
4589 tensor(0.2733)
4590 tensor(0.2706)
4591 tensor(0.2733)
4592 tensor(0.2649)
4593 tensor(0.2748)
4594 tensor(0.2759)
4595 tensor(0.2763)
4596 tensor(0.2859)
4597 tensor(0.2819)
4598 tensor(0.2567)
4599 tensor(0.2785)
4600 tensor(0.3074)
4601 tensor(0.2798)
4602 tensor(0.2889)
4603 tensor(0.2831)
4604 tensor(0.2884)
4605 tensor(0.3049)
4606 tensor(0.2948)
4607 tensor(0.2686)
4608 tensor(0.2872)
4609 tensor(0.2734)
4610 tensor(0.2846)
4611 tensor(0.2739)
4612 tensor(0.2841)


4973 tensor(0.2814)
4974 tensor(0.2864)
4975 tensor(0.2804)
4976 tensor(0.2974)
4977 tensor(0.2905)
4978 tensor(0.2643)
4979 tensor(0.2792)
4980 tensor(0.2780)
4981 tensor(0.2857)
4982 tensor(0.2898)
4983 tensor(0.2648)
4984 tensor(0.2563)
4985 tensor(0.3098)
4986 tensor(0.2710)
4987 tensor(0.2831)
4988 tensor(0.2808)
4989 tensor(0.2654)
4990 tensor(0.2637)
4991 tensor(0.2915)
4992 tensor(0.2855)
4993 tensor(0.2994)
4994 tensor(0.2898)
4995 tensor(0.2855)
4996 tensor(0.2660)
4997 tensor(0.2696)
4998 tensor(0.2808)
4999 tensor(0.2911)
5000 tensor(0.2774)
5001 tensor(0.2881)
5002 tensor(0.2776)
5003 tensor(0.2940)
5004 tensor(0.2800)
5005 tensor(0.2954)
5006 tensor(0.2854)
5007 tensor(0.2779)
5008 tensor(0.3055)
5009 tensor(0.2719)
5010 tensor(0.2700)
5011 tensor(0.2910)
5012 tensor(0.2827)
5013 tensor(0.2837)
5014 tensor(0.2844)
5015 tensor(0.3028)
5016 tensor(0.2887)
5017 tensor(0.3032)
5018 tensor(0.2840)
5019 tensor(0.2766)
5020 tensor(0.2781)
5021 tensor(0.2825)
5022 tensor(0.2882)


5383 tensor(0.2764)
5384 tensor(0.2873)
5385 tensor(0.2677)
5386 tensor(0.2894)
5387 tensor(0.2539)
5388 tensor(0.2786)
5389 tensor(0.2877)
5390 tensor(0.2754)
5391 tensor(0.3041)
5392 tensor(0.2837)
5393 tensor(0.2941)
5394 tensor(0.2951)
5395 tensor(0.2694)
5396 tensor(0.2771)
5397 tensor(0.2888)
5398 tensor(0.2785)
5399 tensor(0.2638)
5400 tensor(0.2809)
5401 tensor(0.2663)
5402 tensor(0.2892)
5403 tensor(0.2817)
5404 tensor(0.2909)
5405 tensor(0.2745)
5406 tensor(0.2823)
5407 tensor(0.2843)
5408 tensor(0.2676)
5409 tensor(0.2796)
5410 tensor(0.2824)
5411 tensor(0.2723)
5412 tensor(0.2788)
5413 tensor(0.2778)
5414 tensor(0.2729)
5415 tensor(0.2793)
5416 tensor(0.2828)
5417 tensor(0.2771)
5418 tensor(0.2772)
5419 tensor(0.2819)
5420 tensor(0.2615)
5421 tensor(0.2657)
5422 tensor(0.2826)
5423 tensor(0.2874)
5424 tensor(0.2849)
5425 tensor(0.2689)
5426 tensor(0.2657)
5427 tensor(0.2724)
5428 tensor(0.2872)
5429 tensor(0.2828)
5430 tensor(0.2757)
5431 tensor(0.2689)
5432 tensor(0.2968)


5793 tensor(0.2932)
5794 tensor(0.2684)
5795 tensor(0.2772)
5796 tensor(0.2777)
5797 tensor(0.2779)
5798 tensor(0.2828)
5799 tensor(0.2794)
5800 tensor(0.2965)
5801 tensor(0.2712)
5802 tensor(0.2687)
5803 tensor(0.2716)
5804 tensor(0.2872)
5805 tensor(0.2715)
5806 tensor(0.2890)
5807 tensor(0.2696)
5808 tensor(0.2755)
5809 tensor(0.2939)
5810 tensor(0.2822)
5811 tensor(0.2680)
5812 tensor(0.2866)
5813 tensor(0.2703)
5814 tensor(0.2857)
5815 tensor(0.2694)
5816 tensor(0.2728)
5817 tensor(0.2839)
5818 tensor(0.2572)
5819 tensor(0.2725)
5820 tensor(0.2759)
5821 tensor(0.2742)
5822 tensor(0.2857)
5823 tensor(0.2963)
5824 tensor(0.2738)
5825 tensor(0.2986)
5826 tensor(0.2916)
5827 tensor(0.2664)
5828 tensor(0.2701)
5829 tensor(0.2693)
5830 tensor(0.2779)
5831 tensor(0.2839)
5832 tensor(0.2770)
5833 tensor(0.2755)
5834 tensor(0.2752)
5835 tensor(0.2875)
5836 tensor(0.2649)
5837 tensor(0.2484)
5838 tensor(0.2760)
5839 tensor(0.2901)
5840 tensor(0.2551)
5841 tensor(0.2845)
5842 tensor(0.2783)


KeyboardInterrupt: 