In [45]:
import sys
import os
import torchvision
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

import torchvision.transforms as transforms
import torch
from torch.utils.data.sampler import Sampler
import numpy as np
import itertools
import random

from torchvision.datasets import SVHN, MNIST, FashionMNIST, CIFAR10, CelebA, Omniglot

In [2]:
def make_ssl_cifar_data_loaders(
        data_path,
        label_path,
        labeled_batch_size,
        unlabeled_batch_size,
        num_workers, 
        transform_train, 
        transform_test, 
        use_validation=True, 
        ):

    if use_validation:
        print("Using train + validation")
        train_dir = os.path.join(data_path, "train")
        test_dir = os.path.join(data_path, "val")
    else:
        train_dir = os.path.join(data_path, "train+val")
        test_dir = os.path.join(data_path, "test")
    train_set = torchvision.datasets.ImageFolder(train_dir, transform_train)
    test_set = torchvision.datasets.ImageFolder(test_dir, transform_test)

    with open(label_path) as f:
        labels = dict(line.split(' ') for line in f.read().splitlines())
    labeled_idxs, unlabeled_idxs, num_classes = relabel_dataset(train_set, labels)
    assert len(train_set.imgs) == len(labeled_idxs) + len(unlabeled_idxs)

    print("Num classes", num_classes)
    print("Labeled data: ", len(labeled_idxs))
    print("Unlabeled data:", len(unlabeled_idxs))

    batch_sampler = LabeledUnlabeledBatchSampler(
            labeled_idxs, unlabeled_idxs, labeled_batch_size, unlabeled_batch_size)

    train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
            test_set,
            batch_size=labeled_batch_size+unlabeled_batch_size,
            shuffle=False,
            num_workers=2*num_workers,  # Needs images twice as fast
            pin_memory=True,
            drop_last=False)

    return train_loader, test_loader, num_classes


#PAVEL: relabels the dataset using the labels file.
def relabel_dataset(dataset, labels):
    num_classes = 0
    unlabeled_idxs = []
    for idx in range(len(dataset.imgs)):
        path, _ = dataset.imgs[idx]
        filename = os.path.basename(path)
        if filename in labels:
            label_idx = dataset.class_to_idx[labels[filename]]
            if label_idx > num_classes:
                num_classes = label_idx
            dataset.imgs[idx] = path, label_idx
            del labels[filename]
        else:
            dataset.imgs[idx] = path, NO_LABEL
            unlabeled_idxs.append(idx)

    num_classes += 1

    if len(labels) != 0:
        message = "List of unlabeled contains {} unknown files: {}, ..."
        some_missing = ', '.join(list(labels.keys())[:5])
        raise LookupError(message.format(len(labels), some_missing))

    labeled_idxs = sorted(set(range(len(dataset.imgs))) - set(unlabeled_idxs))

    return labeled_idxs, unlabeled_idxs, num_classes


In [8]:
class LabeledUnlabeledBatchSampler(Sampler):
    """Minibatch index sampler for labeled and unlabeled indices. 

    An epoch is one pass through the labeled indices.
    """
    def __init__(
            self, 
            labeled_idx, 
            unlabeled_idx, 
            labeled_batch_size, 
            unlabeled_batch_size):

        self.labeled_idx = labeled_idx
        self.unlabeled_idx = unlabeled_idx
        self.unlabeled_batch_size = unlabeled_batch_size
        self.labeled_batch_size = labeled_batch_size

        assert len(self.labeled_idx) >= self.labeled_batch_size > 0
        assert len(self.unlabeled_idx) >= self.unlabeled_batch_size > 0

    @property
    def num_labeled(self):
        return len(self.labeled_idx)

    def __iter__(self):
        print("Balle balle")
        labeled_iter = iterate_once(self.labeled_idx)
        unlabeled_iter = iterate_eternally(self.unlabeled_idx)
        return (
            labeled_batch + unlabeled_batch
            for (labeled_batch, unlabeled_batch)
            in  zip(batch_iterator(labeled_iter, self.labeled_batch_size),
                    batch_iterator(unlabeled_iter, self.unlabeled_batch_size))
        )

    def __len__(self):
        return len(self.labeled_idx) // self.labeled_batch_size


def iterate_once(iterable):
    return np.random.permutation(iterable)


def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())


def batch_iterator(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    args = [iter(iterable)] * n
    return zip(*args)


class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2

In [27]:
NO_LABEL = -1

transform_train = transforms.Compose([
        # transforms.RandomCrop(32, padding=4),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

transform_test = transforms.Compose([
    transforms.ToTensor()
])

transform_train = TransformTwice(transform_train)

trainloader, testloader, _ = make_ssl_cifar_data_loaders(
    data_path='/scratch/rm5708/ml/ML_Project/flowgmm-public/data/images/cifar/cifar10/by-image/',
    label_path='/scratch/rm5708/ml/ML_Project/flowgmm-public/data/labels/cifar10/1000_balanced_labels/00.txt',
    labeled_batch_size=32,
    unlabeled_batch_size=32,
    num_workers=2, 
    transform_train=transform_train, 
    transform_test=transform_test, 
    use_validation=False,
)

Num classes 10
Labeled data:  1000
Unlabeled data: 49000


In [28]:
s = iter(trainloader)

Balle balle
Balle balle


In [29]:
bs = next(s)

In [43]:
for ((x1, x2), y) in trainloader:
    print(x1.shape)
    print(x2.shape)
    print(y.shape)
    print((x1 == x2).all())
    break

Balle balle
Balle balle
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([64])
tensor(True)


In [31]:
train_set = torchvision.datasets.ImageFolder('/scratch/rm5708/ml/ML_Project/flowgmm-public/data/images/cifar/cifar10/by-image/', transform_train)

In [None]:
class Dataset():
    def __init__(self, config: dict):
        self.data_keys = set(['mnist', 'fashionmnist', 'cifar', 'svhn'])
        self.config = config
        self.labeled_ids = []
        self.unlabeled_ids = []
        self.image_tensors = []
        self.labels = []
    
    def prepare(self, in_data='mnist', indata_size=5000, outdata_size=1700, label_ratio=0.1):
        print(self.data_keys)
        self.data_keys.remove(in_data)
        # Prepare OOD data
        for k in self.data_keys:
            dataset = config[k]['dataset']
            transforms = config[k]['transforms']
            start_id = len(self.labels)
            end_id = start_id + int(label_ratio * outdata_size)
            for i, (img, _) in enumerate(dataset):
                if i == outdata_size:
                    break
                img_tensor = transforms(img)
                self.image_tensors.append(img_tensor)
            self.labels += [0] * (int(label_ratio * outdata_size))
            self.labels += [-1] * (int((1 - label_ratio) * outdata_size))
            self.labeled_ids += range(start_id, end_id)
            self.unlabeled_ids += range(end_id, len(self.labels))
        
        # Prepare ID data
        dataset = config[in_data]['dataset']
        transforms = config[in_data]['transforms']
        start_id = len(self.labels)
        end_id = start_id + int(label_ratio * indata_size)
        for i, (img, _) in enumerate(dataset):
            if i == indata_size:
                break
            img_tensor = transforms(img)
            self.image_tensors.append(img_tensor)
        self.labels += [1] * (int(label_ratio * indata_size))
        self.labels += [-1] * (int((1 - label_ratio) * indata_size))
        self.labeled_ids += range(start_id, end_id)
        self.unlabeled_ids += range(end_id, len(self.labels))
        
        random.shuffle(self.labeled_ids)
        random.shuffle(self.unlabeled_ids)
    
    def __len__(self):
        return len(self.image_tensors)
    
    def __getitem__(self, idx):
        return self.image_tensors[idx], self.labels[idx]

In [47]:
root = "/scratch/rm5708/ml/ML_Project/"

data_dir = os.path.join(root, 'data')

In [53]:
mnist_dataset = MNIST(root=data_dir, train=True, download=True)

In [56]:
mnist_ids = np.random.randint(1000, size=60000)

In [58]:
mnist_labeled_ids = np.random.choice(mnist_ids, int(1000 * 0.02))

ValueError: only one element tensors can be converted to Python scalars