In [2]:
from run_sacred import *

In [4]:
monitor_per = 100
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [13]:
def get_dataset(name, validation, test_domain, L, K):
    """Prepare datasets for train, valid and test with configurations.

    Parameter
    ---------
    name : str
    validation : str or list
    test_domain : str or list
    L : int
    K : int

    Return
    ------
    {train/valid}_dataset_{joint/marginal} : torch.Dataset
        {train/valid} datasets from {joint/marginal} distributions
    """
    if isinstance(validation, str):
        validation = validation.split('-')
    all_adls = OppG.get('all_adls')
    all_domain = OppG.get('all_domain_key')
    train_adls = sorted(list(set(all_adls) - set(validation)))
    train_domain = sorted(list(set(all_domain) - set([test_domain])))
    train_dataset_joint = OppG(
        train_domain, l_sample=30, interval=15, T=K+L, adl_ids=train_adls)
    valid_dataset_joint = OppG(
        train_domain, l_sample=30, interval=15, T=K+L, adl_ids=validation)

    # marginal sample come from same datasets for simplicity
    # Same train-valid split with joint dataset
    train_dataset_marginal = OppG(
        train_domain, l_sample=30, interval=15, T=K, adl_ids=train_adls)
    valid_dataset_marginal = OppG(
        train_domain, l_sample=30, interval=15, T=K, adl_ids=validation)
    test_dataset = OppG(test_domain, l_sample=30, interval=15, T=K+L)
    return train_dataset_joint, valid_dataset_joint, train_dataset_marginal, valid_dataset_marginal, test_dataset

datasets = get_dataset('opp', 'ADL4-ADL5', 'S1', 12, 3)
train_dataset_joint, valid_dataset_joint, train_dataset_marginal, valid_dataset_marginal, _ = datasets



### JointとMarginalが必ず、別のユーザからサンプリングされるようにしたい

In [7]:
import torch

In [30]:
import torch
import torch.utils.data
import torchvision


class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices (list, optional): a list of indices
        num_samples (int, optional): number of samples to draw
    """

    def __init__(self, dataset, indices=None, num_samples=None):
                
        # if indices is not provided, 
        # all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) \
            if indices is None else indices
            
        # if num_samples is not provided, 
        # draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) \
            if num_samples is None else num_samples
            
        # distribution of classes in the dataset 
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            if label in label_to_count:
                label_to_count[label] += 1
            else:
                label_to_count[label] = 1
                
        # weight for each sample
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
                   for idx in self.indices]
        self.weights = torch.DoubleTensor(weights)

    def _get_label(self, dataset, idx):
        dataset_type = type(dataset)
        for i in range(len(dataset.cummulative_sizes)):
            if idx < i:
                return i
                
    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(
            self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples

In [31]:
idx = 10
[x > 10 for x in train_dataset_joint.cumulative_sizes]

[True, True, True]

In [34]:
train_loader = torch.utils.data.DataLoader(
    train_dataset_joint, 
    sampler=ImbalancedDatasetSampler(train_dataset_joint),
    batch_size=128, 
)




In [40]:
train_loader.sampler.weights

tensor([1.0000e+00, 1.0000e+00, 3.2695e-05,  ..., 3.2695e-05, 3.2695e-05,
        3.2695e-05], dtype=torch.float64)

In [125]:
from torch.utils.data import Sampler
from torch._six import int_classes as _int_classes


class SplitBatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler, batch_size, drop_last):
        for _sampler in sampler:
            if not isinstance(_sampler, Sampler):
                raise ValueError("sampler should be an instance of "
                                 "torch.utils.data.Sampler, but got sampler={}"
                                 .format(_sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integeral value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.num_sampler = len(sampler)
        self.which_sampler = [int(batch_size/self.num_sampler*(i+1)) for i in range(self.num_sampler)]
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        sampler_idx = 0
        
        for i in range(self.batch_size):
            idx = self.sampler[sampler_idx].__iter__().__next__()
            batch.append(idx)
            if len(batch) == self.which_sampler[sampler_idx]:
                sampler_idx += 1
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

        
def get_split_samplers(dataset, ids):
    assert len(dataset.datasets) == len(ids)
    size = [0] + dataset.cummulative_sizes
    sampler = []
    for i in ids:
        sampler.append(SubsetRandomSampler(range(size[i], size[i+1])))
    return sampler

sampler = get_split_samplers(train_dataset_joint, [0, 1, 2])
batch_smpler = OriginalBatchSampler(get_split_samplers(train_dataset_joint, [0, 1, 2]), 128, False)

In [122]:
sampler = SubsetRandomSampler(range(train_dataset_joint.cummulative_sizes[0]))
sampler1 = SubsetRandomSampler(range(100))
sampler2 = SubsetRandomSampler(range(100, 200))
batch_smpler = OriginalBatchSampler([sampler1, sampler2], 20, False)

  """Entry point for launching an IPython kernel.


In [138]:
batch_smpler.__iter__().__next__()

[2409,
 1773,
 3697,
 3843,
 854,
 1652,
 3520,
 8698,
 1785,
 6773,
 9680,
 990,
 2891,
 4658,
 4597,
 6586,
 4500,
 5134,
 5490,
 1702,
 10220,
 8117,
 9658,
 6475,
 6345,
 5165,
 5300,
 4352,
 9746,
 2685,
 9617,
 4197,
 3058,
 2646,
 1298,
 8865,
 7031,
 2558,
 10215,
 5750,
 7704,
 4784,
 16847,
 14490,
 19986,
 17321,
 21335,
 18366,
 12224,
 16873,
 14565,
 11454,
 18061,
 13758,
 11899,
 10827,
 13605,
 18330,
 12468,
 13337,
 15614,
 15197,
 10677,
 16573,
 17441,
 12106,
 18419,
 19984,
 13413,
 11579,
 15240,
 20550,
 17100,
 15673,
 19191,
 20734,
 15744,
 13460,
 14184,
 11431,
 17909,
 19065,
 15606,
 15780,
 11314,
 23498,
 25296,
 22388,
 27062,
 30561,
 27505,
 24021,
 25745,
 29287,
 29628,
 24752,
 28805,
 30065,
 27145,
 25280,
 29870,
 29502,
 23850,
 21695,
 27547,
 25488,
 24700,
 29961,
 27605,
 27172,
 24338,
 21548,
 27679,
 25102,
 22709,
 30109,
 23304,
 23318,
 24571,
 30478,
 27075,
 23724,
 24848,
 26241,
 27456,
 22557,
 28561,
 26478]

  This is separate from the ipykernel package so we can avoid doing imports until


  This is separate from the ipykernel package so we can avoid doing imports until


In [129]:
dataset.cummulative_sizes[i]

  """Entry point for launching an IPython kernel.


10616

In [130]:
size = [0] + dataset.cummulative_sizes

  """Entry point for launching an IPython kernel.


[0, 10616, 21336, 30588]