In [182]:
import numpy as np

In [273]:
class ImportanceConditionalDatasetSampler:
    def __init__(self, dataset, labels, eps=0.001):
        self.dataset = dataset
        self.labels = labels
        self.num_classes = len(set(labels.tolist()))
        
        self.indices = np.arange(len(self.dataset))
        self.probs = np.ones_like(self.indices) / len(self.indices)
        self.trace = np.ones_like(self.indices) / len(self.indices)
        self.weights = np.ones_like(self.indices) / len(self.indices)
                
        self.last_idx = None
        self.eps = eps
        
    def get_batch(self, batch_size):
        if batch_size > np.sum(self.probs > 0):
            self.probs[self.probs == 0] = self.trace[self.probs == 0]
            self.weights[self.probs == 0] = self.trace[self.probs == 0]
            self.trace = np.ones_like(self.indices) / len(self.indices)
            
        self.probs /= np.sum(self.probs)
        idx = np.random.choice(self.indices, size=(batch_size, ), replace=False, p=self.probs)
        self.probs[idx] = 0
        
        data = self.dataset[idx].clone()
        labels_oh = torch.nn.functional.one_hot(self.labels[idx].clone().long(), self.num_classes).float()
        
        self.last_idx = idx
        return data, labels_oh
    
    def record(self, h):
        h = h.reshape(-1)
        self.trace[self.last_idx] = 1 - h + self.eps
        
    def get_weights(self):
        a = self.weights[self.last_idx]
        return a + (1 - a.mean())

In [234]:
from utils import load_data
import torch

In [252]:
data, labels = load_data('grid')

In [253]:
len(data)

2500

In [236]:
sampler = DifficultyConditionalDatasetSampler(torch.tensor(data).float(), torch.tensor(labels).long()) 

In [237]:
isinstance(sampler, DifficultyConditionalDatasetSampler)

True

In [254]:
sampler.get_batch(1)

(tensor([[-0.5350,  0.5240]]), tensor([[0., 0., 0., 1., 0., 0., 0., 0.]]))

In [247]:
sampler.record(np.array([0]))

In [271]:
a = np.ones(10)/10

In [272]:
(a + (1 - a.mean()))

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

5.796822676124391