In [None]:
# These codes are adapted from Sagawa and Hino's work: https://github.com/ssgw320/gdacnf

In [1]:
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
from scipy import ndimage
from torchvision.datasets import MNIST

In [2]:
your_dataset_path = '../../../dataset/'
dataset = MNIST(your_dataset_path, train=True, download=False)

In [3]:
x = np.array(dataset.data).astype(np.float32) / 255
y = np.array(dataset.targets)

In [4]:
x.shape

(60000, 28, 28)

In [5]:
domain_num = 10

In [6]:
angles = np.linspace(0, 45, domain_num)

In [7]:
angles

array([ 0.,  5., 10., 15., 20., 25., 30., 35., 40., 45.])

In [8]:
index = np.arange(x.shape[0])
np.random.seed(516)
np.random.shuffle(index)

In [9]:
each_domain_samples = np.full(shape=10, fill_value=4000)

In [10]:
split_index = np.split(index, np.cumsum(each_domain_samples))

In [11]:
split_index

[array([38972, 44305, 16520, ..., 35665, 41294, 20358]),
 array([13491, 54710,  5359, ..., 50538, 58244, 43819]),
 array([37143, 44241, 56857, ..., 50086,  6239, 21692]),
 array([49735, 52681, 47817, ..., 44538,  7937, 43258]),
 array([13271, 16526, 57466, ..., 10584, 10332, 12208]),
 array([19937, 17041, 30413, ..., 17114, 17120, 55964]),
 array([ 4708, 57633, 23555, ..., 37473, 36202, 18364]),
 array([43671, 57248, 19413, ..., 30287, 14199, 36874]),
 array([30268,  7589,  7231, ..., 18310, 12280, 53294]),
 array([25719, 43509, 19243, ...,  7842, 36761, 19516]),
 array([43708, 18817, 11261, ..., 19405, 48205, 51967])]

In [12]:
x_all, y_all = list(), list()
for idx, angle in zip(split_index, angles):
    rotated_x = []
    for i in x[idx]:
        # temp = np.random.random()*5 + angle
        temp = angle
        rotated_x.append(ndimage.rotate(i, temp, reshape=False))
    rotated_x = np.array(rotated_x)
    # rotated_x = np.array([ndimage.rotate(i, angle, reshape=False) for i in x[idx]])
    x_all.append(rotated_x.reshape(-1, 1, 28, 28))
    y_all.append(y[idx])

In [13]:
y_all[-1].shape, x_all[0].shape

((4000,), (4000, 1, 28, 28))

In [14]:
obj = {'data': x_all, 'label': y_all}
pd.to_pickle(obj, f'mnist45_original.pkl')

In [15]:
import umap

In [16]:
def fit_umap(x_all, y_all, **umap_kwargs) -> list:
    umap_settings = dict(n_components=2, n_neighbors=15, metric='cosine')
    umap_settings.update(umap_kwargs)
    X = np.vstack(x_all)
    X = X.reshape(X.shape[0], -1)
    # use source label as semi-superviesd UMAP
    Y_semi_supervised = [np.full(shape=y.shape[0], fill_value=-1) for y in y_all]
    Y_semi_supervised[0] = y_all[0].copy()
    Y_semi_supervised = np.hstack(Y_semi_supervised)
    # fit UMAP
    encoder = umap.UMAP(random_state=1234, **umap_settings)
    Z = encoder.fit_transform(X, Y_semi_supervised)
    z_idx = np.cumsum([i.shape[0] for i in x_all])
    z_all = np.vsplit(Z, z_idx)[:-1]
    return z_all, encoder

In [17]:
z_all, encoder = fit_umap(np.array(x_all)[[0, domain_num-1]], np.array(y_all)[[0, domain_num-1]], n_components=8)

In [18]:
obj = {'data': z_all, 'label': np.array(y_all)[[0, domain_num-1]]}
pd.to_pickle(obj, f'mnist45.pkl')