In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, USPS, SVHN, CIFAR10, STL10

import tensorflow_datasets as tfds

import os
import pickle
import numpy as np
from scipy.io import loadmat
import PIL
from PIL import Image

from tools.autoaugment import SVHNPolicy, CIFAR10Policy
from tools.randaugment import RandAugment

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
HOME = os.environ['HOME']

In [3]:
def load_cifar10(split='train', translate=None, twox=False, ntr=None, autoaug=None, channels=3):
    '''
        autoaug == 'AA', AutoAugment
                   'FastAA', Fast AutoAugment
                   'RA', RandAugment
        channels == 3 return by default rgb 3 channel image
                    1 Return a single channel image
    '''
    path = f'data/cifar10-{split}.pkl'
    cifar10_transforms_train= transforms.Compose([transforms.Resize((32,32))]) #224,224
    if not os.path.exists(path):
        dataset = CIFAR10(f'{HOME}/.pytorch/CIFAR10', train=(split=='train'), download=True, transform= cifar10_transforms_train)
        x, y = dataset.data, dataset.targets
        
        #Only Select First 10k images as train
        #if split=='train':
        #    x, y = x[0:10000], y[0:10000]
        
        #[TODO] - solve -> AttributeError: 'numpy.ndarray' object has no attribute 'numpy'
        #x = torch.tensor(resize_imgs(x.numpy(), 32))
        #x = torch.tensor(resize_imgs_dkcho(x, 32)) # x-> torch.Size([10000, 32, 32, 3])
        x= torch.tensor(x)
        x = (x.float()/255.)#.unsqueeze(1).repeat(1,3,1,1)  #<class 'torch.Tensor'>
        x= x.permute(0,3,1,2) #[batchsize,w,h,channel] -> [batchsize, channel, w,h]
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    
    if ntr is not None:
        x, y = x[0:ntr], y[0:ntr]
    
    # Without Data Augmentation
    if (translate is None) and (autoaug is None):
        dataset = TensorDataset(x, y)
        return dataset
    
    # Data Augmentation Pipeline
    transform = [transforms.ToPILImage()]
    if translate is not None:
        transform.append(transforms.RandomAffine(0, [translate, translate]))
    if autoaug is not None:
        if autoaug == 'AA':
            transform.append(CIFAR10Policy()) #originally SVHNPolicy()
        elif autoaug == 'RA':
            transform.append(RandAugment(3,4))
    transform.append(transforms.ToTensor())
    transform = transforms.Compose(transform)
    dataset = myTensorDataset(x, y, transform=transform, twox=twox)
    return dataset




def load_usps(split='train', channels=3):
    path = f'data/usps-{split}.pkl'
    if not os.path.exists(path):
        dataset = USPS(f'{HOME}/.pytorch/USPS', train=(split=='train'), download=True)
        x, y = dataset.data, dataset.targets
        x = torch.tensor(resize_imgs(x, 32))
        x = (x.float()/255.).unsqueeze(1).repeat(1,3,1,1)
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    dataset = TensorDataset(x, y)
    return dataset

In [4]:
def load_svhn(split='train', channels=3):
    dataset = SVHN(f'{HOME}/.pytorch/SVHN', split=split, download=True)
    x, y = dataset.data, dataset.labels
    x = x.astype('float32')/255.
    x, y = torch.tensor(x), torch.tensor(y)
    if channels == 1:
        x = x.mean(1, keepdim=True)
    dataset = TensorDataset(x, y)
    return dataset

In [5]:
def load_pacs(split='train', translate=None, twox=False, ntr=None, autoaug=None, channels=3):
    #PACS Dataset
    NUM_CLASSES = 7      # 7 classes for each domain: 'dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person'
    DATASETS_NAMES = ['photo', 'art', 'cartoon', 'sketch']
    CLASSES_NAMES = ['Dog', 'Elephant', 'Giraffe', 'Guitar', 'Horse', 'House', 'Person']
    DIR_PHOTO = './data/PACS/photo'
    DIR_ART = './data/PACS/art_painting'
    DIR_CARTOON = './data/PACS/cartoon'
    DIR_SKETCH = './data/PACS/sketch'

    path = f'data/pacs-{split}.pkl'

    pacs_convertor= {'train':DIR_PHOTO, 'photo':DIR_PHOTO, 'art':DIR_ART, 'cartoon':DIR_CARTOON, 'sketch':DIR_SKETCH}
    
    pacs_transforms_train= transforms.Compose([transforms.ToTensor(),transforms.Resize((224,224))]) #224,224
    if not os.path.exists(path):
        dataset= torchvision.datasets.ImageFolder(pacs_convertor[split], transform=pacs_transforms_train)
        train_loader = torch.utils.data.DataLoader(dataset,batch_size=len(dataset),drop_last=True)
        x, y= next(iter(train_loader))
        #x, y = dataset.image, dataset.label
        x= torch.tensor(x)
        #x = (x.float()/255.)#.unsqueeze(1).repeat(1,3,1,1)  #<class 'torch.Tensor'>
        #x= x.permute(0,3,1,2) #[batchsize,w,h,channel] -> [batchsize, channel, w,h]
        y = torch.tensor(y)
        with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        if channels == 1:
            x = x[:,0:1,:,:]
    
    if ntr is not None:
        x, y = x[0:ntr], y[0:ntr]
    
    # Without Data Augmentation
    if (translate is None) and (autoaug is None):
        dataset = TensorDataset(x, y)
        return dataset

In [6]:
def load_mnist_m(split='train', channels=3):
    path = f'data/mnist_m-{split}.pkl'
    with open(path, 'rb') as f:
        x, y = pickle.load(f)
        x, y = torch.tensor(x.astype('float32')/255.), torch.tensor(y)
        if channels==1:
            x = x.mean(1, keepdim=True)
    dataset = TensorDataset(x, y)
    return dataset

In [7]:
dataset= load_cifar10('test')

In [9]:
train_loader = torch.utils.data.DataLoader(dataset,batch_size=1,drop_last=True, shuffle=True)
x, y= next(iter(train_loader))

In [11]:
x.shape

torch.Size([1, 3, 32, 32])

# MNIST-M

In [11]:
dataset= load_mnist_m('test')

# SVHN

In [None]:
dataset= load_svhn('test')

Using downloaded and verified file: /home/dongkyu/.pytorch/SVHN/test_32x32.mat


# PACS

In [15]:
dataset= load_pacs(split='photo')
train_size= len(dataset)
train_loader = torch.utils.data.DataLoader(dataset,batch_size=train_size,drop_last=True, shuffle=True)

x, y= next(iter(train_loader))

In [17]:
len(dataset)

1670

In [145]:
y= y.tolist()
from collections import Counter
c=Counter(y)
print(c)

Counter({4: 816, 0: 772, 2: 753, 1: 740, 3: 608, 6: 160, 5: 80})


In [146]:
sample,samplelabel= x[0],y[0]

In [147]:
samplelabel

5

In [148]:
topil= transforms.ToPILImage()
image= topil(sample)
image.save('./data/image_test.png')

# Check STL10

In [189]:
def load_stl10(split='train', channels=3):
    STL10_transforms_train= transforms.Compose([transforms.Resize((32,32))])
    dataset = STL10(f'{HOME}/.pytorch/STL10', split=split, download=True, transform= STL10_transforms_train)
    x, y = dataset.data, dataset.labels
    x = x.astype('float32')/255.
    x, y = torch.tensor(x), torch.tensor(y)
    if channels == 1:
        x = x.mean(1, keepdim=True)
    dataset = TensorDataset(x, y)
    return dataset

In [190]:
dataset= load_stl10('train')
train_loader = torch.utils.data.DataLoader(dataset,batch_size=1,drop_last=True)
x, y= next(iter(train_loader))
x= x[-1]

Files already downloaded and verified


In [192]:
topil= transforms.ToPILImage()
image= topil(x)
image.save('image_test.png')

# Check CIFAR10

In [196]:
dataset= load_cifar10('train')
train_loader = torch.utils.data.DataLoader(dataset,batch_size=2,drop_last=True)
x, y= next(iter(train_loader))
x= x[-1]

In [197]:
topil= transforms.ToPILImage()
image= topil(x)
image.save('./data/image_test.png')