In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from collections import Counter
import numpy as np

In [None]:
WINDOW_SIZE, K = 3, 100
VOCAB_SIZE, IN_EMBED_SIZE, OUT_ENBED_SIZE = 30000, 100, 100 
BATCH_SIZE = 128

np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)

vocab_file = 'text8'
#prepare vocabulary
def tokenize(text):
    return text.split(' ')

text = None
with open(vocab_file) as fr:
    text = fr.readlines()[0]
    token_list = tokenize(text)
vocab =  Counter(token_list).most_common(VOCAB_SIZE - 1)
idx_to_word = [item[0] for item in vocab]
idx_to_word.append('UNK')
word_to_idx = {item: i for i, item in enumerate(idx_to_word)}
word_counts = [item[1] for item in vocab]
word_counts.append(len(text) - np.sum(word_counts))
frequence = word_counts / np.sum(word_counts)
frequence = frequence ** (3 / 4)
frequence = frequence / np.sum(frequence)

In [None]:
#data loader
class MyDataset(tud.Dataset):
    def __init__(self, text, idx_to_word, word_to_idx, WINDOW_SIZE, K, device, frequence):
        super(MyDataset, self).__init__()
        self.text = text
        self.idx_to_word = idx_to_word
        self.word_to_idx = word_to_idx
        self.window_size = WINDOW_SIZE
        self.k = K
        self.device = device
        self.frequence = torch.FloatTensor(frequence)
        self.word_encode = torch.LongTensor([self.word_to_idx.get(word, self.word_to_idx['UNK']) \
                                             for word in self.text])
        
        
    def __len__(self):
        return len(self.word_encode)
    
    def __getitem__(self, idx):
        center_word = self.word_encode[idx]
        pos_index = list(range(idx - self.window_size, idx)) + list(range(idx + 1, idx + self.window_size + 1))
        pos_index = [idx % len(self.word_encode) for idx in pos_index]
        pos_words = self.word_encode[pos_index]
        neg_words = torch.multinomial(self.frequence, self.k * self.window_size * 2, replacement = True)
        return center_word, pos_words, neg_words

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = MyDataset(text, idx_to_word, word_to_idx, WINDOW_SIZE, K, device, frequence)  
dataloader = tud.DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True, num_works = 4)  
for i, (center_word, pos_words, neg_words) in enumerate(dataloader):
    print("iter: {}".format(i))
    print(center_word.shape)
    print(pos_words.shape)
    print(neg_words.shape)
    break

In [123]:
#model
class MyModel(nn.Module):
    def __init__(self, VOCAB_SIZE, IN_EMBED_SIZE, OUT_ENBED_SIZE):
        super(MyModel, self).__init__()
        self.in_embedding = nn.Embedding(VOCAB_SIZE, IN_EMBED_SIZE)
        initrange = 0.5 / IN_EMBED_SIZE
        self.in_embedding.weight.data.uniform_(-initrange, initrange)
        self.out_embedding = nn.Embedding(VOCAB_SIZE, OUT_ENBED_SIZE)
        self.out_embedding.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, center_word, pos_words, neg_words):
        center_embeddings = self.in_embedding(center_word)  #[batch-size, embed_size]
        pos_embeddings = self.in_embedding(pos_words)      #[batch-size, window_size * 2, embed_size]
        neg_embeddings = self.out_embedding(neg_words)      #[batch-size, window_size * 2 * K, embed_size]
        center_embeddings = center_embeddings.unsqueeze(2)                     #[batch-size, embed_size, 1]
        
        pos_dot = pos_embeddings
        
        pos_dot = torch.bmm(pos_embeddings, center_embeddings).squeeze(2)  #[batch-size, window_size * 2]
        neg_dot = torch.bmm(neg_embeddings, -center_embeddings).squeeze(2)  #[batch-size, window_size * 2 * K]
#         print("shape of pos_dot: {}".format(pos_dot.shape))
#         print("shape of neg_dot: {}".format(neg_dot.shape))
        
        pos_loss = F.logsigmoid(pos_dot).sum(1)
        neg_loss = F.logsigmoid(neg_dot).sum(1)
        return - (pos_loss + neg_loss)

In [None]:
#train
model = MyModel(len(idx_to_word), IN_EMBED_SIZE, OUT_ENBED_SIZE)
learning_rate = 4e-4
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

for epoch in range(100):
    for iter, (center_word, pos_words, neg_words) in enumerate(dataloader):
        loss = model(center_word, pos_words, neg_words).mean()
        print('epoch: {}, iter: {}, loss: {}'.format(epoch, iter, loss))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

epoch: 0, iter: 0, loss: 420.04736328125
epoch: 0, iter: 1, loss: 420.03009033203125
epoch: 0, iter: 2, loss: 419.9935302734375
epoch: 0, iter: 3, loss: 419.9306945800781
epoch: 0, iter: 4, loss: 419.844970703125
epoch: 0, iter: 5, loss: 419.7273864746094
epoch: 0, iter: 6, loss: 419.57977294921875
epoch: 0, iter: 7, loss: 419.389892578125
epoch: 0, iter: 8, loss: 419.1412658691406
epoch: 0, iter: 9, loss: 418.9300231933594
epoch: 0, iter: 10, loss: 418.5442810058594
epoch: 0, iter: 11, loss: 418.2599182128906
epoch: 0, iter: 12, loss: 417.8762512207031
epoch: 0, iter: 13, loss: 417.38818359375
epoch: 0, iter: 14, loss: 416.8371276855469
epoch: 0, iter: 15, loss: 416.479736328125
epoch: 0, iter: 16, loss: 415.7598571777344
epoch: 0, iter: 17, loss: 415.1873474121094
epoch: 0, iter: 18, loss: 414.5743713378906
epoch: 0, iter: 19, loss: 413.8144836425781
epoch: 0, iter: 20, loss: 412.9191589355469
epoch: 0, iter: 21, loss: 411.9929504394531
epoch: 0, iter: 22, loss: 410.8089904785156
epo

epoch: 0, iter: 184, loss: 37.19917678833008
epoch: 0, iter: 185, loss: 39.419654846191406
epoch: 0, iter: 186, loss: 39.04886245727539
epoch: 0, iter: 187, loss: 37.275794982910156
epoch: 0, iter: 188, loss: 39.49872589111328
epoch: 0, iter: 189, loss: 36.0349235534668
epoch: 0, iter: 190, loss: 35.45886993408203
epoch: 0, iter: 191, loss: 35.364933013916016
epoch: 0, iter: 192, loss: 35.85647964477539
epoch: 0, iter: 193, loss: 34.69135284423828
epoch: 0, iter: 194, loss: 35.180606842041016
epoch: 0, iter: 195, loss: 37.652652740478516
epoch: 0, iter: 196, loss: 32.84576416015625
epoch: 0, iter: 197, loss: 35.96976852416992
epoch: 0, iter: 198, loss: 37.48291778564453
epoch: 0, iter: 199, loss: 30.677204132080078
epoch: 0, iter: 200, loss: 32.24811553955078
epoch: 0, iter: 201, loss: 27.56901741027832
epoch: 0, iter: 202, loss: 33.04693603515625
epoch: 0, iter: 203, loss: 31.832853317260742
epoch: 0, iter: 204, loss: 36.48822021484375
epoch: 0, iter: 205, loss: 30.751211166381836
epo

epoch: 0, iter: 365, loss: 8.239277839660645
epoch: 0, iter: 366, loss: 6.111988544464111
epoch: 0, iter: 367, loss: 7.9848480224609375
epoch: 0, iter: 368, loss: 7.52569580078125
epoch: 0, iter: 369, loss: 9.409832000732422
epoch: 0, iter: 370, loss: 7.406623363494873
epoch: 0, iter: 371, loss: 6.1284098625183105
epoch: 0, iter: 372, loss: 7.5131425857543945
epoch: 0, iter: 373, loss: 7.837387561798096
epoch: 0, iter: 374, loss: 8.782503128051758
epoch: 0, iter: 375, loss: 9.47379207611084
epoch: 0, iter: 376, loss: 8.982967376708984
epoch: 0, iter: 377, loss: 6.444701671600342
epoch: 0, iter: 378, loss: 6.314486980438232
epoch: 0, iter: 379, loss: 6.62692928314209
epoch: 0, iter: 380, loss: 6.396048545837402
epoch: 0, iter: 381, loss: 6.926065921783447
epoch: 0, iter: 382, loss: 6.265137672424316
epoch: 0, iter: 383, loss: 6.527756214141846
epoch: 0, iter: 384, loss: 6.232541561126709
epoch: 0, iter: 385, loss: 6.864943027496338
epoch: 0, iter: 386, loss: 7.641170024871826
epoch: 0, 

epoch: 0, iter: 547, loss: 3.0254158973693848
epoch: 0, iter: 548, loss: 2.7048699855804443
epoch: 0, iter: 549, loss: 3.013190984725952
epoch: 0, iter: 550, loss: 3.6170153617858887
epoch: 0, iter: 551, loss: 3.7059502601623535
epoch: 0, iter: 552, loss: 3.5170516967773438
epoch: 0, iter: 553, loss: 2.869431495666504
epoch: 0, iter: 554, loss: 2.9924416542053223
epoch: 0, iter: 555, loss: 2.855236053466797
epoch: 0, iter: 556, loss: 3.245734691619873
epoch: 0, iter: 557, loss: 3.768169403076172
epoch: 0, iter: 558, loss: 2.4542808532714844
epoch: 0, iter: 559, loss: 3.2376279830932617
epoch: 0, iter: 560, loss: 2.6136457920074463
epoch: 0, iter: 561, loss: 3.24263596534729
epoch: 0, iter: 562, loss: 3.4135239124298096
epoch: 0, iter: 563, loss: 3.3048596382141113
epoch: 0, iter: 564, loss: 3.497368574142456
epoch: 0, iter: 565, loss: 3.6534507274627686
epoch: 0, iter: 566, loss: 2.747440814971924
epoch: 0, iter: 567, loss: 3.797092914581299
epoch: 0, iter: 568, loss: 4.125815391540527

epoch: 0, iter: 727, loss: 1.5605430603027344
epoch: 0, iter: 728, loss: 1.6026798486709595
epoch: 0, iter: 729, loss: 1.9286208152770996
epoch: 0, iter: 730, loss: 1.6054809093475342
epoch: 0, iter: 731, loss: 2.0922303199768066
epoch: 0, iter: 732, loss: 1.9626744985580444
epoch: 0, iter: 733, loss: 1.4718313217163086
epoch: 0, iter: 734, loss: 1.9255881309509277
epoch: 0, iter: 735, loss: 1.9534367322921753
epoch: 0, iter: 736, loss: 1.731980323791504
epoch: 0, iter: 737, loss: 1.7605023384094238
epoch: 0, iter: 738, loss: 1.793337106704712
epoch: 0, iter: 739, loss: 1.9992640018463135
epoch: 0, iter: 740, loss: 1.8212709426879883
epoch: 0, iter: 741, loss: 1.7735066413879395
epoch: 0, iter: 742, loss: 2.0438787937164307
epoch: 0, iter: 743, loss: 1.6851727962493896
epoch: 0, iter: 744, loss: 1.647029161453247
epoch: 0, iter: 745, loss: 1.4580435752868652
epoch: 0, iter: 746, loss: 1.3663417100906372
epoch: 0, iter: 747, loss: 1.5175318717956543
epoch: 0, iter: 748, loss: 1.33078825

epoch: 0, iter: 907, loss: 0.9557017683982849
epoch: 0, iter: 908, loss: 1.0931652784347534
epoch: 0, iter: 909, loss: 0.9696844220161438
epoch: 0, iter: 910, loss: 1.5615962743759155
epoch: 0, iter: 911, loss: 1.0068817138671875
epoch: 0, iter: 912, loss: 1.1565732955932617
epoch: 0, iter: 913, loss: 1.1053777933120728
epoch: 0, iter: 914, loss: 1.0506956577301025
epoch: 0, iter: 915, loss: 1.0583138465881348
epoch: 0, iter: 916, loss: 1.116187334060669
epoch: 0, iter: 917, loss: 0.8794322609901428
epoch: 0, iter: 918, loss: 1.6545937061309814
epoch: 0, iter: 919, loss: 1.1023023128509521
epoch: 0, iter: 920, loss: 1.0429933071136475
epoch: 0, iter: 921, loss: 1.1140681505203247
epoch: 0, iter: 922, loss: 1.0475155115127563
epoch: 0, iter: 923, loss: 2.1443088054656982
epoch: 0, iter: 924, loss: 0.9740393161773682
epoch: 0, iter: 925, loss: 1.1580469608306885
epoch: 0, iter: 926, loss: 0.9893800020217896
epoch: 0, iter: 927, loss: 0.9631829261779785
epoch: 0, iter: 928, loss: 1.069795

epoch: 0, iter: 1084, loss: 0.6697788834571838
epoch: 0, iter: 1085, loss: 0.6412681937217712
epoch: 0, iter: 1086, loss: 0.6143359541893005
epoch: 0, iter: 1087, loss: 0.8032181262969971
epoch: 0, iter: 1088, loss: 0.7371881008148193
epoch: 0, iter: 1089, loss: 0.8272038102149963
epoch: 0, iter: 1090, loss: 0.6995515823364258
epoch: 0, iter: 1091, loss: 0.6262392997741699
epoch: 0, iter: 1092, loss: 0.633908212184906
epoch: 0, iter: 1093, loss: 0.7150279879570007
epoch: 0, iter: 1094, loss: 0.7426736950874329
epoch: 0, iter: 1095, loss: 0.7358941435813904
epoch: 0, iter: 1096, loss: 0.6709027290344238
epoch: 0, iter: 1097, loss: 0.6798449158668518
epoch: 0, iter: 1098, loss: 0.664645254611969
epoch: 0, iter: 1099, loss: 0.7002837657928467
epoch: 0, iter: 1100, loss: 0.7593224048614502
epoch: 0, iter: 1101, loss: 0.7267757654190063
epoch: 0, iter: 1102, loss: 0.7092708349227905
epoch: 0, iter: 1103, loss: 0.6480581760406494
epoch: 0, iter: 1104, loss: 1.1053706407546997
epoch: 0, iter:

epoch: 0, iter: 1259, loss: 0.6418747901916504
epoch: 0, iter: 1260, loss: 1.0477275848388672
epoch: 0, iter: 1261, loss: 0.5048621892929077
epoch: 0, iter: 1262, loss: 0.6484153866767883
epoch: 0, iter: 1263, loss: 0.44702568650245667
epoch: 0, iter: 1264, loss: 0.6784684062004089
epoch: 0, iter: 1265, loss: 0.5315065979957581
epoch: 0, iter: 1266, loss: 0.5235118865966797
epoch: 0, iter: 1267, loss: 0.5541567802429199
epoch: 0, iter: 1268, loss: 0.6272317171096802
epoch: 0, iter: 1269, loss: 0.49126145243644714
epoch: 0, iter: 1270, loss: 0.5456632375717163
epoch: 0, iter: 1271, loss: 0.575446367263794
epoch: 0, iter: 1272, loss: 0.4831530749797821
epoch: 0, iter: 1273, loss: 0.5969418883323669
epoch: 0, iter: 1274, loss: 0.5345850586891174
epoch: 0, iter: 1275, loss: 0.6409643888473511
epoch: 0, iter: 1276, loss: 0.46101105213165283
epoch: 0, iter: 1277, loss: 0.5690941214561462
epoch: 0, iter: 1278, loss: 0.5950028300285339
epoch: 0, iter: 1279, loss: 0.8833174705505371
epoch: 0, i

epoch: 0, iter: 1433, loss: 0.4738895297050476
epoch: 0, iter: 1434, loss: 0.41661375761032104
epoch: 0, iter: 1435, loss: 0.37646210193634033
epoch: 0, iter: 1436, loss: 0.40711888670921326
epoch: 0, iter: 1437, loss: 0.3963819146156311
epoch: 0, iter: 1438, loss: 0.39877861738204956
epoch: 0, iter: 1439, loss: 0.3447279632091522
epoch: 0, iter: 1440, loss: 0.4413313567638397
epoch: 0, iter: 1441, loss: 0.48978760838508606
epoch: 0, iter: 1442, loss: 0.44267967343330383
epoch: 0, iter: 1443, loss: 0.4235020577907562
epoch: 0, iter: 1444, loss: 0.38520368933677673
epoch: 0, iter: 1445, loss: 0.34896135330200195
epoch: 0, iter: 1446, loss: 0.41163697838783264
epoch: 0, iter: 1447, loss: 0.38085252046585083
epoch: 0, iter: 1448, loss: 0.36770057678222656
epoch: 0, iter: 1449, loss: 0.39913180470466614
epoch: 0, iter: 1450, loss: 0.3701694905757904
epoch: 0, iter: 1451, loss: 0.3757895231246948
epoch: 0, iter: 1452, loss: 0.4002091586589813
epoch: 0, iter: 1453, loss: 0.49900108575820923


epoch: 0, iter: 1606, loss: 0.5419074892997742
epoch: 0, iter: 1607, loss: 0.32170674204826355
epoch: 0, iter: 1608, loss: 0.274759978055954
epoch: 0, iter: 1609, loss: 0.2949628531932831
epoch: 0, iter: 1610, loss: 0.3002951145172119
epoch: 0, iter: 1611, loss: 0.30694109201431274
epoch: 0, iter: 1612, loss: 0.3399012088775635
epoch: 0, iter: 1613, loss: 0.3243033289909363
epoch: 0, iter: 1614, loss: 0.3379489779472351
epoch: 0, iter: 1615, loss: 0.3361891508102417
epoch: 0, iter: 1616, loss: 0.33204013109207153
epoch: 0, iter: 1617, loss: 0.3398498594760895
epoch: 0, iter: 1618, loss: 0.29907774925231934
epoch: 0, iter: 1619, loss: 0.3154378831386566
epoch: 0, iter: 1620, loss: 0.5234138369560242
epoch: 0, iter: 1621, loss: 0.3018214702606201
epoch: 0, iter: 1622, loss: 0.31964847445487976
epoch: 0, iter: 1623, loss: 0.2910546660423279
epoch: 0, iter: 1624, loss: 0.2875830829143524
epoch: 0, iter: 1625, loss: 0.3549130856990814
epoch: 0, iter: 1626, loss: 0.3287167549133301
epoch: 0,

epoch: 0, iter: 1780, loss: 0.22589385509490967
epoch: 0, iter: 1781, loss: 0.272408664226532
epoch: 0, iter: 1782, loss: 0.3294678330421448
epoch: 0, iter: 1783, loss: 0.2318638414144516
epoch: 0, iter: 1784, loss: 0.2541154623031616
epoch: 0, iter: 1785, loss: 0.3754962682723999
epoch: 0, iter: 1786, loss: 0.25633934140205383
epoch: 0, iter: 1787, loss: 0.3625313639640808
epoch: 0, iter: 1788, loss: 0.24927695095539093
epoch: 0, iter: 1789, loss: 0.23406365513801575
epoch: 0, iter: 1790, loss: 0.2587304413318634
epoch: 0, iter: 1791, loss: 0.347148597240448
epoch: 0, iter: 1792, loss: 0.3832854926586151
epoch: 0, iter: 1793, loss: 0.34095948934555054
epoch: 0, iter: 1794, loss: 0.2278563678264618
epoch: 0, iter: 1795, loss: 0.2568432688713074
epoch: 0, iter: 1796, loss: 0.2853279113769531
epoch: 0, iter: 1797, loss: 0.256742388010025
epoch: 0, iter: 1798, loss: 0.2775549292564392
epoch: 0, iter: 1799, loss: 0.282196581363678
epoch: 0, iter: 1800, loss: 0.23797446489334106
epoch: 0, i

epoch: 0, iter: 1953, loss: 0.2517950236797333
epoch: 0, iter: 1954, loss: 0.2133180946111679
epoch: 0, iter: 1955, loss: 0.2017853856086731
epoch: 0, iter: 1956, loss: 0.1983759105205536
epoch: 0, iter: 1957, loss: 0.18907296657562256
epoch: 0, iter: 1958, loss: 0.21254758536815643
epoch: 0, iter: 1959, loss: 0.23505693674087524
epoch: 0, iter: 1960, loss: 0.21736480295658112
epoch: 0, iter: 1961, loss: 0.20343263447284698
epoch: 0, iter: 1962, loss: 0.19786256551742554
epoch: 0, iter: 1963, loss: 0.21005186438560486
epoch: 0, iter: 1964, loss: 0.20812572538852692
epoch: 0, iter: 1965, loss: 0.1969226896762848
epoch: 0, iter: 1966, loss: 0.19521912932395935
epoch: 0, iter: 1967, loss: 0.20149008929729462
epoch: 0, iter: 1968, loss: 0.23890013992786407
epoch: 0, iter: 1969, loss: 0.19983282685279846
epoch: 0, iter: 1970, loss: 0.19823114573955536
epoch: 0, iter: 1971, loss: 0.20631633698940277
epoch: 0, iter: 1972, loss: 0.1944715827703476
epoch: 0, iter: 1973, loss: 0.2046265006065368

epoch: 0, iter: 2126, loss: 0.15080825984477997
epoch: 0, iter: 2127, loss: 0.26539885997772217
epoch: 0, iter: 2128, loss: 0.16173288226127625
epoch: 0, iter: 2129, loss: 0.14874771237373352
epoch: 0, iter: 2130, loss: 0.16871154308319092
epoch: 0, iter: 2131, loss: 0.147861048579216
epoch: 0, iter: 2132, loss: 0.19258254766464233
epoch: 0, iter: 2133, loss: 0.1570432037115097
epoch: 0, iter: 2134, loss: 0.16702505946159363
epoch: 0, iter: 2135, loss: 0.18794910609722137
epoch: 0, iter: 2136, loss: 0.14903898537158966
epoch: 0, iter: 2137, loss: 0.16638444364070892
epoch: 0, iter: 2138, loss: 0.1733226627111435
epoch: 0, iter: 2139, loss: 0.14892856776714325
epoch: 0, iter: 2140, loss: 0.2045583426952362
epoch: 0, iter: 2141, loss: 0.1666000336408615
epoch: 0, iter: 2142, loss: 0.19174957275390625
epoch: 0, iter: 2143, loss: 0.1731080859899521
epoch: 0, iter: 2144, loss: 0.19100166857242584
epoch: 0, iter: 2145, loss: 0.17028024792671204
epoch: 0, iter: 2146, loss: 0.18777143955230713

epoch: 0, iter: 2298, loss: 0.1504461169242859
epoch: 0, iter: 2299, loss: 0.14358776807785034
epoch: 0, iter: 2300, loss: 0.14688439667224884
epoch: 0, iter: 2301, loss: 0.1257040947675705
epoch: 0, iter: 2302, loss: 0.1401681751012802
epoch: 0, iter: 2303, loss: 0.1441364884376526
epoch: 0, iter: 2304, loss: 0.15279701352119446
epoch: 0, iter: 2305, loss: 0.1310625523328781
epoch: 0, iter: 2306, loss: 0.13116802275180817
epoch: 0, iter: 2307, loss: 0.15000078082084656
epoch: 0, iter: 2308, loss: 0.13591067492961884
epoch: 0, iter: 2309, loss: 0.17842070758342743
epoch: 0, iter: 2310, loss: 0.14761598408222198
epoch: 0, iter: 2311, loss: 0.1640254259109497
epoch: 0, iter: 2312, loss: 0.14323531091213226
epoch: 0, iter: 2313, loss: 0.14546340703964233
epoch: 0, iter: 2314, loss: 0.15518341958522797
epoch: 0, iter: 2315, loss: 0.1393776386976242
epoch: 0, iter: 2316, loss: 0.17258904874324799
epoch: 0, iter: 2317, loss: 0.1796731799840927
epoch: 0, iter: 2318, loss: 0.19725437462329865


epoch: 0, iter: 2470, loss: 0.12268492579460144
epoch: 0, iter: 2471, loss: 0.14485031366348267
epoch: 0, iter: 2472, loss: 0.12585479021072388
epoch: 0, iter: 2473, loss: 0.13030804693698883
epoch: 0, iter: 2474, loss: 0.14504007995128632
epoch: 0, iter: 2475, loss: 0.11583257466554642
epoch: 0, iter: 2476, loss: 0.13666772842407227
epoch: 0, iter: 2477, loss: 0.11570177227258682
epoch: 0, iter: 2478, loss: 0.12007030844688416
epoch: 0, iter: 2479, loss: 0.11892515420913696
epoch: 0, iter: 2480, loss: 0.12006991356611252
epoch: 0, iter: 2481, loss: 0.13235154747962952
epoch: 0, iter: 2482, loss: 0.13694362342357635
epoch: 0, iter: 2483, loss: 0.15284687280654907
epoch: 0, iter: 2484, loss: 0.10277731716632843
epoch: 0, iter: 2485, loss: 0.16014274954795837
epoch: 0, iter: 2486, loss: 0.19307804107666016
epoch: 0, iter: 2487, loss: 0.14678749442100525
epoch: 0, iter: 2488, loss: 0.11439687013626099
epoch: 0, iter: 2489, loss: 0.10022619366645813
epoch: 0, iter: 2490, loss: 0.1112251058

epoch: 0, iter: 2642, loss: 0.09537141770124435
epoch: 0, iter: 2643, loss: 0.1273989975452423
epoch: 0, iter: 2644, loss: 0.10569171607494354
epoch: 0, iter: 2645, loss: 0.10304170846939087
epoch: 0, iter: 2646, loss: 0.11724451929330826
epoch: 0, iter: 2647, loss: 0.12235330045223236
epoch: 0, iter: 2648, loss: 0.11430729925632477
epoch: 0, iter: 2649, loss: 0.17018923163414001
epoch: 0, iter: 2650, loss: 0.09418024122714996
epoch: 0, iter: 2651, loss: 0.10044238716363907
epoch: 0, iter: 2652, loss: 0.09164184331893921
epoch: 0, iter: 2653, loss: 0.1090991199016571
epoch: 0, iter: 2654, loss: 0.11988227069377899
epoch: 0, iter: 2655, loss: 0.11277785897254944
epoch: 0, iter: 2656, loss: 0.15359671413898468
epoch: 0, iter: 2657, loss: 0.09723848104476929
epoch: 0, iter: 2658, loss: 0.15780417621135712
epoch: 0, iter: 2659, loss: 0.14198695123195648
epoch: 0, iter: 2660, loss: 0.11705611646175385
epoch: 0, iter: 2661, loss: 0.1116577610373497
epoch: 0, iter: 2662, loss: 0.1169190183281

epoch: 0, iter: 2814, loss: 0.10095050930976868
epoch: 0, iter: 2815, loss: 0.08175211399793625
epoch: 0, iter: 2816, loss: 0.08822674304246902
epoch: 0, iter: 2817, loss: 0.08984629809856415
epoch: 0, iter: 2818, loss: 0.09824389964342117
epoch: 0, iter: 2819, loss: 0.1101008802652359
epoch: 0, iter: 2820, loss: 0.08602719753980637
epoch: 0, iter: 2821, loss: 0.10099086165428162
epoch: 0, iter: 2822, loss: 0.08886375278234482
epoch: 0, iter: 2823, loss: 0.07936549931764603
epoch: 0, iter: 2824, loss: 0.10007461905479431
epoch: 0, iter: 2825, loss: 0.08178184926509857
epoch: 0, iter: 2826, loss: 0.07520948350429535
epoch: 0, iter: 2827, loss: 0.09487920999526978
epoch: 0, iter: 2828, loss: 0.07941828668117523
epoch: 0, iter: 2829, loss: 0.08936485648155212
epoch: 0, iter: 2830, loss: 0.09205453097820282
epoch: 0, iter: 2831, loss: 0.08699432760477066
epoch: 0, iter: 2832, loss: 0.08516751229763031
epoch: 0, iter: 2833, loss: 0.09408657252788544
epoch: 0, iter: 2834, loss: 0.08940605074

epoch: 0, iter: 2986, loss: 0.10067273676395416
epoch: 0, iter: 2987, loss: 0.08826287090778351
epoch: 0, iter: 2988, loss: 0.07851304113864899
epoch: 0, iter: 2989, loss: 0.0873073935508728
epoch: 0, iter: 2990, loss: 0.07816165685653687
epoch: 0, iter: 2991, loss: 0.07346079498529434
epoch: 0, iter: 2992, loss: 0.06800998747348785
epoch: 0, iter: 2993, loss: 0.0805194079875946
epoch: 0, iter: 2994, loss: 0.07193790376186371
epoch: 0, iter: 2995, loss: 0.0758022889494896
epoch: 0, iter: 2996, loss: 0.08303914964199066
epoch: 0, iter: 2997, loss: 0.0719200149178505
epoch: 0, iter: 2998, loss: 0.14074978232383728
epoch: 0, iter: 2999, loss: 0.07324838638305664
epoch: 0, iter: 3000, loss: 0.06849701702594757
epoch: 0, iter: 3001, loss: 0.10140352696180344
epoch: 0, iter: 3002, loss: 0.07862315326929092
epoch: 0, iter: 3003, loss: 0.07817003130912781
epoch: 0, iter: 3004, loss: 0.08034372329711914
epoch: 0, iter: 3005, loss: 0.09397809207439423
epoch: 0, iter: 3006, loss: 0.09965851157903

epoch: 0, iter: 3158, loss: 0.05982523784041405
epoch: 0, iter: 3159, loss: 0.061764974147081375
epoch: 0, iter: 3160, loss: 0.07643433660268784
epoch: 0, iter: 3161, loss: 0.0614677332341671
epoch: 0, iter: 3162, loss: 0.06509963423013687
epoch: 0, iter: 3163, loss: 0.06424234807491302
epoch: 0, iter: 3164, loss: 0.06338534504175186
epoch: 0, iter: 3165, loss: 0.06161738187074661
epoch: 0, iter: 3166, loss: 0.070705845952034
epoch: 0, iter: 3167, loss: 0.0590163879096508
epoch: 0, iter: 3168, loss: 0.10113431513309479
epoch: 0, iter: 3169, loss: 0.05973955988883972
epoch: 0, iter: 3170, loss: 0.06734378635883331
epoch: 0, iter: 3171, loss: 0.05787801742553711
epoch: 0, iter: 3172, loss: 0.07172244042158127
epoch: 0, iter: 3173, loss: 0.0652005523443222
epoch: 0, iter: 3174, loss: 0.058586858212947845
epoch: 0, iter: 3175, loss: 0.06670983880758286
epoch: 0, iter: 3176, loss: 0.061688389629125595
epoch: 0, iter: 3177, loss: 0.06174998730421066
epoch: 0, iter: 3178, loss: 0.068441927433

epoch: 0, iter: 3329, loss: 0.0691838338971138
epoch: 0, iter: 3330, loss: 0.08603838086128235
epoch: 0, iter: 3331, loss: 0.05856762081384659
epoch: 0, iter: 3332, loss: 0.05887557938694954
epoch: 0, iter: 3333, loss: 0.0894056186079979
epoch: 0, iter: 3334, loss: 0.06383166462182999
epoch: 0, iter: 3335, loss: 0.0538632869720459
epoch: 0, iter: 3336, loss: 0.07447199523448944
epoch: 0, iter: 3337, loss: 0.06318673491477966
epoch: 0, iter: 3338, loss: 0.06150975823402405
epoch: 0, iter: 3339, loss: 0.06370987743139267
epoch: 0, iter: 3340, loss: 0.05465037748217583
epoch: 0, iter: 3341, loss: 0.058987077325582504
epoch: 0, iter: 3342, loss: 0.05468304455280304
epoch: 0, iter: 3343, loss: 0.05712208151817322
epoch: 0, iter: 3344, loss: 0.059612635523080826
epoch: 0, iter: 3345, loss: 0.05251115560531616
epoch: 0, iter: 3346, loss: 0.07712225615978241
epoch: 0, iter: 3347, loss: 0.05088004842400551
epoch: 0, iter: 3348, loss: 0.053971778601408005
epoch: 0, iter: 3349, loss: 0.0564714223

epoch: 0, iter: 3500, loss: 0.06098666042089462
epoch: 0, iter: 3501, loss: 0.05124189704656601
epoch: 0, iter: 3502, loss: 0.05899469181895256
epoch: 0, iter: 3503, loss: 0.0503072515130043
epoch: 0, iter: 3504, loss: 0.05547386407852173
epoch: 0, iter: 3505, loss: 0.0540146604180336
epoch: 0, iter: 3506, loss: 0.060511939227581024
epoch: 0, iter: 3507, loss: 0.04950427636504173
epoch: 0, iter: 3508, loss: 0.04403168708086014
epoch: 0, iter: 3509, loss: 0.04886170104146004
epoch: 0, iter: 3510, loss: 0.047394417226314545
epoch: 0, iter: 3511, loss: 0.04992057755589485
epoch: 0, iter: 3512, loss: 0.055168136954307556
epoch: 0, iter: 3513, loss: 0.09061471372842789
epoch: 0, iter: 3514, loss: 0.05057302862405777
epoch: 0, iter: 3515, loss: 0.04980648308992386
epoch: 0, iter: 3516, loss: 0.05458325147628784
epoch: 0, iter: 3517, loss: 0.05063014104962349
epoch: 0, iter: 3518, loss: 0.05122512951493263
epoch: 0, iter: 3519, loss: 0.055124666541814804
epoch: 0, iter: 3520, loss: 0.05362652

epoch: 0, iter: 3671, loss: 0.04602571204304695
epoch: 0, iter: 3672, loss: 0.07944484055042267
epoch: 0, iter: 3673, loss: 0.06471212208271027
epoch: 0, iter: 3674, loss: 0.047581564635038376
epoch: 0, iter: 3675, loss: 0.04938368499279022
epoch: 0, iter: 3676, loss: 0.0924224853515625
epoch: 0, iter: 3677, loss: 0.042067356407642365
epoch: 0, iter: 3678, loss: 0.05168309807777405
epoch: 0, iter: 3679, loss: 0.05037315934896469
epoch: 0, iter: 3680, loss: 0.05322495475411415
epoch: 0, iter: 3681, loss: 0.04834238439798355
epoch: 0, iter: 3682, loss: 0.04510461911559105
epoch: 0, iter: 3683, loss: 0.04038158804178238
epoch: 0, iter: 3684, loss: 0.03847889602184296
epoch: 0, iter: 3685, loss: 0.039296217262744904
epoch: 0, iter: 3686, loss: 0.039309266954660416
epoch: 0, iter: 3687, loss: 0.04375811666250229
epoch: 0, iter: 3688, loss: 0.044618479907512665
epoch: 0, iter: 3689, loss: 0.03968115895986557
epoch: 0, iter: 3690, loss: 0.0411873497068882
epoch: 0, iter: 3691, loss: 0.0572304

epoch: 0, iter: 3841, loss: 0.05889231339097023
epoch: 0, iter: 3842, loss: 0.03272898122668266
epoch: 0, iter: 3843, loss: 0.045416053384542465
epoch: 0, iter: 3844, loss: 0.03886474296450615
epoch: 0, iter: 3845, loss: 0.03677646443247795
epoch: 0, iter: 3846, loss: 0.039413899183273315
epoch: 0, iter: 3847, loss: 0.04179038852453232
epoch: 0, iter: 3848, loss: 0.03635450452566147
epoch: 0, iter: 3849, loss: 0.038480691611766815
epoch: 0, iter: 3850, loss: 0.045878417789936066
epoch: 0, iter: 3851, loss: 0.04837722331285477
epoch: 0, iter: 3852, loss: 0.05062216892838478
epoch: 0, iter: 3853, loss: 0.04736696928739548
epoch: 0, iter: 3854, loss: 0.04475318640470505
epoch: 0, iter: 3855, loss: 0.03882800415158272
epoch: 0, iter: 3856, loss: 0.04645353555679321
epoch: 0, iter: 3857, loss: 0.03561879321932793
epoch: 0, iter: 3858, loss: 0.044632986187934875
epoch: 0, iter: 3859, loss: 0.0419299453496933
epoch: 0, iter: 3860, loss: 0.050874870270490646
epoch: 0, iter: 3861, loss: 0.04176

epoch: 0, iter: 4011, loss: 0.04145403951406479
epoch: 0, iter: 4012, loss: 0.05950899422168732
epoch: 0, iter: 4013, loss: 0.045301467180252075
epoch: 0, iter: 4014, loss: 0.040846675634384155
epoch: 0, iter: 4015, loss: 0.03344467654824257
epoch: 0, iter: 4016, loss: 0.03446736931800842
epoch: 0, iter: 4017, loss: 0.03308536112308502
epoch: 0, iter: 4018, loss: 0.05282249301671982
epoch: 0, iter: 4019, loss: 0.032847821712493896
epoch: 0, iter: 4020, loss: 0.03490563482046127
epoch: 0, iter: 4021, loss: 0.03663443401455879
epoch: 0, iter: 4022, loss: 0.02999171055853367
epoch: 0, iter: 4023, loss: 0.03177998960018158
epoch: 0, iter: 4024, loss: 0.03334042429924011
epoch: 0, iter: 4025, loss: 0.030656980350613594
epoch: 0, iter: 4026, loss: 0.04254709556698799
epoch: 0, iter: 4027, loss: 0.03957657888531685
epoch: 0, iter: 4028, loss: 0.038009483367204666
epoch: 0, iter: 4029, loss: 0.03632767125964165
epoch: 0, iter: 4030, loss: 0.037944305688142776
epoch: 0, iter: 4031, loss: 0.0332

epoch: 0, iter: 4181, loss: 0.030674487352371216
epoch: 0, iter: 4182, loss: 0.032254815101623535
epoch: 0, iter: 4183, loss: 0.035850122570991516
epoch: 0, iter: 4184, loss: 0.0312664620578289
epoch: 0, iter: 4185, loss: 0.051362182945013046
epoch: 0, iter: 4186, loss: 0.029281210154294968
epoch: 0, iter: 4187, loss: 0.02958858385682106
epoch: 0, iter: 4188, loss: 0.02953571453690529
epoch: 0, iter: 4189, loss: 0.028133120387792587
epoch: 0, iter: 4190, loss: 0.04998622462153435
epoch: 0, iter: 4191, loss: 0.034053679555654526
epoch: 0, iter: 4192, loss: 0.028212204575538635
epoch: 0, iter: 4193, loss: 0.031126342713832855
epoch: 0, iter: 4194, loss: 0.03196826949715614
epoch: 0, iter: 4195, loss: 0.05845057591795921
epoch: 0, iter: 4196, loss: 0.036824166774749756
epoch: 0, iter: 4197, loss: 0.0318756103515625
epoch: 0, iter: 4198, loss: 0.03407017141580582
epoch: 0, iter: 4199, loss: 0.02867729961872101
epoch: 0, iter: 4200, loss: 0.03274432197213173
epoch: 0, iter: 4201, loss: 0.03

epoch: 0, iter: 4351, loss: 0.02927277982234955
epoch: 0, iter: 4352, loss: 0.04101858288049698
epoch: 0, iter: 4353, loss: 0.026149189099669456
epoch: 0, iter: 4354, loss: 0.03058532625436783
epoch: 0, iter: 4355, loss: 0.025981782004237175
epoch: 0, iter: 4356, loss: 0.025127246975898743
epoch: 0, iter: 4357, loss: 0.02841908484697342
epoch: 0, iter: 4358, loss: 0.028140531852841377
epoch: 0, iter: 4359, loss: 0.02877325750887394
epoch: 0, iter: 4360, loss: 0.030990565195679665
epoch: 0, iter: 4361, loss: 0.023151949048042297
epoch: 0, iter: 4362, loss: 0.026628490537405014
epoch: 0, iter: 4363, loss: 0.032148003578186035
epoch: 0, iter: 4364, loss: 0.029642999172210693
epoch: 0, iter: 4365, loss: 0.027427667751908302
epoch: 0, iter: 4366, loss: 0.025624819099903107
epoch: 0, iter: 4367, loss: 0.034097541123628616
epoch: 0, iter: 4368, loss: 0.030527789145708084
epoch: 0, iter: 4369, loss: 0.025506548583507538
epoch: 0, iter: 4370, loss: 0.027124356478452682
epoch: 0, iter: 4371, los

epoch: 0, iter: 4520, loss: 0.022720027714967728
epoch: 0, iter: 4521, loss: 0.035046227276325226
epoch: 0, iter: 4522, loss: 0.03543218597769737
epoch: 0, iter: 4523, loss: 0.023040499538183212
epoch: 0, iter: 4524, loss: 0.02728530392050743
epoch: 0, iter: 4525, loss: 0.023353390395641327
epoch: 0, iter: 4526, loss: 0.02106468938291073
epoch: 0, iter: 4527, loss: 0.02714882418513298
epoch: 0, iter: 4528, loss: 0.027363713830709457
epoch: 0, iter: 4529, loss: 0.023699777200818062
epoch: 0, iter: 4530, loss: 0.029461462050676346
epoch: 0, iter: 4531, loss: 0.027810022234916687
epoch: 0, iter: 4532, loss: 0.023119555786252022
epoch: 0, iter: 4533, loss: 0.028459277004003525
epoch: 0, iter: 4534, loss: 0.025952458381652832
epoch: 0, iter: 4535, loss: 0.024178223684430122
epoch: 0, iter: 4536, loss: 0.027181390672922134
epoch: 0, iter: 4537, loss: 0.023732857778668404
epoch: 0, iter: 4538, loss: 0.03658424690365791
epoch: 0, iter: 4539, loss: 0.02689170092344284
epoch: 0, iter: 4540, loss

epoch: 0, iter: 4689, loss: 0.023363439366221428
epoch: 0, iter: 4690, loss: 0.02483832836151123
epoch: 0, iter: 4691, loss: 0.02563631907105446
epoch: 0, iter: 4692, loss: 0.023706655949354172
epoch: 0, iter: 4693, loss: 0.021381964907050133
epoch: 0, iter: 4694, loss: 0.034299034625291824
epoch: 0, iter: 4695, loss: 0.023337818682193756
epoch: 0, iter: 4696, loss: 0.025489887222647667
epoch: 0, iter: 4697, loss: 0.023217713460326195
epoch: 0, iter: 4698, loss: 0.020878173410892487
epoch: 0, iter: 4699, loss: 0.020464975386857986
epoch: 0, iter: 4700, loss: 0.024138960987329483
epoch: 0, iter: 4701, loss: 0.02485448122024536
epoch: 0, iter: 4702, loss: 0.03061930648982525
epoch: 0, iter: 4703, loss: 0.019500527530908585
epoch: 0, iter: 4704, loss: 0.021148601546883583
epoch: 0, iter: 4705, loss: 0.02011759579181671
epoch: 0, iter: 4706, loss: 0.02045944519340992
epoch: 0, iter: 4707, loss: 0.023330068215727806
epoch: 0, iter: 4708, loss: 0.02150612510740757
epoch: 0, iter: 4709, loss: