# Dataset

In [1]:
import torch

class EasyAcc:
    def __init__(self):
        self.n = 0
        self.sum = 0
        self.sumsq = 0

    def __iadd__(self, other):
        self.n += 1
        self.sum += other
        self.sumsq += other*other
        return self

    def __isub__(self, other):
        self.n += 1
        self.sum -= other
        self.sumsq += other*other
        return self

    def mean(self):
        return self.sum / max(self.n, 1)

    def var(self):
        from math import sqrt
        return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)

    def semean(self):
        from math import sqrt
        return self.var() / sqrt(max(self.n, 1))

def categoryCount():
    from collections import defaultdict
    import gzip
    import json
        
    counts = {}

    with gzip.open('entityfreq.gz', 'rt') as f:
        for line in f:
            try:
                freq, entity = line.strip().split()
            except:
                continue
            counts[entity] = int(freq)
            
    return counts

def getCategories(threshold):
    from sentence_transformers import SentenceTransformer
    import gzip
    import json
    import re
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
        
    for entity, freq in categoryCount().items():
        if freq >= threshold:
            niceentity = re.sub(r'_', r' ', entity)
            embedcat = model.encode([niceentity])[0]
            yield entity, embedcat

def datasetStats(threshold):
    numclasses = len([ entity for entity, freq in categoryCount().items() if freq >= threshold ])
    return { 'numclasses': numclasses, 'numexamples': threshold * numclasses }
            
def makeData(threshold, categories):
    from collections import defaultdict
    from sentence_transformers import SentenceTransformer
    import json
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
    catcount = defaultdict(int)
    
    with open('shuffled_dedup_entities.tsv') as f:
        batchline, batchencode, batchentity = [], [], []
        for line in f:
            try:
                entity, pre, mention, post = line.strip().split('\t')
            except:
                continue
                
            if entity in categories and catcount[entity] < threshold:
                catcount[entity] += 1
                batchline.append(line)
                batchencode.append(pre)
                batchencode.append(post)
                batchentity.append(entity)

                if len(batchline) == 5:
                    embed = model.encode(batchencode)

                    for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                        embedpre, embedpost = embed[2*n], embed[2*n+1]
                        entityord, entityvec = categories[entity]
                        yield { 'line': line, 
                                'entityord': entityord, 
                                'entityvec': entityvec,
                                'pre': embedpre, 
                                'post': embedpost }

                    batchline, batchencode, batchentity = [], [], []
                
        if len(batchline):
            embed = model.encode(batchencode)

            for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                embedpre, embedpost = embed[2*n], embed[2*n+1]
                entityord, entityvec = categories[entity]
                yield { 'line': line, 
                        'entityord': entityord, 
                        'entityvec': entityvec,
                        'pre': embedpre, 
                        'post': embedpost }

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, threshold):
        from tqdm.notebook import tqdm
        self.labelfeats = { k: (n, v) for n, (k, v) in enumerate(getCategories(threshold)) } 
        Xs = []
        ys = []
        for n, what in tqdm(enumerate(makeData(threshold, self.labelfeats))):
#             if n >= 1000:
#                 break
            pre = torch.tensor(what['pre'])
            post = torch.tensor(what['post'])
            Xs.append(torch.cat((pre, post)).unsqueeze(0))
            ys.append(what['entityord'])

        self.Xs = torch.cat(Xs, dim=0)
        self.ys = torch.LongTensor(ys)
            
    def __len__(self):
        return self.Xs.shape[0]

    def __getitem__(self, index):
        # Select sample
        return self.Xs[index], self.ys[index]

In [2]:
datasetStats(2000), datasetStats(1000), datasetStats(200)

({'numclasses': 311, 'numexamples': 622000},
 {'numclasses': 1154, 'numexamples': 1154000},
 {'numclasses': 14031, 'numexamples': 2806200})

## This takes time, run once only (days)

In [3]:
def makeMyDataset(threshold):
    import gzip
    
    foo = MyDataset(threshold)
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'wb') as handle:
        import pickle
        pickle.dump(foo, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
makeMyDataset(2000)

0it [00:00, ?it/s]

## Load cached processed data

In [2]:
def loadMyDataset(threshold):
    import gzip
    
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'rb') as handle:
        import pickle
        return pickle.load(handle)

In [3]:
# best constant predictor
# if you don't beat this, you have a problem

def bestconstant(threshold):
    from math import fsum
    
    counts = { k: threshold for k, v in categoryCount().items() if v >= threshold }
    sumcounts = fsum(v for v in counts.values())
    predict = torch.Tensor([ v / sumcounts for v in counts.values() ]).unsqueeze(0)
    log_loss = torch.nn.CrossEntropyLoss()
    sumloss, denom = EasyAcc(), 0
    
    for m, k in enumerate(counts.keys()):
        n = counts[k]
        actual = torch.LongTensor([m])
        sumloss += n * log_loss(predict, actual).item()
        denom += n
    
    return { 'best_constant_answer': max((v, k) for k, v in counts.items())[1], 
             'best_constant_average_logloss': sumloss.sum / denom,
             'best_constant_average_accuracy': max(v for v in counts.values()) / denom }            

bestconstant(2000), bestconstant(1000), bestconstant(200)

({'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 5.739792823791504,
  'best_constant_average_accuracy': 0.003215434083601286},
 {'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 7.050989627838135,
  'best_constant_average_accuracy': 0.0008665511265164644},
 {'best_constant_answer': 'weight_gain',
  'best_constant_average_logloss': 9.54902458190918,
  'best_constant_average_accuracy': 7.127075760815338e-05})

In [34]:
class Bilinear(torch.nn.Module):
    def __init__(self, dobs, daction, naction, device):
        super(Bilinear, self).__init__()
        
        self.W = torch.nn.Parameter(torch.zeros(dobs, daction, device=device))
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, Xs, Zs):
        return torch.matmul(torch.matmul(Xs, self.W), Zs.T)
        
    def preq1(self, logits):
        return self.sigmoid(logits)

class IGW(object):
    def __init__(self, gamma):
        super(IGW, self).__init__()
        
        self.gamma = gamma
    
    def sample(self, fhat):
        N, K = fhat.shape
        rando = torch.randint(high=K, size=(N, 1), device=fhat.device)
        fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)
        fhatrando = torch.gather(input=fhat, dim=1, index=rando)
        probs = K / (K + self.gamma * (fhatstar - fhatrando))
        unif = torch.rand(size=(N, 1), device=fhat.device)
        shouldexplore = (unif <= probs).long()
        return (ahatstar + shouldexplore * (rando - ahatstar)).squeeze(1)

def learnOnline(dataset, seed=4545, rank=None, initlr=4e-1, tzero=100000, gamma=2000, batch_size=32, cuda=False):
    import time
    
    torch.manual_seed(seed)
    labelfeatsdict = { n: v for n, v in dataset.labelfeats.values() }
    labelfeats = [ torch.tensor(labelfeatsdict[n]).float().unsqueeze(0) for n in range(len(labelfeatsdict)) ]
    Zs = torch.cat(labelfeats, dim=0)
    
    if cuda:
        Zs = Zs.cuda()
    
    if rank is not None:
        with torch.no_grad():
            U, S, Vh = torch.linalg.svd(Zs, full_matrices=False)
            Zs = U[:, :rank] @ torch.diag(S[:rank])
    
    generator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    model = None
    log_loss = torch.nn.BCEWithLogitsLoss()
        
    print('{:<5s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}'.format(
            'n', 'loss', 'since last', 'acc', 'since last', 'reward', 'since last', 'dt (sec)'), 
          flush=True)
    avloss, sincelast, acc, accsincelast, avreward, rewardsincelast = [ EasyAcc() for _ in range(6) ]
    
    for bno, (Xs, ys) in enumerate(generator):
        Xs, ys = Xs.to(Zs.device), ys.to(Zs.device)
        
        if model is None:
            import numpy as np
            model = Bilinear(dobs=Xs.shape[1], daction=Zs.shape[1], naction=Zs.shape[0], device=Zs.device)
            sampler = IGW(gamma=gamma)
            opt = torch.optim.Adam(( p for p in model.parameters() if p.requires_grad ), lr=initlr)
            scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda = lambda t: np.sqrt(tzero) / np.sqrt(tzero + t))
            start = time.time()
            
        with torch.no_grad():
            fhat = model.preq1(model.forward(0.0001 * Xs, Zs))
            sample = sampler.sample(fhat)
            reward = (sample == ys).unsqueeze(1).float()
            
        opt.zero_grad()
        logit = model.forward(0.0001 * Xs, Zs)
        samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)
        loss = log_loss(samplelogit, reward)
        loss.backward()
        opt.step()
        scheduler.step()
        
        with torch.no_grad():
            pred = torch.argmax(logit, dim=1)
            acc += torch.mean((pred == ys).float())
            accsincelast += torch.mean((pred == ys).float())
            avloss += loss
            sincelast += loss
            avreward += torch.mean(reward)
            rewardsincelast += torch.mean(reward)

        if bno & (bno - 1) == 0:
            now = time.time()
            print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(
                    avloss.n, avloss.mean(), sincelast.mean(), acc.mean(), 
                    accsincelast.mean(), avreward.mean(), rewardsincelast.mean(),
                    now - start),
                  flush=True)
            sincelast, accsincelast, rewardsincelast = [ EasyAcc() for _ in range(3) ]

    now = time.time()
    print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(
            avloss.n, avloss.mean(), sincelast.mean(), acc.mean(), 
            accsincelast.mean(), avreward.mean(), rewardsincelast.mean(),
            now - start),
          flush=True)

In [4]:
mydata = loadMyDataset(2000)

In [36]:
learnOnline(mydata, initlr=0.33, tzero=100000, rank=50, gamma=14000)

n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00200   
2    	0.69364   	0.69414   	0.01562   	0.00000   	0.00000   	0.00000   	0.00672   
3    	0.68035   	0.65375   	0.01042   	0.00000   	0.00000   	0.00000   	0.01054   
5    	0.65557   	0.61842   	0.00625   	0.00000   	0.00000   	0.00000   	0.01973   
9    	0.55516   	0.42964   	0.00347   	0.00000   	0.00000   	0.00000   	0.02972   
17   	0.40323   	0.23231   	0.00368   	0.00391   	0.00184   	0.00391   	0.04770   
33   	0.24758   	0.08219   	0.00284   	0.00195   	0.00189   	0.00195   	0.08209   
65   	0.14187   	0.03286   	0.00192   	0.00098   	0.00192   	0.00195   	0.15620   
129  	0.08558   	0.02841   	0.00412   	0.00635   	0.00266   	0.00342   	0.28595   
257  	0.05992   	0.03406   	0.00596   	0.00781   	0.00413   	0.00562   	0.57432   
513  	0.04637   	0.03276   	0.00816   	0.01038   	0.00530   	0.00647   	1.12511   
1025

In [5]:
def flass():
    import random
    
    for initlr, tzero, gamma in ( ( 0.05 + 0.8 * random.random(),
                                    10000 + 300000 * random.random(),
							        1000 + 15000 * random.random(),
                                  )
                                  for _ in range(100)
                                ):
        print(initlr, tzero, gamma)
        learnOnline(mydata, initlr=initlr, tzero=tzero, gamma=gamma, rank=50)

flass()

0.5954897083423462 230667.97077365717 7377.412444397972
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.93971   
2    	0.70245   	0.71176   	0.01562   	0.00000   	0.00000   	0.00000   	0.94560   
3    	0.67577   	0.62240   	0.02083   	0.03125   	0.00000   	0.00000   	0.95080   
5    	0.59392   	0.47114   	0.01250   	0.00000   	0.00000   	0.00000   	0.96053   
9    	0.45543   	0.28232   	0.01042   	0.00781   	0.00347   	0.00781   	0.97273   
17   	0.28186   	0.08660   	0.00551   	0.00000   	0.00184   	0.00000   	0.99807   
33   	0.16711   	0.04518   	0.00379   	0.00195   	0.00379   	0.00586   	1.03218   
65   	0.10073   	0.03227   	0.00481   	0.00586   	0.00433   	0.00488   	1.09917   
129  	0.06032   	0.01928   	0.00581   	0.00684   	0.00339   	0.00244   	1.22759   
257  	0.05120   	0.04201   	0.00827   	0.01074   	0.00608   	0.00879   	1.48161   
513  	0.04737   	0.04352   	0.0

9    	0.59392   	0.53005   	0.00347   	0.00000   	0.00000   	0.00000   	0.04834   
17   	0.50621   	0.40754   	0.00551   	0.00781   	0.00368   	0.00781   	0.07476   
33   	0.37965   	0.24518   	0.00473   	0.00391   	0.00284   	0.00195   	0.11774   
65   	0.25389   	0.12421   	0.00769   	0.01074   	0.00481   	0.00684   	0.18458   
129  	0.15429   	0.05314   	0.00775   	0.00781   	0.00388   	0.00293   	0.31255   
257  	0.10043   	0.04614   	0.00936   	0.01099   	0.00596   	0.00806   	0.56050   
513  	0.06686   	0.03316   	0.01011   	0.01086   	0.00646   	0.00696   	1.08498   
1025 	0.04735   	0.02781   	0.01274   	0.01538   	0.00677   	0.00708   	2.09590   
2049 	0.04292   	0.03848   	0.01752   	0.02231   	0.00898   	0.01120   	4.33965   
4097 	0.05717   	0.07143   	0.03305   	0.04858   	0.01745   	0.02592   	8.90494   
8193 	0.09294   	0.12872   	0.06460   	0.09616   	0.03725   	0.05706   	18.60551  
16385	0.14670   	0.20046   	0.10929   	0.15398   	0.06948   	0.10170   	35.07265  
1943

513  	0.04401   	0.04087   	0.01194   	0.01672   	0.00731   	0.00952   	1.04888   
1025 	0.03943   	0.03485   	0.01607   	0.02020   	0.00851   	0.00970   	2.09503   
2049 	0.04247   	0.04551   	0.02259   	0.02911   	0.01150   	0.01450   	4.25510   
4097 	0.05225   	0.06204   	0.03384   	0.04510   	0.01687   	0.02225   	8.49464   
8193 	0.08016   	0.10808   	0.05810   	0.08236   	0.03122   	0.04558   	17.04835  
16385	0.11952   	0.15888   	0.09710   	0.13610   	0.05653   	0.08184   	34.62064  
19438	0.12944   	0.18268   	0.10652   	0.15711   	0.06309   	0.09828   	41.14330  
0.3598080745921428 296235.0123348684 10363.408032461206
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00291   
2    	0.70291   	0.71267   	0.01562   	0.00000   	0.01562   	0.03125   	0.00917   
3    	0.68100   	0.63718   	0.03125   	0.06250   	0.01042   	0.00000   	0.02266   
5    	0.63331   	0.56177   	0.0

19438	0.24757   	0.32768   	0.14481   	0.20216   	0.12177   	0.17494   	39.46389  
0.4883703952157747 218281.52612512154 11119.239809001503
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00235   
2    	0.70017   	0.70720   	0.01562   	0.00000   	0.00000   	0.00000   	0.00558   
3    	0.68465   	0.65361   	0.02083   	0.03125   	0.01042   	0.03125   	0.00833   
5    	0.62168   	0.52721   	0.01250   	0.00000   	0.00625   	0.00000   	0.01314   
9    	0.50528   	0.35978   	0.00694   	0.00000   	0.00347   	0.00000   	0.02198   
17   	0.33384   	0.14096   	0.00551   	0.00391   	0.00368   	0.00391   	0.03859   
33   	0.19490   	0.04727   	0.00473   	0.00391   	0.00379   	0.00391   	0.07125   
65   	0.10655   	0.01545   	0.00240   	0.00000   	0.00240   	0.00098   	0.17702   
129  	0.06336   	0.01950   	0.00266   	0.00293   	0.00242   	0.00244   	0.33381   
257  	0.05274   	0.04204   	0.

5    	0.62192   	0.53152   	0.00625   	0.00000   	0.00000   	0.00000   	0.01755   
9    	0.51140   	0.37325   	0.00694   	0.00781   	0.00694   	0.01562   	0.02724   
17   	0.33566   	0.13796   	0.00368   	0.00000   	0.00368   	0.00000   	0.04447   
33   	0.19061   	0.03649   	0.00284   	0.00195   	0.00189   	0.00000   	0.07696   
65   	0.11422   	0.03544   	0.00625   	0.00977   	0.00337   	0.00488   	0.13979   
129  	0.06775   	0.02056   	0.00654   	0.00684   	0.00315   	0.00293   	0.26338   
257  	0.05424   	0.04062   	0.00948   	0.01245   	0.00608   	0.00903   	0.51102   
513  	0.04516   	0.03605   	0.01036   	0.01123   	0.00688   	0.00769   	1.01362   
1025 	0.04600   	0.04683   	0.01567   	0.02100   	0.01009   	0.01331   	2.01849   
2049 	0.05470   	0.06341   	0.02465   	0.03363   	0.01614   	0.02219   	4.10502   
4097 	0.08626   	0.11784   	0.04705   	0.06946   	0.03221   	0.04829   	8.19712   
8193 	0.15686   	0.22748   	0.09060   	0.13417   	0.06905   	0.10590   	16.44344  
1638

257  	0.05668   	0.02164   	0.00644   	0.00635   	0.00389   	0.00366   	0.53493   
513  	0.04845   	0.04019   	0.00902   	0.01160   	0.00658   	0.00928   	1.08773   
1025 	0.05627   	0.06411   	0.01738   	0.02576   	0.01302   	0.01947   	2.10972   
2049 	0.07687   	0.09750   	0.03201   	0.04666   	0.02455   	0.03610   	4.16592   
4097 	0.12122   	0.16559   	0.05787   	0.08374   	0.04640   	0.06825   	8.33918   
8193 	0.19355   	0.26589   	0.09904   	0.14023   	0.08387   	0.12135   	16.47325  
16385	0.26201   	0.33047   	0.14404   	0.18904   	0.12609   	0.16831   	32.99725  
19438	0.27596   	0.35083   	0.15391   	0.20686   	0.13546   	0.18577   	39.22054  
0.08201538146836426 299653.05431176745 5914.029178024478
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00207   
2    	0.68746   	0.68177   	0.01562   	0.00000   	0.00000   	0.00000   	0.00669   
3    	0.68258   	0.67282   	0.

16385	0.08790   	0.11579   	0.09727   	0.13693   	0.04406   	0.06501   	32.64828  
19438	0.09527   	0.13483   	0.10757   	0.16290   	0.04953   	0.07894   	38.74380  
0.14640825280632958 56566.65990098979 14399.432125867566
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00206   
2    	0.69254   	0.69193   	0.01562   	0.00000   	0.00000   	0.00000   	0.00736   
3    	0.68173   	0.66013   	0.02083   	0.03125   	0.00000   	0.00000   	0.01184   
5    	0.66161   	0.63143   	0.01250   	0.00000   	0.00000   	0.00000   	0.01760   
9    	0.61184   	0.54963   	0.00694   	0.00000   	0.00000   	0.00000   	0.02719   
17   	0.52103   	0.41885   	0.00735   	0.00781   	0.00368   	0.00781   	0.04434   
33   	0.39012   	0.25103   	0.00473   	0.00195   	0.00473   	0.00586   	0.07734   
65   	0.25855   	0.12287   	0.00577   	0.00684   	0.00577   	0.00684   	0.14194   
129  	0.15575   	0.05134   	0.

3    	0.67866   	0.64007   	0.02083   	0.03125   	0.00000   	0.00000   	0.01093   
5    	0.62285   	0.53913   	0.01250   	0.00000   	0.00000   	0.00000   	0.01666   
9    	0.51943   	0.39016   	0.00694   	0.00000   	0.00347   	0.00781   	0.02656   
17   	0.36076   	0.18225   	0.00735   	0.00781   	0.00551   	0.00781   	0.04448   
33   	0.21494   	0.06002   	0.00568   	0.00391   	0.00379   	0.00195   	0.07795   
65   	0.13286   	0.04822   	0.00673   	0.00781   	0.00577   	0.00781   	0.14265   
129  	0.08334   	0.03304   	0.00557   	0.00439   	0.00557   	0.00537   	0.27386   
257  	0.06047   	0.03742   	0.00790   	0.01025   	0.00657   	0.00757   	0.59761   
513  	0.05553   	0.05057   	0.01231   	0.01672   	0.00920   	0.01184   	1.12324   
1025 	0.05402   	0.05252   	0.01905   	0.02582   	0.01244   	0.01569   	2.15001   
2049 	0.06981   	0.08561   	0.03017   	0.04129   	0.02036   	0.02829   	4.26340   
4097 	0.11806   	0.16633   	0.05782   	0.08548   	0.04238   	0.06441   	8.42780   
8193

129  	0.15098   	0.05365   	0.00315   	0.00342   	0.00339   	0.00342   	0.26880   
257  	0.09106   	0.03066   	0.00413   	0.00513   	0.00328   	0.00317   	0.53147   
513  	0.06654   	0.04192   	0.00816   	0.01221   	0.00573   	0.00818   	1.10283   
1025 	0.05375   	0.04093   	0.01155   	0.01495   	0.00823   	0.01074   	2.14128   
2049 	0.05981   	0.06588   	0.02007   	0.02859   	0.01455   	0.02087   	4.23675   
4097 	0.08774   	0.11568   	0.03803   	0.05600   	0.02815   	0.04175   	8.35966   
8193 	0.14229   	0.19685   	0.07031   	0.10259   	0.05587   	0.08361   	16.54240  
16385	0.21798   	0.29368   	0.11979   	0.16927   	0.10061   	0.14536   	33.22498  
19438	0.23464   	0.32409   	0.13136   	0.19350   	0.11121   	0.16808   	39.38260  
0.08034461449241288 150922.26486145283 1817.8268788942737
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00222   
2    	0.68606   	0.67898   	0

8193 	0.13369   	0.18380   	0.07569   	0.10816   	0.05018   	0.07373   	16.49086  
16385	0.19141   	0.24913   	0.11505   	0.15442   	0.08159   	0.11300   	32.88689  
19438	0.20316   	0.26626   	0.12418   	0.17315   	0.08929   	0.13061   	38.91966  
0.21418871446768917 172784.6567558816 14694.472894070543
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00201   
2    	0.69525   	0.69735   	0.01562   	0.00000   	0.00000   	0.00000   	0.00658   
3    	0.68696   	0.67038   	0.02083   	0.03125   	0.00000   	0.00000   	0.01052   
5    	0.65365   	0.60368   	0.01250   	0.00000   	0.00000   	0.00000   	0.01612   
9    	0.58862   	0.50733   	0.01042   	0.00781   	0.00347   	0.00781   	0.02509   
17   	0.46671   	0.32957   	0.00551   	0.00000   	0.00184   	0.00000   	0.04173   
33   	0.31491   	0.15361   	0.00284   	0.00000   	0.00095   	0.00000   	0.07357   
65   	0.19585   	0.07307   	0.

2    	0.71885   	0.74455   	0.01562   	0.00000   	0.00000   	0.00000   	0.00648   
3    	0.69728   	0.65416   	0.02083   	0.03125   	0.01042   	0.03125   	0.01055   
5    	0.62652   	0.52038   	0.01250   	0.00000   	0.00625   	0.00000   	0.01629   
9    	0.48725   	0.31316   	0.00694   	0.00000   	0.00694   	0.00781   	0.02576   
17   	0.30453   	0.09898   	0.00368   	0.00000   	0.00368   	0.00000   	0.04315   
33   	0.17078   	0.02866   	0.00284   	0.00195   	0.00284   	0.00195   	0.08842   
65   	0.09913   	0.02524   	0.00337   	0.00391   	0.00337   	0.00391   	0.15198   
129  	0.06749   	0.03535   	0.00509   	0.00684   	0.00460   	0.00586   	0.27786   
257  	0.05400   	0.04040   	0.00681   	0.00854   	0.00571   	0.00684   	0.52285   
513  	0.04585   	0.03768   	0.00969   	0.01257   	0.00658   	0.00745   	1.07150   
1025 	0.05147   	0.05710   	0.01488   	0.02008   	0.01055   	0.01453   	2.01923   
2049 	0.07719   	0.10293   	0.03108   	0.04730   	0.02332   	0.03610   	4.03930   
4097

65   	0.08180   	0.01997   	0.00337   	0.00391   	0.00240   	0.00293   	0.12973   
129  	0.05518   	0.02815   	0.00630   	0.00928   	0.00339   	0.00439   	0.25128   
257  	0.04444   	0.03362   	0.00839   	0.01050   	0.00511   	0.00684   	0.48837   
513  	0.04464   	0.04483   	0.01096   	0.01355   	0.00719   	0.00928   	1.01718   
1025 	0.04418   	0.04372   	0.01460   	0.01825   	0.00915   	0.01111   	1.97455   
2049 	0.06291   	0.08167   	0.02518   	0.03577   	0.01710   	0.02505   	3.87169   
4097 	0.09139   	0.11988   	0.04679   	0.06842   	0.03199   	0.04689   	7.87590   
8193 	0.14607   	0.20076   	0.08130   	0.11581   	0.05864   	0.08530   	15.67226  
16385	0.21370   	0.28133   	0.12319   	0.16509   	0.09445   	0.13026   	31.07640  
19438	0.22792   	0.30424   	0.13289   	0.18496   	0.10315   	0.14988   	36.79760  
0.31027835282079846 174709.86807493577 10714.697651492404
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0

4097 	0.07571   	0.09450   	0.04428   	0.06415   	0.02545   	0.03732   	7.83901   
8193 	0.11674   	0.15779   	0.07650   	0.10873   	0.04856   	0.07166   	15.75163  
16385	0.16735   	0.21797   	0.11628   	0.15607   	0.07957   	0.11059   	31.48629  
19438	0.17875   	0.23993   	0.12574   	0.17648   	0.08709   	0.12747   	37.28205  
0.5374306847499736 227381.2798345313 11724.984555718698
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00208   
2    	0.70497   	0.71679   	0.01562   	0.00000   	0.00000   	0.00000   	0.00699   
3    	0.68590   	0.64775   	0.02083   	0.03125   	0.00000   	0.00000   	0.01119   
5    	0.62143   	0.52473   	0.01250   	0.00000   	0.00000   	0.00000   	0.01738   
9    	0.49686   	0.34115   	0.01042   	0.00781   	0.00694   	0.01562   	0.02705   
17   	0.31816   	0.11712   	0.00551   	0.00000   	0.00368   	0.00000   	0.04392   
33   	0.17644   	0.02586   	0.0

1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00206   
2    	0.69586   	0.69857   	0.01562   	0.00000   	0.00000   	0.00000   	0.00641   
3    	0.68701   	0.66932   	0.02083   	0.03125   	0.00000   	0.00000   	0.01037   
5    	0.64542   	0.58304   	0.01250   	0.00000   	0.00000   	0.00000   	0.01583   
9    	0.56756   	0.47023   	0.00694   	0.00000   	0.00000   	0.00000   	0.02511   
17   	0.43639   	0.28882   	0.00368   	0.00000   	0.00184   	0.00391   	0.04193   
33   	0.28635   	0.12694   	0.00284   	0.00195   	0.00284   	0.00391   	0.07432   
65   	0.17441   	0.05897   	0.00529   	0.00781   	0.00433   	0.00586   	0.13664   
129  	0.10395   	0.03238   	0.00388   	0.00244   	0.00388   	0.00342   	0.27045   
257  	0.06692   	0.02960   	0.00462   	0.00537   	0.00413   	0.00439   	0.51017   
513  	0.05209   	0.03720   	0.00786   	0.01111   	0.00542   	0.00671   	0.98699   
1025 	0.05447   	0.05686   	0.01402   	0.02020   	0.01043   	0.01544   	1.96493   
2049

33   	0.23134   	0.08207   	0.00379   	0.00000   	0.00379   	0.00195   	0.07572   
65   	0.13557   	0.03681   	0.00529   	0.00684   	0.00337   	0.00293   	0.13609   
129  	0.08422   	0.03206   	0.00581   	0.00635   	0.00388   	0.00439   	0.25410   
257  	0.05901   	0.03360   	0.00766   	0.00952   	0.00535   	0.00684   	0.57049   
513  	0.04311   	0.02715   	0.01109   	0.01453   	0.00591   	0.00647   	1.03488   
1025 	0.03672   	0.03033   	0.01598   	0.02087   	0.00726   	0.00861   	1.98221   
2049 	0.04170   	0.04669   	0.02375   	0.03152   	0.01098   	0.01471   	3.94496   
4097 	0.06028   	0.07886   	0.04399   	0.06424   	0.02145   	0.03192   	7.85955   
8193 	0.09968   	0.13908   	0.07881   	0.11365   	0.04310   	0.06475   	15.70819  
16385	0.14532   	0.19097   	0.12128   	0.16376   	0.07267   	0.10225   	31.27610  
19438	0.15628   	0.21508   	0.13183   	0.18845   	0.08047   	0.12231   	37.08315  
0.3263798620667253 243828.19499149604 14467.021533426461
n    	loss      	since last	ac

2049 	0.06327   	0.07845   	0.02545   	0.03610   	0.01527   	0.02209   	3.98460   
4097 	0.08785   	0.11246   	0.04344   	0.06143   	0.02733   	0.03940   	7.87588   
8193 	0.13496   	0.18208   	0.07768   	0.11192   	0.05259   	0.07786   	15.69234  
16385	0.18971   	0.24446   	0.11829   	0.15891   	0.08521   	0.11784   	31.48331  
19438	0.20338   	0.27675   	0.12775   	0.17852   	0.09330   	0.13670   	37.46718  
0.6303390780182125 10910.367946860879 5447.6284110035485
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00210   
2    	0.70255   	0.71196   	0.01562   	0.00000   	0.00000   	0.00000   	0.00537   
3    	0.67151   	0.60944   	0.02083   	0.03125   	0.00000   	0.00000   	0.00807   
5    	0.58595   	0.45761   	0.01250   	0.00000   	0.00000   	0.00000   	0.01254   
9    	0.44382   	0.26616   	0.00694   	0.00000   	0.00000   	0.00000   	0.02112   
17   	0.27288   	0.08058   	0.

1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00208   
2    	0.69472   	0.69629   	0.01562   	0.00000   	0.00000   	0.00000   	0.00703   
3    	0.68132   	0.65453   	0.02083   	0.03125   	0.00000   	0.00000   	0.01122   
5    	0.65420   	0.61350   	0.01250   	0.00000   	0.00000   	0.00000   	0.01739   
9    	0.60077   	0.53400   	0.00694   	0.00000   	0.00000   	0.00000   	0.02684   
17   	0.50133   	0.38946   	0.00735   	0.00781   	0.00368   	0.00781   	0.04395   
33   	0.36430   	0.21871   	0.00568   	0.00391   	0.00379   	0.00391   	0.07561   
65   	0.23288   	0.09735   	0.00433   	0.00293   	0.00337   	0.00293   	0.13783   
129  	0.14092   	0.04751   	0.00654   	0.00879   	0.00339   	0.00342   	0.31934   
257  	0.09215   	0.04300   	0.00766   	0.00879   	0.00596   	0.00854   	0.57122   
513  	0.06479   	0.03732   	0.00950   	0.01135   	0.00749   	0.00903   	1.06537   
1025 	0.05748   	0.05016   	0.01369   	0.01788   	0.01018   	0.01288   	2.01906   
2049

33   	0.27556   	0.12261   	0.00663   	0.00781   	0.00379   	0.00781   	0.07650   
65   	0.16294   	0.04679   	0.00529   	0.00391   	0.00337   	0.00293   	0.15035   
129  	0.10633   	0.04885   	0.00727   	0.00928   	0.00581   	0.00830   	0.27310   
257  	0.06899   	0.03135   	0.00803   	0.00879   	0.00559   	0.00537   	0.52015   
513  	0.05715   	0.04527   	0.01096   	0.01392   	0.00804   	0.01050   	1.00072   
1025 	0.05583   	0.05451   	0.01598   	0.02100   	0.01216   	0.01630   	2.14639   
2049 	0.07138   	0.08694   	0.02590   	0.03583   	0.02024   	0.02832   	4.07249   
4097 	0.11373   	0.15609   	0.04885   	0.07181   	0.03961   	0.05899   	8.08890   
8193 	0.18497   	0.25624   	0.09032   	0.13181   	0.07738   	0.11517   	16.29602  
16385	0.26100   	0.33703   	0.13960   	0.18888   	0.12349   	0.16961   	32.38738  
19438	0.27630   	0.35845   	0.15023   	0.20732   	0.13364   	0.18807   	38.32784  
0.7366630179577417 187236.3365946736 3649.4415866411914
n    	loss      	since last	acc

2049 	0.06053   	0.07189   	0.02440   	0.03375   	0.01603   	0.02234   	3.88252   
4097 	0.10134   	0.14217   	0.04773   	0.07106   	0.03392   	0.05182   	7.86343   
8193 	0.17186   	0.24240   	0.09106   	0.13441   	0.06999   	0.10608   	15.62069  
16385	0.23486   	0.29786   	0.13296   	0.17486   	0.10714   	0.14429   	31.74215  
19438	0.24725   	0.31375   	0.14199   	0.19048   	0.11527   	0.15889   	37.75331  
0.13197320863481526 19422.167321497698 2984.0854805163167
n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.00224   
2    	0.68305   	0.67294   	0.01562   	0.00000   	0.00000   	0.00000   	0.00520   
3    	0.67226   	0.65070   	0.02083   	0.03125   	0.00000   	0.00000   	0.00842   
5    	0.65017   	0.61703   	0.01250   	0.00000   	0.01250   	0.03125   	0.01319   
9    	0.60272   	0.54340   	0.01042   	0.00781   	0.00694   	0.00000   	0.02117   
17   	0.51704   	0.42065   	0

In [4]:
mydata = loadMyDataset(1000)

In [None]:
learnOnline(mydata, initlr=4e-1, tzero=100000, rank=50, gamma=4000)

In [4]:
mydata = loadMyDataset(200)

In [None]:
learnOnline(mydata, initlr=1e-1, tzero=300000, rank=50, gamma=16000)