# Dataset

In [1]:
import torch

class EasyAcc:
    def __init__(self):
        self.n = 0
        self.sum = 0
        self.sumsq = 0

    def __iadd__(self, other):
        self.n += 1
        self.sum += other
        self.sumsq += other*other
        return self

    def __isub__(self, other):
        self.n += 1
        self.sum -= other
        self.sumsq += other*other
        return self

    def mean(self):
        return self.sum / max(self.n, 1)

    def var(self):
        from math import sqrt
        return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)

    def semean(self):
        from math import sqrt
        return self.var() / sqrt(max(self.n, 1))

class MyDataset(torch.utils.data.Dataset):
    def __init__(self):
        import numpy
        
        self.data = numpy.load('652_contest.npy')
        self.ys = torch.Tensor(self.data)
        
    def __len__(self):
        return 10**5

    def __getitem__(self, index):
        # Select sample
        generator = torch.Generator().manual_seed(index)
        realized = torch.bernoulli(input=self.ys, generator=generator)
        return realized

In [8]:
class CountTable(object):
    def __init__(self, naction):
        self.table = [ [0, 0, 0.0] for _ in range(naction) ]
        
    def fhat(self):
        return torch.Tensor([ x[2] for x in self.table ])
        
    def update(self, action, reward):
        for a, r in map(lambda x: (x[0].item(), x[1].item()), zip(action, reward)):
            self.table[a][1] += 1
            self.table[a][0] += r
            self.table[a][2] = self.table[a][0] / self.table[a][1]

class CorralIGW(object):
    def __init__(self, *, eta, gammamin, gammamax, nalgos, device):
        import numpy
        
        super(CorralIGW, self).__init__()
        
        self.eta = eta / nalgos
        self.gammas = torch.Tensor(numpy.geomspace(gammamin, gammamax, nalgos), device=device)
        self.invpalgo = torch.Tensor([ self.gammas.shape[0] ] * self.gammas.shape[0], device=device)
        
    def update(self, algo, invprop, reward):
        import numpy
        from scipy import optimize
        
        assert torch.all(reward >= 0) and torch.all(reward <= 1), reward
        
        weightedlosses = self.eta * (-reward.squeeze(1)) * invprop.squeeze(1)
        newinvpalgo = torch.scatter(input=self.invpalgo,
                                    dim=0,
                                    index=algo,
                                    src=weightedlosses,
                                    reduce='add')
                                    
        # just do this calc on the cpu
        invp = newinvpalgo.cpu().numpy() 
        invp += 1 - numpy.min(invp)
        Zlb = 0
        Zub = 1
        while (numpy.sum(1 / (invp + Zub)) > 1):
            Zlb = Zub
            Zub *= 2 
        root, res = optimize.brentq(lambda z: 1 - numpy.sum(1 / (invp + z)), Zlb, Zub, full_output=True)
        assert res.converged, res
        
        self.invpalgo = torch.Tensor(invp + root, device=self.invpalgo.device)
 
    def sample(self, fhat):
        N, K = fhat.shape

        algosampler = torch.distributions.categorical.Categorical(probs=1.0/self.invpalgo, validate_args=False)
        algo = algosampler.sample((N,))
        invpalgo = torch.gather(input=self.invpalgo.unsqueeze(0).expand(N, -1),
                                dim=1,
                                index=algo.unsqueeze(1))
        gamma = torch.gather(input=self.gammas.unsqueeze(0).expand(N, -1),
                             dim=1,
                             index=algo.unsqueeze(1))
        
        fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)
        rando = torch.randint(high=K, size=(N, 1), device=fhat.device)
        fhatrando = torch.gather(input=fhat, dim=1, index=rando)
        probs = K / (K + gamma * (fhatstar - fhatrando))
        unif = torch.rand(size=(N, 1), device=fhat.device)
        shouldexplore = (unif <= probs).long()
        return (ahatstar + shouldexplore * (rando - ahatstar)).squeeze(1), algo, invpalgo        

def learnOnline(dataset, *, seed, eta, gammamin, gammamax, nalgos, batch_size):
    import time
    
    trajectory = []

    torch.manual_seed(seed)
    generator = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    log_loss = torch.nn.BCELoss()
    model = None
        
    print('{:<5s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}\t{:<10s}'.format(
            'n', 'loss', 'since last', 'acc', 'since last', 'reward', 'since last', 'dt (sec)'),
          flush=True)
    avloss, sincelast, acc, accsincelast, avreward, rewardsincelast = [ EasyAcc() for _ in range(6) ]
    
    for bno, ys in enumerate(generator):
        if model is None:
            import numpy as np
            model = CountTable(naction=ys.shape[1])
            sampler = CorralIGW(eta=eta, gammamin=gammamin, gammamax=gammamax, nalgos=nalgos, device=ys.device)
            start = time.time()
            
        with torch.no_grad():
            fhat = model.fhat().unsqueeze(0).expand(ys.shape[0], -1)
            sample, algo, invpalgo = sampler.sample(fhat)
            reward = torch.gather(input=ys, dim=1, index=sample.unsqueeze(1)).float()
            
        with torch.no_grad():
            samplefhat = torch.gather(input=fhat, index=sample.unsqueeze(1), dim=1)
            loss = log_loss(samplefhat, reward)
            model.update(sample, reward.squeeze(1))
        
        with torch.no_grad():
            pred = torch.argmax(fhat, dim=1)
            ypred = torch.gather(input=ys, dim=1, index=pred.unsqueeze(1))
            acc += torch.mean(ypred).float()
            accsincelast += torch.mean(ypred).float()
            avloss += loss
            sincelast += loss
            avreward += torch.mean(reward)
            rewardsincelast += torch.mean(reward)
            sampler.update(algo, invpalgo, reward)
        
        if bno & (bno - 1) == 0:
            now = time.time()
            print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(
                    avloss.n, avloss.mean(), sincelast.mean(), acc.mean(),
                    accsincelast.mean(), avreward.mean(), rewardsincelast.mean(),
                    now - start),
                  flush=True)
            trajectory.append((avloss.n, 1 - avreward.mean().item()))
            sincelast, accsincelast, rewardsincelast = [ EasyAcc() for _ in range(3) ]
            #print(f'sampler.palgo = { 1/sampler.invpalgo }')

    now = time.time()
    print('{:<5d}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}\t{:<10.5f}'.format(
            avloss.n, avloss.mean(), sincelast.mean(), acc.mean(),
            accsincelast.mean(), avreward.mean(), rewardsincelast.mean(),
            now - start),
          flush=True)
    trajectory.append((avloss.n, 1 - avreward.mean().item()))
    #print(f'sampler.palgo = { 1/sampler.invpalgo }')
    
    return trajectory

In [9]:
trajectory = learnOnline(MyDataset(), seed=4545, batch_size=1, eta=1, gammamin=1000, gammamax=1000000, nalgos=8)

n    	loss      	since last	acc       	since last	reward    	since last	dt (sec)  
1    	0.00000   	0.00000   	1.00000   	1.00000   	0.00000   	0.00000   	0.00691   
2    	0.00000   	0.00000   	1.00000   	1.00000   	0.00000   	0.00000   	0.01941   
3    	33.33333  	100.00000 	1.00000   	1.00000   	0.33333   	1.00000   	0.02731   
5    	40.00000  	50.00000  	0.60000   	0.00000   	0.20000   	0.00000   	0.03844   
9    	44.52146  	50.17329  	0.44444   	0.25000   	0.22222   	0.25000   	0.07238   
17   	58.96970  	75.22397  	0.41176   	0.37500   	0.41176   	0.62500   	0.15753   
33   	39.49024  	18.79332  	0.66667   	0.93750   	0.66667   	0.93750   	0.24698   
65   	30.88296  	22.00671  	0.72308   	0.78125   	0.63077   	0.59375   	0.43491   
129  	25.01378  	19.05289  	0.74419   	0.76562   	0.61240   	0.59375   	0.63657   
257  	22.81344  	20.59592  	0.73541   	0.72656   	0.62257   	0.63281   	0.86887   
513  	20.94362  	19.06649  	0.72710   	0.71875   	0.64133   	0.66016   	1.37877   
1025

In [11]:
trajectory, [ (n, n*v) for n, v in trajectory ]

([(1, 1.0),
  (2, 1.0),
  (3, 0.6666666567325592),
  (5, 0.7999999970197678),
  (9, 0.7777777761220932),
  (17, 0.5882352888584137),
  (33, 0.3333333134651184),
  (65, 0.3692307472229004),
  (129, 0.38759690523147583),
  (257, 0.3774319291114807),
  (513, 0.35867446660995483),
  (1025, 0.33463412523269653),
  (2049, 0.28989750146865845),
  (4097, 0.180375874042511),
  (8193, 0.10496765375137329),
  (16385, 0.06121450662612915),
  (32769, 0.04290640354156494),
  (65537, 0.03454536199569702),
  (100000, 0.028670012950897217)],
 [(1, 1.0),
  (2, 2.0),
  (3, 1.9999999701976776),
  (5, 3.999999985098839),
  (9, 6.999999985098839),
  (17, 9.999999910593033),
  (33, 10.999999344348907),
  (65, 23.999998569488525),
  (129, 50.00000077486038),
  (257, 97.00000578165054),
  (513, 184.00000137090683),
  (1025, 342.99997836351395),
  (2049, 593.9999805092812),
  (4097, 738.9999559521675),
  (8193, 859.9999871850014),
  (16385, 1002.9996910691261),
  (32769, 1405.9999376535416),
  (65537, 2263.9993