# Dataset

In [1]:
import torch

class EasyAcc:
    def __init__(self):
        self.n = 0
        self.sum = 0
        self.sumsq = 0

    def __iadd__(self, other):
        self.n += 1
        self.sum += other
        self.sumsq += other*other
        return self

    def __isub__(self, other):
        self.n += 1
        self.sum -= other
        self.sumsq += other*other
        return self

    def mean(self):
        return self.sum / max(self.n, 1)

    def var(self):
        from math import sqrt
        return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)

    def semean(self):
        from math import sqrt
        return self.var() / sqrt(max(self.n, 1))

def categoryCount():
    from collections import defaultdict
    import gzip
    import json
        
    counts = {}

    with gzip.open('entityfreq.gz', 'rt') as f:
        for line in f:
            try:
                freq, entity = line.strip().split()
            except:
                continue
            counts[entity] = int(freq)
            
    return counts

def getCategories(threshold):
    from sentence_transformers import SentenceTransformer
    import gzip
    import json
    import re
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
        
    for entity, freq in categoryCount().items():
        if freq >= threshold:
            niceentity = re.sub(r'_', r' ', entity)
            embedcat = model.encode([niceentity])[0]
            yield entity, embedcat

def datasetStats(threshold):
    numclasses = len([ entity for entity, freq in categoryCount().items() if freq >= threshold ])
    return { 'numclasses': numclasses, 'numexamples': threshold * numclasses }
            
def makeData(threshold, categories):
    from collections import defaultdict
    from sentence_transformers import SentenceTransformer
    import json
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
    catcount = defaultdict(int)
    
    with open('shuffled_dedup_entities.tsv') as f:
        batchline, batchencode, batchentity = [], [], []
        for line in f:
            try:
                entity, pre, mention, post = line.strip().split('\t')
            except:
                continue
                
            if entity in categories and catcount[entity] < threshold:
                catcount[entity] += 1
                batchline.append(line)
                batchencode.append(pre)
                batchencode.append(post)
                batchentity.append(entity)

                if len(batchline) == 5:
                    embed = model.encode(batchencode)

                    for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                        embedpre, embedpost = embed[2*n], embed[2*n+1]
                        entityord, entityvec = categories[entity]
                        yield { 'line': line, 
                                'entityord': entityord, 
                                'entityvec': entityvec,
                                'pre': embedpre, 
                                'post': embedpost }

                    batchline, batchencode, batchentity = [], [], []
                
        if len(batchline):
            embed = model.encode(batchencode)

            for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                embedpre, embedpost = embed[2*n], embed[2*n+1]
                entityord, entityvec = categories[entity]
                yield { 'line': line, 
                        'entityord': entityord, 
                        'entityvec': entityvec,
                        'pre': embedpre, 
                        'post': embedpost }

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, threshold):
        from tqdm.notebook import tqdm
        self.labelfeats = { k: (n, v) for n, (k, v) in enumerate(getCategories(threshold)) } 
        Xs = []
        ys = []
        for n, what in tqdm(enumerate(makeData(threshold, self.labelfeats))):
#             if n >= 1000:
#                 break
            pre = torch.tensor(what['pre'])
            post = torch.tensor(what['post'])
            Xs.append(torch.cat((pre, post)).unsqueeze(0))
            ys.append(what['entityord'])

        self.Xs = torch.cat(Xs, dim=0)
        self.ys = torch.LongTensor(ys)
            
    def __len__(self):
        return self.Xs.shape[0]

    def __getitem__(self, index):
        # Select sample
        return self.Xs[index], self.ys[index]

In [2]:
datasetStats(2000), datasetStats(1000), datasetStats(100)

({'numclasses': 311, 'numexamples': 622000},
 {'numclasses': 1154, 'numexamples': 1154000},
 {'numclasses': 32089, 'numexamples': 3208900})

## This takes time, run once only (days)

In [3]:
def makeMyDataset(threshold):
    import gzip
    
    foo = MyDataset(threshold)
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'wb') as handle:
        import pickle
        pickle.dump(foo, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
makeMyDataset(2000)

0it [00:00, ?it/s]

## Load cached processed data

In [2]:
def loadMyDataset(threshold):
    import gzip
    
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'rb') as handle:
        import pickle
        return pickle.load(handle)

In [4]:
# best constant predictor
# if you don't beat this, you have a problem

def bestconstant(threshold):
    from math import fsum
    
    counts = { k: threshold for k, v in categoryCount().items() if v >= threshold }
    sumcounts = fsum(v for v in counts.values())
    predict = torch.Tensor([ v / sumcounts for v in counts.values() ]).unsqueeze(0)
    log_loss = torch.nn.CrossEntropyLoss()
    sumloss, denom = EasyAcc(), 0
    
    for m, k in enumerate(counts.keys()):
        n = counts[k]
        actual = torch.LongTensor([m])
        sumloss += n * log_loss(predict, actual).item()
        denom += n
    
    return { 'best_constant_answer': max((v, k) for k, v in counts.items())[1], 
             'best_constant_average_logloss': sumloss.sum / denom,
             'best_constant_average_accuracy': max(v for v in counts.values()) / denom }            

bestconstant(2000), bestconstant(1000), bestconstant(200)

({'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 5.739792823791504,
  'best_constant_average_accuracy': 0.003215434083601286},
 {'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 7.050989627838135,
  'best_constant_average_accuracy': 0.0008665511265164644},
 {'best_constant_answer': 'weight_gain',
  'best_constant_average_logloss': 9.54902458190918,
  'best_constant_average_accuracy': 7.127075760815338e-05})

In [77]:
class Bilinear(torch.nn.Module):
    def __init__(self, dobs, daction, naction, device):
        super(Bilinear, self).__init__()
        
        self.W = torch.nn.Parameter(torch.zeros(dobs, daction, device=device))
        self.b = torch.nn.Parameter(torch.zeros(1, device=device))
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, Xs, Zs):
        return torch.matmul(torch.matmul(Xs, self.W), Zs.T) + self.b
        
    def preq1(self, logits):
        return self.sigmoid(logits)
    
class RankOneDetset(object):
    def __init__(self, batch_size, rawactions):
        self.N = batch_size
        self.rawactions = rawactions
        self.K, self.D = rawactions.shape
        self.device = rawactions.device
        
        self.batcheye = torch.eye(self.D, device=self.device).unsqueeze(0).expand(self.N, -1, -1)
        self.S = self.batcheye.clone()
        self.Sinv = self.batcheye.clone()
        self.logdetfac = torch.zeros(self.N, 1, device=self.device)
        
    def computePhi(self, i): 
        # Sprime_a <- replace column i of S with action a where det(S)=1
        # Sprime_a = S + (a - S_i) e_i^\top = S + u v^\top
        # det(Sprime_a) = det(S) (1 + e_i^\top S^{-1} (a - S_i))
        #               = (1 - (S^{-T} e_i)^\top S_i) + (S^{-T} e_i)^\top a
        #               = 0 + \phi^\top a
        
        #Sinvtopei = torch.linalg.solve(torch.transpose(self.S, 1, 2), self.batcheye[:,:,i])
        Sinvtopei = self.Sinv[:, i, :]
        return Sinvtopei, self.logdetfac
    
    def computeAllPhi(self):
        return self.Sinv, self.logdetfac
    
    def updateAll(self, colstar, fstar, astar, denom):
        Y = torch.gather(input=self.rawactions.unsqueeze(0).expand(self.N, -1, -1), 
                         dim=1, 
                         index=astar.reshape(self.N, 1, 1).expand(self.N, 1, self.D)
                        ).squeeze(1)
        Ydenom = torch.gather(input=denom, dim=1, index=astar)
        Y /= Ydenom
        Y /= torch.exp(self.logdetfac).reshape(self.N, 1)
        
        u = Y - torch.gather(input=self.S, dim=2, index=colstar.unsqueeze(1).expand(self.N, self.D, 1)).squeeze(2)
        Sinvu = torch.bmm(self.Sinv, u.unsqueeze(2)).squeeze(2)
        vtopSinv = torch.gather(input=self.Sinv, dim=1, index=colstar.unsqueeze(1).expand(self.N, 1, self.D)).squeeze(1)
        vtopSinvu = torch.gather(input=Sinvu, dim=1, index=colstar).unsqueeze(2)
        self.Sinv -= (1 / (1 + vtopSinvu)) * torch.bmm(Sinvu.unsqueeze(2), vtopSinv.unsqueeze(1))

        self.S.scatter_(index=colstar.unsqueeze(1).expand(self.N, self.D, 1), 
                        dim=2, 
                        src=Y.unsqueeze(2))
        thislogdet = 1/self.D * (torch.log(fstar) - self.logdetfac)
        scale = torch.exp(thislogdet).reshape(self.N, 1, 1)
        self.S /= scale
        self.Sinv *= scale
        self.logdetfac += thislogdet
    
    def updateCoord(self, i, fstar, astar, denom):
        Y = torch.gather(input=self.rawactions.unsqueeze(0).expand(self.N, -1, -1), 
                         dim=1, 
                         index=astar.reshape(self.N, 1, 1).expand(self.N, 1, self.D)
                        ).squeeze(1)
        Ydenom = torch.gather(input=denom, dim=1, index=astar)
        Y /= Ydenom
        Y /= torch.exp(self.logdetfac).reshape(self.N, 1)

        # replace column i of S with y
        # -----------------------------
        # Sprime = S + (y - S_i) e_i^\top = S + u v^\top
        # Sprime^{-1} = S^{-1} - 1/(1 + v^\top S^{-1} u) (S^{-1} u) (v^\top S^{-1})^\top
        
        u = Y - self.S[:, :, i]
        Sinvu = torch.bmm(self.Sinv, u.unsqueeze(2)).squeeze(2)
        vtopSinv = self.Sinv[:, i, :]
        vtopSinvu = Sinvu[:, i].unsqueeze(1).unsqueeze(2)
        self.Sinv -= (1 / (1 + vtopSinvu)) * torch.bmm(Sinvu.unsqueeze(2), vtopSinv.unsqueeze(1))
        
        self.S[:,:,i] = Y
        thislogdet = 1/self.D * (torch.log(fstar) - self.logdetfac)
        scale = torch.exp(thislogdet).reshape(self.N, 1, 1)
        self.S /= scale
        self.Sinv *= scale
        self.logdetfac += thislogdet

class SpannerIGW(torch.nn.Module):
    def __init__(self, actions, iota):
        super(SpannerIGW, self).__init__()
        
        self.rawactions = actions
        self.iota = iota
          
    def _make_spanner(self, denom):
        from math import log

        # Algorithm 4 Approximate Barycentric Identification (Awerbuch and Kleinberg, 2008)
        C = 2
        
        N, _ = denom.shape
        K, D = self.rawactions.shape
        device = self.rawactions.device
        detset = RankOneDetset(N, self.rawactions)
        design = torch.zeros(N, D, device=device).long()
                
        for i in range(D):
            psi, _ = detset.computePhi(i)
            dets = torch.abs(torch.matmul(self.rawactions, psi.T)).T / denom 
            fstar, astar = torch.max(dets, dim=1, keepdim=True)
            design[:, i] = astar.squeeze(1)
            detset.updateCoord(i, fstar, astar, denom)
                        
        for _ in range(int(D * log(D))):
            allpsi, logdetfac = detset.computeAllPhi()
            Z = self.rawactions.T.unsqueeze(0).expand(N, -1, -1)
            dets = torch.abs(torch.bmm(allpsi, Z)) / denom.unsqueeze(1)
            fstarcolumn, astarcolumn = torch.max(dets, dim=2)
            fstar, colstar = torch.max(fstarcolumn, dim=1, keepdim=True)
            astar = torch.gather(input=astarcolumn, dim=1, index=colstar)
                                
            if torch.any(fstar >= C * torch.exp(logdetfac)):
                design.scatter_(index=colstar, dim=1, src=astar)
                detset.updateAll(colstar, fstar, astar, denom)
            else:
                break
              
        return design

    def _compute_denom(self, fhat, fhatstar):
        N = fhat.shape[0]
        d = self.rawactions.shape[1]
        return torch.sqrt(1 + d + self.iota * (fhatstar - fhat))

    def sample(self, fhat):
        device = self.rawactions.device

        N = fhat.shape[0]
        D = self.rawactions.shape[1]
        fhatstar, exploit = torch.max(fhat, dim=1, keepdim=True)
        denom = self._compute_denom(fhat, fhatstar)
        design = self._make_spanner(denom)
        
        exploreindex = torch.randint(high=D, size=(N, 1), device=fhat.device)
        explore = torch.gather(input=design, dim=1, index=exploreindex)
        fhatexplore = torch.gather(input=fhat, dim=1, index=explore) 
        probs = D / (1 + D + self.iota * (fhatstar - fhatexplore))
        unif = torch.rand(size=(N, 1), device=fhat.device)
        shouldexplore = (unif <= probs).long()
        return (exploit + shouldexplore * (explore - exploit)).squeeze(1)

def learnOnline(dataset, seed=4545, rank=None, initlr=4e-1, tzero=100000, iota=1000, batch_size=32, cuda=False, extra=0):
    import time
    
    torch.manual_seed(seed)
    labelfeatsdict = { n: v for n, v in dataset.labelfeats.values() }
    labelfeats = [ torch.tensor(labelfeatsdict[n]).float().unsqueeze(0) for n in range(len(labelfeatsdict)) ]
    Zs = torch.cat(labelfeats, dim=0)
    
    if cuda:
        Zs = Zs.cuda()
    
    if rank is not None:
        with torch.no_grad():
            U, S, Vh = torch.linalg.svd(Zs, full_matrices=False)
            Zs = U[:, :rank] @ torch.diag(S[:rank])
            
    original = Zs.shape[0]
    if extra > 0:
        moarrows = Zs[-1,:].expand(extra, -1)
        Zs = torch.cat((Zs, moarrows), dim=0)
    
    generator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    model = None
    log_loss = torch.nn.BCEWithLogitsLoss()
        
    print('{:<5s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}'.format(
            'n', 'loss', 'since last', 'acc', 'since last', 'reward', 'since last', 'dt (sec)'), 
          flush=True)
    avloss, sincelast, acc, accsincelast, avreward, rewardsincelast = [ EasyAcc() for _ in range(6) ]
    
    for bno, (Xs, ys) in enumerate(generator):
        Xs, ys = Xs.to(Zs.device), ys.to(Zs.device)
                        
        if model is None:
            import numpy as np
            model = Bilinear(dobs=Xs.shape[1], daction=Zs.shape[1], naction=Zs.shape[0], device=Zs.device)
            sampler = SpannerIGW(actions=Zs, iota=iota)
            opt = torch.optim.Adam(( p for p in model.parameters() if p.requires_grad ), lr=initlr)
            scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda = lambda t: np.sqrt(tzero) / np.sqrt(tzero + t))
            start = time.time()
            
        with torch.no_grad():
            fhat = model.preq1(model.forward(0.0001 * Xs, Zs))
            presample = sampler.sample(fhat)
            sample = (presample >= original).long() * (original - 1 - presample) + presample
            reward = (sample == ys).unsqueeze(1).float()
            
        opt.zero_grad()
        logit = model.forward(0.0001 * Xs, Zs)
        samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)
        loss = log_loss(samplelogit, reward)
        loss.backward()
        opt.step()
        scheduler.step()
        
        with torch.no_grad():
            prepred = torch.argmax(logit, dim=1)
            pred = (prepred >= original).long() * (original - 1 - prepred) + prepred
            acc += torch.mean((pred == ys).float())
            accsincelast += torch.mean((pred == ys).float())
            avloss += loss
            sincelast += loss
            avreward += torch.mean(reward)
            rewardsincelast += torch.mean(reward)

        if bno & (bno - 1) == 0:
            now = time.time()
            print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(
                    avloss.n, avloss.mean(), sincelast.mean(), acc.mean(), 
                    accsincelast.mean(), avreward.mean(), rewardsincelast.mean(),
                    now - start),
                  flush=True)
            sincelast, accsincelast, rewardsincelast = [ EasyAcc() for _ in range(3) ]

    now = time.time()
    print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(
            avloss.n, avloss.mean(), sincelast.mean(), acc.mean(), 
            accsincelast.mean(), avreward.mean(), rewardsincelast.mean(),
            now - start),
          flush=True)

In [3]:
mydata = loadMyDataset(2000)

In [78]:
learnOnline(mydata, cuda=False, initlr=0.33, tzero=100000, rank=50, iota=14000*50/311)

n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.03622   
2    	0.61319   	0.53324   	0.01562   	0.00000   	0.00000   	0.00000   	0.07019   
3    	0.53545   	0.37998   	0.01042   	0.00000   	0.00000   	0.00000   	0.10568   
5    	0.41075   	0.22369   	0.00625   	0.00000   	0.00000   	0.00000   	0.17272   
9    	0.26150   	0.07493   	0.00347   	0.00000   	0.00000   	0.00000   	0.30608   
17   	0.14523   	0.01443   	0.00184   	0.00000   	0.00000   	0.00000   	0.58442   
33   	0.09242   	0.03630   	0.00473   	0.00781   	0.00284   	0.00586   	1.23097   
65   	0.06370   	0.03408   	0.00577   	0.00684   	0.00385   	0.00488   	2.53459   
129  	0.05744   	0.05109   	0.00872   	0.01172   	0.00630   	0.00879   	4.78488   
257  	0.04618   	0.03483   	0.01021   	0.01172   	0.00657   	0.00684   	9.22068   
513  	0.04758   	0.04899   	0.01377   	0.01733   	0.00926   	0.01196   	18.48969  
1025

In [79]:
# statistically equivalent with duplicated actions
learnOnline(mydata, cuda=False, initlr=0.33, tzero=100000, rank=50, iota=14000*50/311, extra=1024)

n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	0.04451   
2    	0.61319   	0.53324   	0.01562   	0.00000   	0.00000   	0.00000   	0.10104   
3    	0.53545   	0.37998   	0.01042   	0.00000   	0.00000   	0.00000   	0.14878   
5    	0.41075   	0.22369   	0.00625   	0.00000   	0.00000   	0.00000   	0.23764   
9    	0.26150   	0.07493   	0.00347   	0.00000   	0.00000   	0.00000   	0.42172   
17   	0.14523   	0.01443   	0.00184   	0.00000   	0.00000   	0.00000   	0.77712   
33   	0.09242   	0.03630   	0.00473   	0.00781   	0.00284   	0.00586   	1.42541   
65   	0.06370   	0.03408   	0.00577   	0.00684   	0.00385   	0.00488   	2.77199   
129  	0.05744   	0.05109   	0.00872   	0.01172   	0.00630   	0.00879   	5.48490   
257  	0.04618   	0.03483   	0.01021   	0.01172   	0.00657   	0.00684   	11.57694  
513  	0.04758   	0.04899   	0.01377   	0.01733   	0.00926   	0.01196   	23.83897  
1025

In [4]:
mydata = loadMyDataset(1000)

In [None]:
learnOnline(mydata, initlr=4e-1, tzero=100000, rank=50, iota=666, batch_size=32, cuda=True)