In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import struct
import numpy as np

In [2]:
feature_dim = 100;
categories = 200000;
nsamples = 2000;

In [43]:
with open('/private/home/qiantong/tmp/lm_logs/adsm_test/adsm_input_b.bin', 'rb') as f:
    data = f.read()
    h_input = struct.unpack('f' * feature_dim * nsamples, data)
    print(len(h_input))
    
with open('/private/home/qiantong/tmp/lm_logs/adsm_test/adsm_target_b.bin', 'rb') as f:
    data = f.read()
    h_target = struct.unpack('i' * nsamples, data)
    print(len(h_target))

input = torch.zeros((feature_dim * nsamples,), dtype=torch.float32)
input = input.new_tensor(h_input).reshape(nsamples, feature_dim)
print(input[0])
print(input[250])
print(input[350])

target = torch.zeros((nsamples,), dtype=torch.int32)
target = target.new_tensor(h_target).to(torch.int64)
print(target[:10])

200000
2000
tensor([0.8344, 0.5809, 0.4471, 0.6506, 0.8366, 0.7644, 0.7056, 0.3341, 0.7679,
        0.2162, 0.3955, 0.1304, 0.7658, 0.5979, 0.1689, 0.9357, 0.6382, 0.6445,
        0.8030, 0.0611, 0.5084, 0.8336, 0.8747, 0.5925, 0.2188, 0.2527, 0.2798,
        0.6364, 0.3410, 0.9844, 0.5321, 0.9699, 0.6615, 0.3889, 0.3945, 0.1447,
        0.2450, 0.0102, 0.1717, 0.7890, 0.1923, 0.5241, 0.0500, 0.3537, 0.8072,
        0.7991, 0.6464, 0.9184, 0.8881, 0.4741, 0.6913, 0.5796, 0.2668, 0.6899,
        0.2302, 0.6519, 0.2830, 0.2789, 0.4707, 0.8481, 0.2815, 0.8284, 0.2369,
        0.1835, 0.9035, 0.3367, 0.1337, 0.6613, 0.4437, 0.8792, 0.5249, 0.6371,
        0.5496, 0.3567, 0.2098, 0.3443, 0.6151, 0.3868, 0.9220, 0.5020, 0.4515,
        0.6307, 0.2177, 0.3458, 0.1999, 0.6734, 0.9900, 0.4931, 0.8068, 0.2970,
        0.3995, 0.0382, 0.4447, 0.9714, 0.1020, 0.1118, 0.5209, 0.5527, 0.3575,
        0.0543])
tensor([0.1119, 0.1876, 0.3776, 0.1045, 0.6737, 0.0850, 0.9660, 0.5560, 0.7579,
        0.7

In [110]:
class test_module(nn.Module):
    def __init__(self, feature_dim=100, categories=200000, cutoffs={10000, 50000}, load_init_params=True):
        super().__init__()
        self.adsm = nn.AdaptiveLogSoftmaxWithLoss(feature_dim, categories, cutoffs=cutoffs, div_value=4)
        self.lin1 = nn.Linear(feature_dim, feature_dim, False)
        self.lin2 = nn.Linear(feature_dim, feature_dim, False)
        
        if load_init_params:
            ## nn
            for i, param in enumerate(self.lin1.parameters()):
                with open('/private/home/qiantong/tmp/lm_logs/adsm_test/nn_param_0.bin'.format(i), 'rb') as f:
                    data = f.read()
                    param.data = param.data.new_tensor(struct.unpack('f' * param.numel(), data)).reshape(param.size(1), param.size(0))
                    param.data = param.data.transpose(0, 1)
                print(param.data[0])
                print(type(param.data), param.size())
                
            for i, param in enumerate(self.lin2.parameters()):
                with open('/private/home/qiantong/tmp/lm_logs/adsm_test/nn_param_1.bin'.format(i), 'rb') as f:
                    data = f.read()
                    param.data = param.data.new_tensor(struct.unpack('f' * param.numel(), data)).reshape(param.size(1), param.size(0))
                    param.data = param.data.transpose(0, 1)
                print(param.data[0])
                print(type(param.data), param.size())
                
            ## adsm
            for i, param in enumerate(self.adsm.parameters()):
                with open('/private/home/qiantong/tmp/lm_logs/adsm_test/adsm_param_{}.bin'.format(i), 'rb') as f:
                    data = f.read()
                    param.data = param.data.new_tensor(struct.unpack('f' * param.numel(), data)).reshape(param.size(1), param.size(0))
                    param.data = param.data.transpose(0, 1)
                print(param.data[0])
                print(type(param.data), param.size())
     
    def forward(self, x, y):
        res = self.lin1(x)
        res = F.relu(res)
        res = self.lin2(res)
        res = self.adsm(res, y)

        return res


In [113]:
adsm = test_module(load_init_params=False).cuda()
optimizer = optim.SGD(adsm.parameters(), lr=0.01)

print(input.shape)
print(target.shape)

input = input.cuda()
target = target.cuda()

for epoch in range(500):
    optimizer.zero_grad()
    loss = adsm(input, target)
    loss[1].backward()
    optimizer.step()
    print(loss[1].item())

torch.Size([2000, 100])
torch.Size([2000])
20.2977352142334
20.27950096130371
20.26127052307129
20.24294662475586
20.22442626953125
20.205629348754883
20.186460494995117
20.16681671142578
20.146608352661133
20.12571144104004
20.104005813598633
20.08135986328125
20.05766487121582
20.032773971557617
20.00652503967285
19.978734970092773
19.94921875
19.917741775512695
19.884084701538086
19.848007202148438
19.8092098236084
19.767379760742188
19.722148895263672
19.673126220703125
19.619869232177734
19.561861038208008
19.49851417541504
19.429166793823242
19.353044509887695
19.26930046081543
19.176939010620117
19.07480239868164
18.96152114868164
18.8355712890625
18.695125579833984
18.538036346435547
18.361814498901367
18.163557052612305
17.939800262451172
17.686458587646484
17.39871597290039
17.07103157043457
16.697227478027344
16.271121978759766
15.788512229919434
15.252924919128418
14.689804077148438
14.168183326721191
13.785895347595215
13.572054862976074
13.463436126708984
13.3992776870727

12.197430610656738
12.197376251220703
12.197324752807617
12.197272300720215
12.197218894958496
12.197166442871094
12.197113037109375
12.197061538696289
12.19700813293457
12.196955680847168
12.196903228759766
12.196849822998047
12.196797370910645
12.196745872497559
12.196693420410156
12.196640968322754
12.196588516235352
12.196537017822266
12.196483612060547
12.196431159973145
12.196379661560059
12.196327209472656
12.196274757385254
12.196223258972168
12.196171760559082
12.19611930847168
12.196066856384277
12.196015357971191
12.195963859558105
12.195910453796387
12.1958589553833
12.195807456970215
12.195754051208496
12.19570255279541
12.195651054382324
12.195600509643555
12.195547103881836
12.19549560546875
12.195444107055664
12.195393562316895
12.195341110229492
12.195289611816406
12.19523811340332
12.195185661315918
12.195135116577148
12.195083618164062
12.195032119750977
12.194979667663574
12.194928169250488
12.194877624511719
12.19482707977295
12.194774627685547
12.194723129272461
1