# Dataset

In [5]:
import torch

class EasyAcc:
    def __init__(self):
        self.n = 0
        self.sum = 0
        self.sumsq = 0

    def __iadd__(self, other):
        self.n += 1
        self.sum += other
        self.sumsq += other*other
        return self

    def __isub__(self, other):
        self.n += 1
        self.sum -= other
        self.sumsq += other*other
        return self

    def mean(self):
        return self.sum / max(self.n, 1)

    def var(self):
        from math import sqrt
        return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)

    def semean(self):
        from math import sqrt
        return self.var() / sqrt(max(self.n, 1))

def categoryCount():
    from collections import defaultdict
    import gzip
    import json
        
    counts = {}

    with gzip.open('entityfreq.gz', 'rt') as f:
        for line in f:
            try:
                freq, entity = line.strip().split()
            except:
                continue
            counts[entity] = int(freq)
            
    return counts

def getCategories(threshold):
    from sentence_transformers import SentenceTransformer
    import gzip
    import json
    import re
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
        
    for entity, freq in categoryCount().items():
        if freq >= threshold:
            niceentity = re.sub(r'_', r' ', entity)
            embedcat = model.encode([niceentity])[0]
            yield entity, embedcat

def datasetStats(threshold):
    numclasses = len([ entity for entity, freq in categoryCount().items() if freq >= threshold ])
    return { 'numclasses': numclasses, 'numexamples': threshold * numclasses }
            
def makeData(threshold, categories):
    from collections import defaultdict
    from sentence_transformers import SentenceTransformer
    import json
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
    catcount = defaultdict(int)
    
    with open('shuffled_dedup_entities.tsv') as f:
        batchline, batchencode, batchentity = [], [], []
        for line in f:
            try:
                entity, pre, mention, post = line.strip().split('\t')
            except:
                continue
                
            if entity in categories and catcount[entity] < threshold:
                catcount[entity] += 1
                batchline.append(line)
                batchencode.append(pre)
                batchencode.append(post)
                batchentity.append(entity)

                if len(batchline) == 5:
                    embed = model.encode(batchencode)

                    for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                        embedpre, embedpost = embed[2*n], embed[2*n+1]
                        entityord, entityvec = categories[entity]
                        yield { 'line': line, 
                                'entityord': entityord, 
                                'entityvec': entityvec,
                                'pre': embedpre, 
                                'post': embedpost }

                    batchline, batchencode, batchentity = [], [], []
                
        if len(batchline):
            embed = model.encode(batchencode)

            for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                embedpre, embedpost = embed[2*n], embed[2*n+1]
                entityord, entityvec = categories[entity]
                yield { 'line': line, 
                        'entityord': entityord, 
                        'entityvec': entityvec,
                        'pre': embedpre, 
                        'post': embedpost }

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, threshold):
        from tqdm.notebook import tqdm
        self.labelfeats = { k: (n, v) for n, (k, v) in enumerate(getCategories(threshold)) } 
        Xs = []
        ys = []
        for n, what in tqdm(enumerate(makeData(threshold, self.labelfeats))):
#             if n >= 1000:
#                 break
            pre = torch.tensor(what['pre'])
            post = torch.tensor(what['post'])
            Xs.append(torch.cat((pre, post)).unsqueeze(0))
            ys.append(what['entityord'])

        self.Xs = torch.cat(Xs, dim=0)
        self.ys = torch.LongTensor(ys)
            
    def __len__(self):
        return self.Xs.shape[0]

    def __getitem__(self, index):
        # Select sample
        return self.Xs[index], self.ys[index]

In [2]:
datasetStats(2000), datasetStats(1000), datasetStats(200)

({'numclasses': 311, 'numexamples': 622000},
 {'numclasses': 1154, 'numexamples': 1154000},
 {'numclasses': 14031, 'numexamples': 2806200})

## This takes time, run once only (days)

In [3]:
def makeMyDataset(threshold):
    import gzip
    
    foo = MyDataset(threshold)
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'wb') as handle:
        import pickle
        pickle.dump(foo, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
makeMyDataset(2000)

0it [00:00, ?it/s]

## Load cached processed data

In [6]:
def loadMyDataset(threshold):
    import gzip
    
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'rb') as handle:
        import pickle
        return pickle.load(handle)

In [3]:
# best constant predictor
# if you don't beat this, you have a problem

def bestconstant(threshold):
    from math import fsum
    
    counts = { k: threshold for k, v in categoryCount().items() if v >= threshold }
    sumcounts = fsum(v for v in counts.values())
    predict = torch.Tensor([ v / sumcounts for v in counts.values() ]).unsqueeze(0)
    log_loss = torch.nn.CrossEntropyLoss()
    sumloss, denom = EasyAcc(), 0
    
    for m, k in enumerate(counts.keys()):
        n = counts[k]
        actual = torch.LongTensor([m])
        sumloss += n * log_loss(predict, actual).item()
        denom += n
    
    return { 'best_constant_answer': max((v, k) for k, v in counts.items())[1], 
             'best_constant_average_logloss': sumloss.sum / denom,
             'best_constant_average_accuracy': max(v for v in counts.values()) / denom }            

bestconstant(2000), bestconstant(1000), bestconstant(200)

({'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 5.739792823791504,
  'best_constant_average_accuracy': 0.003215434083601286},
 {'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 7.050989627838135,
  'best_constant_average_accuracy': 0.0008665511265164644},
 {'best_constant_answer': 'weight_gain',
  'best_constant_average_logloss': 9.54902458190918,
  'best_constant_average_accuracy': 7.127075760815338e-05})

In [14]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, d, device):
        super(ResidualBlock, self).__init__()
        
        self.W = torch.nn.Parameter(torch.zeros(d, d, device=device))
        self.afunc = torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)
        
    def forward(self, X):
        return X + 0.001 * self.afunc(torch.matmul(X, self.W))
    
class BilinearResidual(torch.nn.Module):
    def __init__(self, dobs, daction, device, depth):
        super(BilinearResidual, self).__init__()
        
        self.block = torch.nn.Sequential(*[ResidualBlock(dobs, device) for _ in range(depth) ])
        self.W = torch.nn.Parameter(torch.zeros(dobs, daction, device=device))

    def forward(self, Xs, Zs):
        return torch.matmul(torch.matmul(self.block(Xs), self.W), Zs.T)

def learnOnline(dataset, initlr, tzero, rank, depth, cuda=False, seed=4545):
    import time

    torch.manual_seed(seed)
    labelfeatsdict = { n: v for n, v in dataset.labelfeats.values() }
    labelfeats = [ torch.tensor(labelfeatsdict[n]).float().unsqueeze(0) for n in range(len(labelfeatsdict)) ]
    Zs = torch.cat(labelfeats, dim=0)
    
    if cuda:
        Zs = Zs.cuda()
    
    if rank is not None:
        with torch.no_grad():
            U, S, Vh = torch.linalg.svd(Zs, full_matrices=False)
            Zs = U[:, :rank] @ torch.diag(S[:rank])
        
    generator = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    model = None
    log_loss = torch.nn.CrossEntropyLoss()
    
    print('{:<5s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}'.format('n', 'loss', 'since last', 'acc', 'since last', 'dt (sec)'), flush=True)
    avloss, acc, sincelast, accsincelast = EasyAcc(), EasyAcc(), EasyAcc(), EasyAcc()
    
    for bno, (Xs, ys) in enumerate(generator):
        Xs, ys = Xs.to(Zs.device), ys.to(Zs.device)
        
        if model is None:
            import numpy as np
            model = BilinearResidual(dobs=Xs.shape[1], daction=Zs.shape[1], depth=depth, device=Zs.device)
            opt = torch.optim.Adam(( p for p in model.parameters() if p.requires_grad ), lr=initlr)
            scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda = lambda t: np.sqrt(tzero) / np.sqrt(tzero + t))
            start = time.time()

        opt.zero_grad()
        score = model.forward(0.0001 * Xs, Zs)
        loss = log_loss(score, ys)
        loss.backward()
        opt.step()
        scheduler.step()
        
        with torch.no_grad():
            pred = torch.argmax(score, dim=1)
            acc += torch.mean((pred == ys).float())
            accsincelast += torch.mean((pred == ys).float())
            avloss += loss
            sincelast += loss

        if bno & (bno - 1) == 0:
            print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(avloss.n, avloss.mean(), sincelast.mean(), 
                                                                                         acc.mean(), accsincelast.mean(), time.time() - start), 
                  flush=True)
            sincelast, accsincelast = EasyAcc(), EasyAcc()

    print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(avloss.n, avloss.mean(), sincelast.mean(), 
                                                                                 acc.mean(), accsincelast.mean(), time.time() - start), 
          flush=True)

In [8]:
mydata = loadMyDataset(2000)

In [15]:
learnOnline(mydata, initlr=1.6, tzero=1000, rank=50, depth=1)

n    	loss      	since last	acc       	since last	dt (sec)  
1    	5.73979   	5.73979   	0.03125   	0.03125   	1.59583   
2    	5.74189   	5.74398   	0.04688   	0.06250   	1.61233   
3    	5.74583   	5.75372   	0.03125   	0.00000   	1.62748   
5    	5.63384   	5.46584   	0.01875   	0.00000   	1.65971   
9    	5.49162   	5.31384   	0.04167   	0.07031   	1.71931   
17   	5.37390   	5.24147   	0.03676   	0.03125   	1.83441   
33   	5.09394   	4.79647   	0.05019   	0.06445   	2.07867   
65   	4.78399   	4.46436   	0.06923   	0.08887   	2.52995   
129  	4.49957   	4.21070   	0.09859   	0.12842   	3.39347   
257  	4.25538   	4.00928   	0.12707   	0.15576   	5.09717   
513  	4.04204   	3.82787   	0.15515   	0.18335   	8.51150   
1025 	3.83618   	3.62993   	0.18399   	0.21289   	15.28262  
2049 	3.67509   	3.51382   	0.20579   	0.22760   	29.04722  
4097 	3.53127   	3.38738   	0.22584   	0.24590   	56.78693  
8193 	3.39453   	3.25775   	0.24602   	0.26620   	111.52929 
16385	3.25441   	3.11428

In [16]:
learnOnline(mydata, initlr=1.6, tzero=1000, rank=50, depth=2)

n    	loss      	since last	acc       	since last	dt (sec)  
1    	5.73979   	5.73979   	0.03125   	0.03125   	0.02975   
2    	5.74189   	5.74398   	0.04688   	0.06250   	0.06424   
3    	5.73980   	5.73563   	0.03125   	0.00000   	0.09255   
5    	5.63467   	5.47698   	0.01875   	0.00000   	0.15084   
9    	5.48668   	5.30170   	0.04167   	0.07031   	0.27073   
17   	5.35737   	5.21189   	0.03493   	0.02734   	0.50968   
33   	5.09352   	4.81317   	0.05019   	0.06641   	0.94030   
65   	4.80017   	4.49766   	0.06827   	0.08691   	1.80242   
129  	4.52182   	4.23912   	0.09593   	0.12402   	3.57374   
257  	4.27489   	4.02603   	0.12622   	0.15674   	7.00461   
513  	4.05326   	3.83076   	0.15363   	0.18115   	14.02894  
1025 	3.84507   	3.63647   	0.18229   	0.21100   	28.35510  
2049 	3.67987   	3.51450   	0.20480   	0.22733   	56.35309  
4097 	3.52884   	3.37773   	0.22626   	0.24774   	112.48561 
8193 	3.38435   	3.23981   	0.24758   	0.26890   	225.65657 
16385	3.23519   	3.08601

In [17]:
learnOnline(mydata, initlr=1.6, tzero=1000, rank=50, depth=3)

n    	loss      	since last	acc       	since last	dt (sec)  
1    	5.73979   	5.73979   	0.03125   	0.03125   	0.04361   
2    	5.74189   	5.74398   	0.04688   	0.06250   	0.08212   
3    	5.73457   	5.71995   	0.03125   	0.00000   	0.12103   
5    	5.64684   	5.51525   	0.02500   	0.01562   	0.19858   
9    	5.49352   	5.30186   	0.04861   	0.07812   	0.36066   
17   	5.36423   	5.21878   	0.03676   	0.02344   	0.69372   
33   	5.10200   	4.82338   	0.05019   	0.06445   	1.30164   
65   	4.81347   	4.51592   	0.07019   	0.09082   	2.50194   
129  	4.54179   	4.26587   	0.09254   	0.11523   	4.96825   
257  	4.30416   	4.06468   	0.12281   	0.15332   	9.97540   
513  	4.08976   	3.87452   	0.15046   	0.17822   	19.43409  
1025 	3.87696   	3.66374   	0.17930   	0.20819   	38.36506  
2049 	3.71016   	3.54320   	0.20147   	0.22366   	76.37812  
4097 	3.55729   	3.40435   	0.22384   	0.24622   	151.93329 
8193 	3.40762   	3.25791   	0.24648   	0.26913   	305.66595 
16385	3.25106   	3.09444

In [4]:
mydata = loadMyDataset(1000)

In [6]:
learnOnline(mydata, initlr=8e-1, tzero=400, rank=50)

n    	loss    	since last	acc     	acc since last
1    	7.05099 	7.05099 	0.00000 	0.00000 
2    	7.03190 	7.01281 	0.00000 	0.00000 
3    	7.04831 	7.08113 	0.00000 	0.00000 
5    	7.04976 	7.05195 	0.00000 	0.00000 
9    	6.96595 	6.86118 	0.00347 	0.00781 
17   	6.87694 	6.77679 	0.00735 	0.01172 
33   	6.70233 	6.51681 	0.01042 	0.01367 
65   	6.45273 	6.19534 	0.02067 	0.03125 
129  	6.11979 	5.78165 	0.03464 	0.04883 
257  	5.81036 	5.49852 	0.04815 	0.06177 
513  	5.51839 	5.22528 	0.06335 	0.07861 
1025 	5.28478 	5.05070 	0.07966 	0.09601 
2049 	5.10783 	4.93069 	0.09255 	0.10544 
4097 	4.96485 	4.82180 	0.10396 	0.11537 
8193 	4.85280 	4.74070 	0.11337 	0.12280 
16385	4.76203 	4.67124 	0.12150 	0.12963 
32769	4.69070 	4.61934 	0.12856 	0.13561 
36063	4.68256 	4.60159 	0.12936 	0.13735 
