In [1]:
# Fashion MNIST dataset
import torchvision.transforms as transforms
import torch 
from torchvision import datasets
testset = datasets.FashionMNIST(
    root      = './.data/', train     = False,
    download  = True,
    transform = transforms.ToTensor())

In [2]:
test_label = [2,4,6] # Define actual test class that we use
actual_testdata = torch.isin(testset.targets, torch.tensor(test_label))
testset.data = testset.data[actual_testdata]
testset.targets = testset.targets[actual_testdata]

In [3]:
import torch
import numpy as np
import os
import cv2

class DataExtractor:
    def __init__(self, test_dataset):
        self.test_dataset = test_dataset

    def __call__(self, num_samples=2, save_path=None):
        self.num_samples = num_samples
        self.output = self.pick_random_samples(num_samples)
        self.saver(save_path)

    def saver(self, save_path=None):
        test2, test4, test6 = self.output

        if save_path is None:
            save_path = './data/'
        os.makedirs(save_path, exist_ok=True)

        # ✅ 저장: 정상 (label==2)
        for i, data in enumerate(test2):
            image = data.view(28, 28).numpy() * 255
            image = image.astype(np.uint8)
            cv2.imwrite(os.path.join(save_path, f'Normal_{i}.png'), image)

        # ✅ 저장: 비정상 (label==4,6)
        for cls_idx, data_group in zip([4, 6], [test4, test6]):
            for j, data in enumerate(data_group):
                image = data.view(28, 28).numpy() * 255
                image = image.astype(np.uint8)
                cv2.imwrite(os.path.join(save_path, f'Abnormal_{cls_idx}_{j}.png'), image)

        print(f"✅ {self.num_samples} samples per class saved to: {save_path}")

    def pick_random_samples(self, num_samples=2):
        # 클래스별 필터링
        test_2 = self.test_dataset.data[self.test_dataset.targets == 2]
        test_4 = self.test_dataset.data[self.test_dataset.targets == 4]
        test_6 = self.test_dataset.data[self.test_dataset.targets == 6]

        # 랜덤 셔플
        test_2 = test_2[torch.randperm(test_2.size(0))]
        test_4 = test_4[torch.randperm(test_4.size(0))]
        test_6 = test_6[torch.randperm(test_6.size(0))]

        # 정규화 및 벡터화 (28x28 → 784)
        data_test_2 = test_2.view(test_2.size(0), -1).float() / 255.0
        data_test_4 = test_4.view(test_4.size(0), -1).float() / 255.0
        data_test_6 = test_6.view(test_6.size(0), -1).float() / 255.0

        # 앞에서 n개만 선택
        return [
            data_test_2[:num_samples],
            data_test_4[:num_samples],
            data_test_6[:num_samples]
        ]

In [4]:
save_path = './eval/test_images'
os.makedirs(save_path, exist_ok=True) # 폴더 생성
data_extractor = DataExtractor(testset) # 데이터 추출기 생성
data_extractor(num_samples=2, save_path=save_path) # 데이터 추출

✅ 2 samples per class saved to: ./eval/test_images
