In [1]:
from torchvision.transforms import transforms 
from data_aug.gaussian_blur import GaussianBlur
from torchvision import transforms, datasets
from data_aug.view_generator import ContrastiveLearningViewGenerator
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
from exceptions.exceptions import InvalidDatasetSelection
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import torch
from torch.utils.data import ConcatDataset

In [2]:
def get_simclr_pipeline_transform(size, s=1):
        """Return a list of data augmentation transformations as described in the SimCLR paper."""
        color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.RandomApply([color_jitter], p=0.8),
                                                transforms.RandomGrayscale(p=0.2),
                                                GaussianBlur(kernel_size=int(0.1*size)),
                                                transforms.ToTensor()])
        return data_transforms

In [None]:
ContrastiveLearningViewGenerator(
                                                                    self.get_simclr_pipeline_transform(32),
                                                                    n_views)

In [14]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,
                                train=False,
                           transform=ContrastiveLearningViewGenerator(
                            get_simclr_pipeline_transform(32),
                            2))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=len(dataset_test), shuffle=True, num_workers=1)

Files already downloaded and verified


In [15]:
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

In [20]:
# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar_data/cifar10.1_v4_data.npy')

In [21]:
data_T = np.transpose(data_new, [0,3,1,2])
TT = transforms.Compose(ContrastiveLearningViewGenerator(
                            get_simclr_pipeline_transform(32),
                            2))
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,32,32])
data_T_tensor = torch.from_numpy(data_T)

In [13]:
len(label_all)

10000

In [8]:
data_new

array([[[[230, 219, 216],
         [237, 218, 211],
         [236, 218, 215],
         ...,
         [159, 120, 116],
         [198, 155, 140],
         [213, 181, 167]],

        [[237, 220, 212],
         [251, 214, 206],
         [251, 213, 209],
         ...,
         [ 79,  33,  26],
         [147,  77,  67],
         [181, 126, 112]],

        [[240, 217, 209],
         [255, 209, 203],
         [255, 210, 207],
         ...,
         [ 39,  16,  22],
         [ 77,  27,  36],
         [100,  63,  67]],

        ...,

        [[219, 212, 213],
         [217, 210, 207],
         [211, 213, 206],
         ...,
         [215, 211, 209],
         [219, 210, 204],
         [216, 209, 206]],

        [[255, 255, 255],
         [255, 255, 255],
         [253, 255, 255],
         ...,
         [252, 255, 255],
         [254, 255, 255],
         [255, 255, 255]],

        [[255, 255, 255],
         [252, 253, 254],
         [253, 253, 254],
         ...,
         [254, 254, 254],
        

In [None]:
class CustomCIFAR10_1(Dataset):
    def __init__(self, data_path, transform=None):
        self.images = np.load(data_path, allow_pickle=True)  # Assuming this is the image data array
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        sample_image = self.images[idx]

        # Convert to PIL Image
        sample_image = Image.fromarray(np.uint8(sample_image))

        if self.transform:
            sample_image = self.transform(sample_image)

        return sample_image  # If you don't have labels, return just the image

    
class ContrastiveLearningDataset:
    def __init__(self, root_folder):
        self.root_folder = root_folder
        
    @staticmethod
    def get_simclr_pipeline_transform(size, s=1):
        """Return a list of data augmentation transformations as described in the SimCLR paper."""
        color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.RandomApply([color_jitter], p=0.8),
                                                transforms.RandomGrayscale(p=0.2),
                                                GaussianBlur(kernel_size=int(0.1*size)),
                                                transforms.ToTensor()])
        return data_transforms
    
    def get_dataset(self, name, n_views):
        valid_datasets = {'cifar10' : lambda: datasets.CIFAR10(self.root_folder, train=False,
                                                                transform = ContrastiveLearningViewGenerator(
                                                                    self.get_simclr_pipeline_transform(32),
                                                                    n_views),
                                                                download=True),
                        'stl10': lambda: datasets.STL10(self.root_finder, split='unlabeled',
                                                        transform = ContrastiveLearningViewGenerator(
                                                            self.get_simclr_pipeline_transform(96),
                                                            n_views),
                                                        download=True),
                        'cifar10_1': lambda: CustomCIFAR10_1('/data4/oldrain123/C2ST/data/cifar_data/cifar10.1_v4_data.npy',
                                                              transform=ContrastiveLearningViewGenerator(
                                                                  self.get_simclr_pipeline_transform(32),
                                                                  n_views))
                        }
        try:
            dataset_fn = valid_datasets[name]
        except KeyError:
            raise InvalidDatasetSelection()
        else:
            return dataset_fn()

In [None]:
from torch.utils.data import ConcatDataset, Subset, DataLoader
from random import sample

root_folder = "/data4/oldrain123/C2ST/data/cifar_data/cifar10"

class LabeledDataset(Dataset):
    def __init__(self, base_dataset, label):
        self.base_dataset = base_dataset
        self.label = label

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        image = self.base_dataset[idx]
        return image, self.label

# Initialize CIFAR10 and CIFAR10.1 datasets
cifar10_dataset = datasets.CIFAR10(root_folder, train=False, transform=ContrastiveLearningViewGenerator(ContrastiveLearningDataset.get_simclr_pipeline_transform(32), n_views=2), download=True)
cifar10_1_dataset = CustomCIFAR10_1('/data4/oldrain123/C2ST/data/cifar_data/cifar10.1_v4_data.npy', transform=ContrastiveLearningViewGenerator(ContrastiveLearningDataset.get_simclr_pipeline_transform(32), n_views=2))

# Sample 2000 images from CIFAR10
indices = sample(range(len(cifar10_dataset)), len(cifar10_1_dataset), )
cifar10_subset = Subset(cifar10_dataset, indices)

# Add labels
labeled_cifar10 = LabeledDataset(cifar10_subset, 0)
labeled_cifar10_1 = LabeledDataset(cifar10_1_dataset, 1)

# Combine the two datasets
combined_dataset = ConcatDataset([labeled_cifar10, labeled_cifar10_1])

# Create DataLoader
dataloader = DataLoader(combined_dataset, batch_size=256, shuffle=True)

In [None]:
print(len(cifar10_1_dataset))

In [None]:
256 * 16