# Dataset

In [1]:
import torch

class EasyAcc:
    def __init__(self):
        self.n = 0
        self.sum = 0
        self.sumsq = 0

    def __iadd__(self, other):
        self.n += 1
        self.sum += other
        self.sumsq += other*other
        return self

    def __isub__(self, other):
        self.n += 1
        self.sum -= other
        self.sumsq += other*other
        return self

    def mean(self):
        return self.sum / max(self.n, 1)

    def var(self):
        from math import sqrt
        return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)

    def semean(self):
        from math import sqrt
        return self.var() / sqrt(max(self.n, 1))

def categoryCount():
    from collections import defaultdict
    import gzip
    import json
        
    counts = {}

    with gzip.open('entityfreq.gz', 'rt') as f:
        for line in f:
            try:
                freq, entity = line.strip().split()
            except:
                continue
            counts[entity] = int(freq)
            
    return counts

def getCategories(threshold):
    from sentence_transformers import SentenceTransformer
    import gzip
    import json
    import re
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
        
    for entity, freq in categoryCount().items():
        if freq >= threshold:
            niceentity = re.sub(r'_', r' ', entity)
            embedcat = model.encode([niceentity])[0]
            yield entity, embedcat

def datasetStats(threshold):
    numclasses = len([ entity for entity, freq in categoryCount().items() if freq >= threshold ])
    return { 'numclasses': numclasses, 'numexamples': threshold * numclasses }
            
def makeData(threshold, categories):
    from collections import defaultdict
    from sentence_transformers import SentenceTransformer
    import json
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
    catcount = defaultdict(int)
    
    with open('shuffled_dedup_entities.tsv') as f:
        batchline, batchencode, batchentity = [], [], []
        for line in f:
            try:
                entity, pre, mention, post = line.strip().split('\t')
            except:
                continue
                
            if entity in categories and catcount[entity] < threshold:
                catcount[entity] += 1
                batchline.append(line)
                batchencode.append(pre)
                batchencode.append(post)
                batchentity.append(entity)

                if len(batchline) == 5:
                    embed = model.encode(batchencode)

                    for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                        embedpre, embedpost = embed[2*n], embed[2*n+1]
                        entityord, entityvec = categories[entity]
                        yield { 'line': line, 
                                'entityord': entityord, 
                                'entityvec': entityvec,
                                'pre': embedpre, 
                                'post': embedpost }

                    batchline, batchencode, batchentity = [], [], []
                
        if len(batchline):
            embed = model.encode(batchencode)

            for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                embedpre, embedpost = embed[2*n], embed[2*n+1]
                entityord, entityvec = categories[entity]
                yield { 'line': line, 
                        'entityord': entityord, 
                        'entityvec': entityvec,
                        'pre': embedpre, 
                        'post': embedpost }

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, threshold):
        from tqdm.notebook import tqdm
        self.labelfeats = { k: (n, v) for n, (k, v) in enumerate(getCategories(threshold)) } 
        Xs = []
        ys = []
        for n, what in tqdm(enumerate(makeData(threshold, self.labelfeats))):
#             if n >= 1000:
#                 break
            pre = torch.tensor(what['pre'])
            post = torch.tensor(what['post'])
            Xs.append(torch.cat((pre, post)).unsqueeze(0))
            ys.append(what['entityord'])

        self.Xs = torch.cat(Xs, dim=0)
        self.ys = torch.LongTensor(ys)
            
    def __len__(self):
        return self.Xs.shape[0]

    def __getitem__(self, index):
        # Select sample
        return self.Xs[index], self.ys[index]

In [2]:
datasetStats(2000), datasetStats(1000), datasetStats(200)

({'numclasses': 311, 'numexamples': 622000},
 {'numclasses': 1154, 'numexamples': 1154000},
 {'numclasses': 14031, 'numexamples': 2806200})

## This takes time, run once only (days)

In [3]:
def makeMyDataset(threshold):
    import gzip
    
    foo = MyDataset(threshold)
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'wb') as handle:
        import pickle
        pickle.dump(foo, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
makeMyDataset(2000)

0it [00:00, ?it/s]

## Load cached processed data

In [2]:
def loadMyDataset(threshold):
    import gzip
    
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'rb') as handle:
        import pickle
        return pickle.load(handle)

In [3]:
# best constant predictor
# if you don't beat this, you have a problem

def bestconstant(threshold):
    from math import fsum
    
    counts = { k: threshold for k, v in categoryCount().items() if v >= threshold }
    sumcounts = fsum(v for v in counts.values())
    predict = torch.Tensor([ v / sumcounts for v in counts.values() ]).unsqueeze(0)
    log_loss = torch.nn.CrossEntropyLoss()
    sumloss, denom = EasyAcc(), 0
    
    for m, k in enumerate(counts.keys()):
        n = counts[k]
        actual = torch.LongTensor([m])
        sumloss += n * log_loss(predict, actual).item()
        denom += n
    
    return { 'best_constant_answer': max((v, k) for k, v in counts.items())[1], 
             'best_constant_average_logloss': sumloss.sum / denom,
             'best_constant_average_accuracy': max(v for v in counts.values()) / denom }            

bestconstant(2000), bestconstant(1000), bestconstant(200)

({'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 5.739792823791504,
  'best_constant_average_accuracy': 0.003215434083601286},
 {'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 7.050989627838135,
  'best_constant_average_accuracy': 0.0008665511265164644},
 {'best_constant_answer': 'weight_gain',
  'best_constant_average_logloss': 9.54902458190918,
  'best_constant_average_accuracy': 7.127075760815338e-05})

In [5]:
class Bilinear(torch.nn.Module):
    def __init__(self, dobs, daction, naction):
        super(Bilinear, self).__init__()
        
        self.W = torch.nn.Parameter(torch.zeros(dobs, daction))

    def forward(self, Xs, Zs):
        return torch.matmul(torch.matmul(Xs, self.W), Zs.T)

def learnOnline(dataset, seed=4545, initlr=2e-1, tzero=200, rank=None):
    torch.manual_seed(seed)
    labelfeatsdict = { n: v for n, v in dataset.labelfeats.values() }
    labelfeats = [ torch.tensor(labelfeatsdict[n]).float().unsqueeze(0) for n in range(len(labelfeatsdict)) ]
    Zs = torch.cat(labelfeats, dim=0)
    
    if rank is not None:
        with torch.no_grad():
            U, S, Vh = torch.linalg.svd(Zs, full_matrices=False)
            Zs = U[:, :rank] @ torch.diag(S[:rank])
        
    generator = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    model = None
    log_loss = torch.nn.CrossEntropyLoss()
    
    print('{:<5s}\t{:<8s}\t{:<8s}\t{:<8s}\t{:<8s}'.format('n', 'loss', 'since last', 'acc', 'acc since last'), flush=True)
    avloss, acc, sincelast, accsincelast = EasyAcc(), EasyAcc(), EasyAcc(), EasyAcc()
    
    for bno, (Xs, ys) in enumerate(generator):
        if model is None:
            import numpy as np
            model = Bilinear(dobs=Xs.shape[1], daction=Zs.shape[1], naction=Zs.shape[0])
            opt = torch.optim.Adam(( p for p in model.parameters() if p.requires_grad ), lr=initlr)
            scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda = lambda t: np.sqrt(tzero) / np.sqrt(tzero + t))

        opt.zero_grad()
        score = model.forward(0.0001 * Xs, Zs)
        loss = log_loss(score, ys)
        loss.backward()
        opt.step()
        scheduler.step()
        
        with torch.no_grad():
            pred = torch.argmax(score, dim=1)
            acc += torch.mean((pred == ys).float())
            accsincelast += torch.mean((pred == ys).float())
            avloss += loss
            sincelast += loss

        if bno & (bno - 1) == 0:
            print('{:<5d}\t{:<8.5f}\t{:<8.5f}\t{:<8.5f}\t{:<8.5f}'.format(avloss.n, avloss.mean(), sincelast.mean(), acc.mean(), accsincelast.mean()), flush=True)
            sincelast, accsincelast = EasyAcc(), EasyAcc()

    print('{:<5d}\t{:<8.5f}\t{:<8.5f}\t{:<8.5f}\t{:<8.5f}'.format(avloss.n, avloss.mean(), sincelast.mean(), acc.mean(), accsincelast.mean()), flush=True)

In [4]:
mydata = loadMyDataset(2000)

In [4]:
learnOnline(mydata, initlr=1.6, rank=50)

n    	loss    	since last	acc     	acc since last
1    	5.73979 	5.73979 	0.03125 	0.03125 
2    	5.74189 	5.74398 	0.04688 	0.06250 
3    	5.75276 	5.77451 	0.03125 	0.00000 
5    	5.68035 	5.57173 	0.01875 	0.00000 
9    	5.60537 	5.51164 	0.03125 	0.04688 
17   	5.47308 	5.32425 	0.03125 	0.03125 
33   	5.19774 	4.90520 	0.03883 	0.04688 
65   	4.87412 	4.54038 	0.06154 	0.08496 
129  	4.56322 	4.24746 	0.09278 	0.12451 
257  	4.28451 	4.00364 	0.12524 	0.15796 
513  	4.04570 	3.80595 	0.15722 	0.18933 
1025 	3.83418 	3.62225 	0.18427 	0.21136 
2049 	3.67436 	3.51437 	0.20446 	0.22467 
4097 	3.54882 	3.42322 	0.22160 	0.23875 
8193 	3.44915 	3.34948 	0.23554 	0.24947 
16385	3.36143 	3.27368 	0.24843 	0.26132 
19438	3.34360 	3.24788 	0.25124 	0.26634 


In [4]:
mydata = loadMyDataset(1000)

In [6]:
learnOnline(mydata, initlr=8e-1, tzero=400, rank=50)

n    	loss    	since last	acc     	acc since last
1    	7.05099 	7.05099 	0.00000 	0.00000 
2    	7.03190 	7.01281 	0.00000 	0.00000 
3    	7.04831 	7.08113 	0.00000 	0.00000 
5    	7.04976 	7.05195 	0.00000 	0.00000 
9    	6.96595 	6.86118 	0.00347 	0.00781 
17   	6.87694 	6.77679 	0.00735 	0.01172 
33   	6.70233 	6.51681 	0.01042 	0.01367 
65   	6.45273 	6.19534 	0.02067 	0.03125 
129  	6.11979 	5.78165 	0.03464 	0.04883 
257  	5.81036 	5.49852 	0.04815 	0.06177 
513  	5.51839 	5.22528 	0.06335 	0.07861 
1025 	5.28478 	5.05070 	0.07966 	0.09601 
2049 	5.10783 	4.93069 	0.09255 	0.10544 
4097 	4.96485 	4.82180 	0.10396 	0.11537 
8193 	4.85280 	4.74070 	0.11337 	0.12280 
16385	4.76203 	4.67124 	0.12150 	0.12963 
32769	4.69070 	4.61934 	0.12856 	0.13561 
36063	4.68256 	4.60159 	0.12936 	0.13735 
