**Filter out targeted class dataset**

In [None]:
trainset = torchvision.datasets.MNIST("/data/mnist", train=True, transform=transforms.ToTensor())
testset = torchvision.datasets.MNIST("/data/mnist", train=False, transform=transforms.ToTensor())
trainloader = DataLoader(trainset, batch_size=1, num_workers=10)
testloader = DataLoader(testset, batch_size=1, num_workers=10)

In [None]:
images_train, labels_train = [], []
for images, labels in trainloader:
    if labels.item() == 8:
        images_train.append(images)
        labels_train.append(labels)
train_pt = (torch.cat(images_train), torch.cat(labels_train))
torch.save(train_pt, "./data/training.pt")

In [None]:
images_test, labels_test = [], []
for images, labels in testloader:
    if labels.item() == 8:
        images_test.append(images)
        labels_test.append(labels)
test_pt = (torch.cat(images_test), torch.cat(labels_test))
torch.save(test_pt, "./data/test.pt")

**Build balanced dataset**

In [None]:
import torchvision
from torchvision.transforms import transforms
trainset = torchvision.datasets.MNIST("/data/mnist", train=True, transform=transforms.ToTensor())
#subset_idices = np.random.choice(range(len(trainset)), 5000, replace=False)
trainloader = DataLoader(trainset, batch_size=1, num_workers=10)

In [None]:
counter = [0 for _ in range(10)]
collection = [[] for _ in range(10)]
size = 500
for image, label in trainloader:
    idx = label.item()
    if counter[idx] < size:
        collection[idx].append(image)
        counter[idx] += 1
    if all([x == 500 for x in collection]):
        break
counter

In [None]:
data, labels = [], []
for label in range(10):
    data.append(torch.cat(collection[label]))
    labels.append(torch.Tensor([label]).to(torch.int64).repeat(500))

In [None]:
data, labels = torch.cat(data), torch.cat(labels)

In [None]:
data.shape, labels.shape

In [None]:
out_dir = "./data/mnist_balanced_subset"
torch.save((data, labels), out_dir)

**Compare x and x_t**

In [None]:
N = len(T)
demo_steps = 5
for i, (image, label) in enumerate(trainloader, 1):
    xt_list, label_list = [], ["Original", "Gaussian", "Rotate", "Contrast", "Brightness"]
    fig = plt.figure(figsize=(8, 4))
    ax = plt.subplot(2, N+1, 1)
    ax.set_title(label_list[0])
    plt.imshow(image[0][0].numpy(), cmap="gray")
    plt.axis("off")
    for t in T:
        xt_list.append(t(image))
    for c in range(2, N+2):
        ax = plt.subplot(2, N+1, c)
        ax.set_title(label_list[c-1])
        plt.imshow(xt_list[c-2][0][0].numpy(), cmap="gray")
        plt.axis("off")
        ax = plt.subplot(2, N+1, c+N+1)
        ax.set_title(f"Diff: {label_list[c-1]}")
        plt.imshow(xt_list[c-2][0][0].numpy() - image[0][0].numpy(), cmap="gray")
        plt.axis("off")
    plt.show()
    print("=======================================================")
    if i == demo_steps:
        break