In [None]:
'''
Create MNIST/FMNIST-90'Rotation dataset with two class.
You can create bias in either class label(denoted by Y) or rotation(denoted by A), or both.

Config:
    dataset : 'fmnist', 'mnist'
    bias_ratio : Major/Minor Ratio.
    bias_factor: 'None', 'Y', 'A', 'Both'
    class_list : two class labels(int) to filter (ex. [2, 8] for MNIST)
                Note that order matters, which is [major, minor]
    random_seed : random seed (default 0)
'''


import torch
from numpy import random

# config
dataset = 'mnist'
data_dir = f"./{dataset}/original"
bias_ratio = 4
bias_factor = 'Y'
class_list = [1, 6]
random_seed = 0

# random seed
torch.manual_seed(random_seed)
random.seed(random_seed)

In [None]:
import torchvision.transforms as tf
import torchvision.datasets as datasets
from os import path, mkdir, makedirs



# load data
if not path.exists(path.join(f'original')):
    mkdir(path.join(f'original'))

if dataset == 'mnist':
    train = datasets.MNIST(root=data_dir, train=True, transform=tf.ToTensor(), download=True)
    test = datasets.MNIST(root=data_dir, train=False, transform=tf.ToTensor(), download=True)

elif dataset == 'fmnist':
    train = datasets.FashionMNIST(root=data_dir, train=True, transform=tf.ToTensor(), download=True)
    test = datasets.FashionMNIST(root=data_dir, train=False, transform=tf.ToTensor(), download=True)

else:
    raise NotImplementedError("Dataset should be mnist or fmnist")




def get_savename(bias_factor, bias_ratio, class_list, dataset):
    '''
    Return dataset name along with bias factor, bias ratio, and class list.
    '''
    if bias_factor == 'None':
        setting = 'Unbiased'
    elif bias_factor == 'Y':
        setting = f'Y_{bias_ratio}1'
    elif bias_factor == 'A':
        setting = f'A_{bias_ratio}1'
    elif bias_factor == 'Both':
        setting = f'Both_{bias_ratio}1'
    else:
        raise NotImplementedError

    digit_name = ''.join(str(digit) for digit in class_list)
    savename = f'{dataset}_{digit_name}_{setting}'
    print(f'Savename of the dataset: {savename}')

    return savename

savename = get_savename(bias_factor, bias_ratio, class_list, dataset)


In [None]:
# create bias on Y 

# ====== Functions =========================================================
def filter_class_with_bias(data, targets, class_list, ratio = 4):
    '''
   Filter class and returns (ratio:1) biased dataset of given class list.
    '''

    if type(data) != torch.Tensor:
        data = torch.Tensor(data)
        targets = torch.Tensor(targets)
        
    mask_minor = (targets == class_list[1])
    mask_major = (targets == class_list[0])

    print(f'\nCreated dataset with {ratio}:1 ratio on Y')
    print('original count of major: ', mask_major.sum().item())
    print('original count of minor: ', mask_minor.sum().item())

    num_minor = int(mask_major.sum() / ratio)
    minor_class = data[mask_minor][:num_minor]
    minor_class_label = targets[mask_minor][:num_minor]

    major_class = data[mask_major]
    major_class_label = targets[mask_major]

    print('count of major: ', major_class_label.shape[0])
    print('count of minor: ', minor_class_label.shape[0])
    print('\n')

    return major_class, major_class_label, minor_class, minor_class_label
# ============================================================================



# filter class and create bias on train/test dataset
train_img = train.data
train_label = train.targets

test_img = test.data
test_label = test.targets

Y_bias = bias_ratio if bias_factor == 'Y' or bias_factor=='Both' else 1

train_data_major, train_Y_major, train_data_minor, train_Y_minor  = \
    filter_class_with_bias(train_img, train_label, class_list, Y_bias)

test_data_major, test_Y_major, test_data_minor, test_Y_minor  = \
    filter_class_with_bias(test_img, test_label, class_list, Y_bias)

In [None]:
# create bias on Z

import matplotlib.pyplot as plt
%matplotlib inline


# ====== Functions =========================================================
def plot_img(data, Y, A, num):
    '''
    Plot test for given dataset. 

    Args:
        num: number of plots
    '''
    for idx in torch.randint(0, data.shape[0], (num,)):
        plt.imshow(data[idx], cmap='gray')
        plt.title(f'Y={Y[idx]} A={A[idx]}')
        plt.show()
     

def rotate_90(data, bias):
    '''
    Rotate 90 degree counterclockwise.

    Returns:
        A = 1: clean
        A = 0: rotated
    '''
    minor_size = data.shape[0] // (bias + 1)
    rot_data = tf.functional.rotate(data[:minor_size], 90)
    A = torch.zeros((data.shape[0],))
    A[minor_size:] = 1
    print(f'Created dataset with {bias}:1 ratio on A')
    print(f'Clean: {data.shape[0] - minor_size}\nRotated: {minor_size}\n')
    return torch.cat([rot_data, data[minor_size:]], dim=0), A

# ============================================================================

# create rotation bias on train/test dataset
A_bias = bias_ratio if bias_factor == 'A' or bias_factor == 'Both' else 1

train_data_major, train_A_major = rotate_90(train_data_major, A_bias)
train_data_minor, train_A_minor = rotate_90(train_data_minor, A_bias)

test_data_major, test_A_major = rotate_90(test_data_major, A_bias)
test_data_minor, test_A_minor = rotate_90(test_data_minor, A_bias)

In [None]:
# augment 6x to mimic original MNIST/FMNIST size 

# ====== Functions =========================================================
def rot_aug(train, Y, A):
    '''
    Rotate clockwise/counterclockwise slightly.
    Includes plot test of random data. 

    Returns:
        3x augmented data
    '''
    transforms_rotate_r = tf.RandomRotation(degrees=(5,10))
    transforms_rotate_l = tf.RandomRotation(degrees=(-10, -5))
    r_rot_data = transforms_rotate_r(train)
    l_rot_data = transforms_rotate_l(train)

    data_3x = torch.cat([train, l_rot_data, r_rot_data], dim=0)
    data_3x_Y = Y.repeat(3)
    data_3x_A = A.repeat(3)
    print(f'With Rotation, Augmented data size: {data_3x.shape[0]}')
    

    def plot_rot_aug(i):
        '''
        Plot test for Rotation Augmentation.
        '''
        fig = plt.figure()
        rows, cols = 1 , 3

        ax1 = fig.add_subplot(rows, cols, 1)
        ax1.set_title(f'original')
        ax1.imshow(train[i].squeeze(), cmap='gray')

        ax2 = fig.add_subplot(rows, cols, 2)
        ax2.set_title('rotate_right')
        ax2.imshow(l_rot_data[i].squeeze(), cmap='gray')

        ax3 = fig.add_subplot(rows, cols, 3)
        ax3.set_title('rotate_left')
        ax3.imshow(r_rot_data[i].squeeze(), cmap='gray')

        fig.tight_layout()
        plt.show()

    plot_rot_aug(torch.randint(0, train.shape[0], (1,)))
    
    return data_3x, data_3x_Y, data_3x_A


def shift_aug(train, Y, A):
    '''
    Randomly shift horizontally/vertically data a little bit.
    Includes plot test of random data. 
    
    Returns:
        2x augmented data
    '''
    transforms_shift = tf.RandomAffine(degrees=0, translate=(0.1,0.1))
    shift_data = transforms_shift(train)

    data_2x = torch.cat([train, shift_data], dim=0)
    data_2x_Y = Y.repeat(2)
    data_2x_A = A.repeat(2)
    print(f'With Random Shift, Augmented data size: {data_2x.shape[0]}')

    def plot_shift_aug(i):
        '''
        Plot test for Shift Augmentation
        '''
        fig = plt.figure()
        rows, cols = 1 , 2

        ax1 = fig.add_subplot(rows, cols, 1)
        ax1.set_title(f'original')
        ax1.imshow(train[i].squeeze(), cmap='gray')

        ax2 = fig.add_subplot(rows, cols, 2)
        ax2.set_title('shifted')
        ax2.imshow(shift_data[i].squeeze(), cmap='gray')

        fig.tight_layout()
        plt.show()

    plot_shift_aug(torch.randint(0, train.shape[0], (1,)))

    return data_2x, data_2x_Y, data_2x_A

def shuffle_with_A(data, Y, A):
    '''
    Shuffle whole dataset.
    '''
    indices = torch.randperm(data.size(0))
    return torch.index_select(data, dim=0, index=indices), \
        torch.index_select(Y, dim=0, index=indices), torch.index_select(A, dim=0, index=indices)
# ============================================================================


# putting everything together 
train_data = torch.cat([train_data_minor, train_data_major], dim=0)
train_Y = torch.cat([train_Y_minor, train_Y_major], dim=0)
train_A = torch.cat([train_A_minor, train_A_major], dim=0)
train_data, train_Y, train_A  = shuffle_with_A(train_data, train_Y, train_A)

test_data = torch.cat([test_data_minor, test_data_major], dim=0)
test_Y = torch.cat([test_Y_minor, test_Y_major], dim=0)
test_A = torch.cat([test_A_minor, test_A_major], dim=0)
test_data, test_Y, test_A  = shuffle_with_A(test_data, test_Y, test_A)


# augment 6x
# if bias apperas both Y and A, augment 2x more
train_data, train_Y, train_A= rot_aug(train_data, train_Y, train_A)
train_data, train_Y, train_A = shift_aug(train_data, train_Y, train_A)
if train_data.size(0) < 60000:
    train_data, train_Y, train_A = shift_aug(train_data, train_Y, train_A)


test_data, test_Y, test_A= rot_aug(test_data, test_Y, test_A)
test_data, test_Y, test_A = shift_aug(test_data, test_Y, test_A)
if test_data.size(0) < 60000:
    test_data, test_Y, test_A = shift_aug(test_data, test_Y, test_A)

In [None]:
# random sample to mimic original dataset size of MNIST/FashionMNIST.

def print_data_info(Y, A, class_list):
    major, minor = class_list
    cln_major = ((Y == major) & (A == 1)).sum()
    rot_major = ((Y == major) & (A == 0)).sum()
    cln_minor = ((Y == minor) & (A == 1)).sum()
    rot_minor = ((Y == minor) & (A == 0)).sum()

    print(f'Total Dataset Size: {Y.shape[0]}')
    print(f'\tMajor Class {major}: {cln_major+rot_major}')
    print(f'\t\tClean: {cln_major}')
    print(f'\t\tRotated: {rot_major}')
    print(f'\tMinor Class {minor}: {cln_minor+rot_minor}')
    print(f'\t\tClean: {cln_minor}')
    print(f'\t\tRotated: {rot_minor}')

def rand_sample_with_A(data, Y, A, size):
    '''
    Random sampling of given size from the dataset.
    '''

    if data.size(0) < size:
        print("Target size is bigger than current dataset size")
        return data, Y, A

    indices = torch.randperm(data.size(0))[:size]
    return torch.index_select(data, dim=0, index=indices), \
        torch.index_select(Y, dim=0, index=indices), torch.index_select(A, dim=0, index=indices)


# 60k for train data, 10k for test data
if train_data.size(0) > 60000 and test_data.size(0) > 10000:
    train_data, train_Y, train_A  = rand_sample_with_A(train_data, train_Y, train_A, 60000)
    test_data, test_Y, test_A  = rand_sample_with_A(test_data, test_Y, test_A, 10000)

# check final group size with bias is correct
print_data_info(train_Y, train_A, class_list)
print_data_info(test_Y, test_A, class_list)

In [None]:
# save directory
savedir = path.join('rotated', savename)
if not path.exists(savedir):
    makedirs(savedir)

# save train data
torch.save(train_data, path.join(savedir, f'train_data.pt'), _use_new_zipfile_serialization=False)
torch.save(train_Y, path.join(savedir, f'train_Y.pt'), _use_new_zipfile_serialization=False)
torch.save(train_A, path.join(savedir, f'train_A.pt'), _use_new_zipfile_serialization=False)

# save test data
torch.save(test_data, path.join(savedir, f'test_data.pt'), _use_new_zipfile_serialization=False)
torch.save(test_Y, path.join(savedir, f'test_Y.pt'), _use_new_zipfile_serialization=False)
torch.save(test_A, path.join(savedir, f'test_A.pt'), _use_new_zipfile_serialization=False)