# Dataset

In [1]:
import torch

class EasyAcc:
    def __init__(self):
        self.n = 0
        self.sum = 0
        self.sumsq = 0

    def __iadd__(self, other):
        self.n += 1
        self.sum += other
        self.sumsq += other*other
        return self

    def __isub__(self, other):
        self.n += 1
        self.sum -= other
        self.sumsq += other*other
        return self

    def mean(self):
        return self.sum / max(self.n, 1)

    def var(self):
        from math import sqrt
        return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)

    def semean(self):
        from math import sqrt
        return self.var() / sqrt(max(self.n, 1))

def categoryCount():
    from collections import defaultdict
    import gzip
    import json
        
    counts = {}

    with gzip.open('entityfreq.gz', 'rt') as f:
        for line in f:
            try:
                freq, entity = line.strip().split()
            except:
                continue
            counts[entity] = int(freq)
            
    return counts

def getCategories(threshold):
    from sentence_transformers import SentenceTransformer
    import gzip
    import json
    import re
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
        
    for entity, freq in categoryCount().items():
        if freq >= threshold:
            niceentity = re.sub(r'_', r' ', entity)
            embedcat = model.encode([niceentity])[0]
            yield entity, embedcat

def datasetStats(threshold):
    numclasses = len([ entity for entity, freq in categoryCount().items() if freq >= threshold ])
    return { 'numclasses': numclasses, 'numexamples': threshold * numclasses }
            
def makeData(threshold, categories):
    from collections import defaultdict
    from sentence_transformers import SentenceTransformer
    import json
    
    model = SentenceTransformer('bert-base-nli-mean-tokens')
    catcount = defaultdict(int)
    
    with open('shuffled_dedup_entities.tsv') as f:
        batchline, batchencode, batchentity = [], [], []
        for line in f:
            try:
                entity, pre, mention, post = line.strip().split('\t')
            except:
                continue
                
            if entity in categories and catcount[entity] < threshold:
                catcount[entity] += 1
                batchline.append(line)
                batchencode.append(pre)
                batchencode.append(post)
                batchentity.append(entity)

                if len(batchline) == 5:
                    embed = model.encode(batchencode)

                    for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                        embedpre, embedpost = embed[2*n], embed[2*n+1]
                        entityord, entityvec = categories[entity]
                        yield { 'line': line, 
                                'entityord': entityord, 
                                'entityvec': entityvec,
                                'pre': embedpre, 
                                'post': embedpost }

                    batchline, batchencode, batchentity = [], [], []
                
        if len(batchline):
            embed = model.encode(batchencode)

            for n, (line, entity) in enumerate(zip(batchline, batchentity)):
                embedpre, embedpost = embed[2*n], embed[2*n+1]
                entityord, entityvec = categories[entity]
                yield { 'line': line, 
                        'entityord': entityord, 
                        'entityvec': entityvec,
                        'pre': embedpre, 
                        'post': embedpost }

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, threshold):
        from tqdm.notebook import tqdm
        self.labelfeats = { k: (n, v) for n, (k, v) in enumerate(getCategories(threshold)) } 
        Xs = []
        ys = []
        for n, what in tqdm(enumerate(makeData(threshold, self.labelfeats))):
#             if n >= 1000:
#                 break
            pre = torch.tensor(what['pre'])
            post = torch.tensor(what['post'])
            Xs.append(torch.cat((pre, post)).unsqueeze(0))
            ys.append(what['entityord'])

        self.Xs = torch.cat(Xs, dim=0)
        self.ys = torch.LongTensor(ys)
            
    def __len__(self):
        return self.Xs.shape[0]

    def __getitem__(self, index):
        # Select sample
        return self.Xs[index], self.ys[index]

In [2]:
datasetStats(2000), datasetStats(1000), datasetStats(100)

({'numclasses': 311, 'numexamples': 622000},
 {'numclasses': 1154, 'numexamples': 1154000},
 {'numclasses': 32089, 'numexamples': 3208900})

## This takes time, run once only (days)

In [3]:
def makeMyDataset(threshold):
    import gzip
    
    foo = MyDataset(threshold)
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'wb') as handle:
        import pickle
        pickle.dump(foo, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
makeMyDataset(2000)

0it [00:00, ?it/s]

## Load cached processed data

In [2]:
def loadMyDataset(threshold):
    import gzip
    
    with gzip.open(f'mydataset.{threshold}.pickle.gz', 'rb') as handle:
        import pickle
        return pickle.load(handle)

In [3]:
# best constant predictor
# if you don't beat this, you have a problem

def bestconstant(threshold):
    from math import fsum
    
    counts = { k: threshold for k, v in categoryCount().items() if v >= threshold }
    sumcounts = fsum(v for v in counts.values())
    predict = torch.Tensor([ v / sumcounts for v in counts.values() ]).unsqueeze(0)
    log_loss = torch.nn.CrossEntropyLoss()
    sumloss, denom = EasyAcc(), 0
    
    for m, k in enumerate(counts.keys()):
        n = counts[k]
        actual = torch.LongTensor([m])
        sumloss += n * log_loss(predict, actual).item()
        denom += n
    
    return { 'best_constant_answer': max((v, k) for k, v in counts.items())[1], 
             'best_constant_average_logloss': sumloss.sum / denom,
             'best_constant_average_accuracy': max(v for v in counts.values()) / denom }            

bestconstant(2000), bestconstant(1000), bestconstant(200)

({'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 5.739792823791504,
  'best_constant_average_accuracy': 0.003215434083601286},
 {'best_constant_answer': 'public_domain',
  'best_constant_average_logloss': 7.050989627838135,
  'best_constant_average_accuracy': 0.0008665511265164644},
 {'best_constant_answer': 'weight_gain',
  'best_constant_average_logloss': 9.54902458190918,
  'best_constant_average_accuracy': 7.127075760815338e-05})

In [7]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, d, device):
        super(ResidualBlock, self).__init__()
        
        self.W = torch.nn.Parameter(torch.zeros(d, d, device=device))
        self.afunc = torch.nn.LeakyReLU(negative_slope=0.01, inplace=True)
        
    def forward(self, X):
        return X + 0.001 * self.afunc(torch.matmul(X, self.W))
    
class BilinearResidual(torch.nn.Module):
    def __init__(self, dobs, daction, device, depth):
        super(BilinearResidual, self).__init__()
        
        self.block = torch.nn.Sequential(*[ResidualBlock(dobs, device) for _ in range(depth) ])
        self.W = torch.nn.Parameter(torch.zeros(dobs, daction, device=device))
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, Xs, Zs):
        return torch.matmul(torch.matmul(self.block(Xs), self.W), Zs.T)
    
    def preq1(self, logits):
        return self.sigmoid(logits)

class SpannerIGW(torch.nn.Module):
    def __init__(self, actions, iota):
        super(SpannerIGW, self).__init__()
        
        self.rawactions = actions
        self.iota = iota
          
    def _make_spanner(self, actions):
        from math import log
        
        device = self.rawactions.device
        
        slowdet = False
        
        # Algorithm 4 Approximate Barycentric Identification (Awerbuch and Kleinberg, 2008)
                  
        N, K, D = actions.shape
        S = torch.eye(D, device=device).unsqueeze(0).expand(N, -1, -1).clone()   # N x D x D
        logscalefac = torch.zeros(N, device=device)
        design = torch.zeros(N, D, device=device).long()
        coords = S.clone()
        for i in range(D):
            Sscale = S / torch.exp(logscalefac.reshape(shape=(N, 1, 1)))
            Ascale = actions / torch.exp(logscalefac.reshape(shape=(N, 1, 1)))

            if slowdet:
                Sprime = Sscale.unsqueeze(1).expand(-1, K, -1, -1).clone()     # N x K x D x D
                Sprime[:,:,:,i] = Ascale
                dets = torch.abs(torch.linalg.det(Sprime))
            else: 
                Sprimeprime = Sscale.unsqueeze(1).expand(-1, D, -1, -1).clone()     # N x D x D x D
                Sprimeprime[:,:,:,i] = coords
                psi = torch.linalg.det(Sprimeprime) # N x D 
                dets = torch.abs(torch.bmm(Ascale, psi.unsqueeze(2))).squeeze(2) 
                
            fprime, aprime = torch.max(dets, dim=1)
            design[:,i] = aprime
            Y = torch.gather(input=actions, dim=1, index=aprime.reshape(N, 1, 1).expand(N, 1, D)).squeeze(1)
            S[:,:,i] = Y
            logscalefac += 1/D * torch.log(fprime) # determinant is getting huge, scale to avoid floating point issues
            
        C = 2
        for _ in range(int(D * log(D))):
            replaced = False
            for i in range(D):
                Sscale = S / torch.exp(logscalefac.reshape(shape=(N, 1, 1)))
                Ascale = actions / torch.exp(logscalefac.reshape(shape=(N, 1, 1)))

                if slowdet:
                    Sprime = Sscale.unsqueeze(1).expand(-1, K, -1, -1).clone()     # N x K x D x D
                    Sprime[:,:,:,i] = Ascale
                    dets = torch.abs(torch.linalg.det(Sprime))
                else: 
                    Sprimeprime = Sscale.unsqueeze(1).expand(-1, D, -1, -1).clone()     # N x D x D x D
                    Sprimeprime[:,:,:,i] = coords
                    psi = torch.linalg.det(Sprimeprime) # N x D 
                    dets = torch.abs(torch.bmm(Ascale, psi.unsqueeze(2))).squeeze(2) 

                fprime, aprime = torch.max(dets, dim=1)
                
                if torch.any(fprime >= C): # fprime >= C * det(S/scalefac), but our scaling ensures det(S/scalefac) == 1 
                    design[:,i] = aprime
                    Y = torch.gather(input=actions, dim=1, index=aprime.reshape(N, 1, 1).expand(N, 1, D)).squeeze(1)
                    S[:,:,i] = Y
                    logscalefac += 1/D * torch.log(fprime)
                    replaced = True
                    break

            if not replaced:
                break

        return design

    def _reweight_actions(self, fhat, fhatstar):
        N = fhat.shape[0]
        d = self.rawactions.shape[1]
        sqrtdenom = torch.sqrt(1 + d + self.iota * (fhatstar - fhat)).unsqueeze(-1)
        return self.rawactions.expand(N, -1, -1) / sqrtdenom

    def sample(self, fhat):
        device = self.rawactions.device

        N = fhat.shape[0]
        D = self.rawactions.shape[1]
        fhatstar, exploit = torch.max(fhat, dim=1, keepdim=True)
        Zs = self._reweight_actions(fhat, fhatstar)
        design = self._make_spanner(Zs)
        
        exploreindex = torch.randint(high=D, size=(N, 1), device=fhat.device)
        explore = torch.gather(input=design, dim=1, index=exploreindex)
        fhatexplore = torch.gather(input=fhat, dim=1, index=explore) 
        probs = D / (1 + D + self.iota * (fhatstar - fhatexplore))
        unif = torch.rand(size=(N, 1), device=fhat.device)
        shouldexplore = (unif <= probs).long()
        return (exploit + shouldexplore * (explore - exploit)).squeeze(1)

def learnOnline(dataset, rank, initlr, tzero, iota, depth, cuda=False, seed=4545):
    import time
    
    torch.manual_seed(seed)
    labelfeatsdict = { n: v for n, v in dataset.labelfeats.values() }
    labelfeats = [ torch.tensor(labelfeatsdict[n]).float().unsqueeze(0) for n in range(len(labelfeatsdict)) ]
    Zs = torch.cat(labelfeats, dim=0)
    
    if cuda:
        Zs = Zs.cuda()
    
    if rank is not None:
        with torch.no_grad():
            U, S, Vh = torch.linalg.svd(Zs, full_matrices=False)
            Zs = U[:, :rank] @ torch.diag(S[:rank])
    
    generator = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
    model = None
    log_loss = torch.nn.BCEWithLogitsLoss()
        
    print('{:<5s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}'.format(
            'n', 'loss', 'since last', 'acc', 'since last', 'reward', 'since last', 'dt (sec)'), 
          flush=True)
    avloss, sincelast, acc, accsincelast, avreward, rewardsincelast = [ EasyAcc() for _ in range(6) ]
    
    for bno, (Xs, ys) in enumerate(generator):
        Xs, ys = Xs.to(Zs.device), ys.to(Zs.device)
        
        if model is None:
            import numpy as np
            model = BilinearResidual(dobs=Xs.shape[1], daction=Zs.shape[1], depth=depth, device=Zs.device)
            sampler = SpannerIGW(actions=Zs, iota=iota)
            opt = torch.optim.Adam(( p for p in model.parameters() if p.requires_grad ), lr=initlr)
            scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda = lambda t: np.sqrt(tzero) / np.sqrt(tzero + t))
            start = time.time()
            
        with torch.no_grad():
            fhat = model.preq1(model.forward(0.0001 * Xs, Zs))
            sample = sampler.sample(fhat)
            reward = (sample == ys).unsqueeze(1).float()
            
        opt.zero_grad()
        logit = model.forward(0.0001 * Xs, Zs)
        samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)
        loss = log_loss(samplelogit, reward)
        loss.backward()
        opt.step()
        scheduler.step()
        
        with torch.no_grad():
            pred = torch.argmax(logit, dim=1)
            acc += torch.mean((pred == ys).float())
            accsincelast += torch.mean((pred == ys).float())
            avloss += loss
            sincelast += loss
            avreward += torch.mean(reward)
            rewardsincelast += torch.mean(reward)

        if bno & (bno - 1) == 0:
            now = time.time()
            print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(
                    avloss.n, avloss.mean(), sincelast.mean(), acc.mean(), 
                    accsincelast.mean(), avreward.mean(), rewardsincelast.mean(),
                    now - start),
                  flush=True)
            sincelast, accsincelast, rewardsincelast = [ EasyAcc() for _ in range(3) ]

    now = time.time()
    print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(
            avloss.n, avloss.mean(), sincelast.mean(), acc.mean(), 
            accsincelast.mean(), avreward.mean(), rewardsincelast.mean(),
            now - start),
          flush=True)

In [3]:
mydata = loadMyDataset(2000)

In [8]:
learnOnline(mydata, initlr=0.33, tzero=100000, rank=50, iota=14000*50/311, depth=1)

n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.69315   	0.69315   	0.03125   	0.03125   	0.00000   	0.00000   	4.88414   
2    	0.68814   	0.68313   	0.01562   	0.00000   	0.00000   	0.00000   	11.62514  
3    	0.67061   	0.63556   	0.01042   	0.00000   	0.00000   	0.00000   	16.28875  
5    	0.62251   	0.55037   	0.00625   	0.00000   	0.00000   	0.00000   	29.66229  
9    	0.48351   	0.30975   	0.00347   	0.00000   	0.00000   	0.00000   	66.66567  
17   	0.29189   	0.07632   	0.00551   	0.00781   	0.00368   	0.00781   	121.54505 
33   	0.16682   	0.03393   	0.00379   	0.00195   	0.00379   	0.00391   	186.71466 
65   	0.10445   	0.04014   	0.00192   	0.00000   	0.00385   	0.00391   	334.13751 
129  	0.07253   	0.04011   	0.00218   	0.00244   	0.00412   	0.00439   	678.67136 
257  	0.05489   	0.03712   	0.00426   	0.00635   	0.00486   	0.00562   	1420.09770
513  	0.04502   	0.03511   	0.00591   	0.00757   	0.00579   	0.00671   	2874.62543
1025

In [4]:
mydata = loadMyDataset(1000)

In [None]:
learnOnline(mydata, initlr=4e-1, tzero=100000, rank=50, iota=666, batch_size=32, cuda=True)