In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import timm
import numpy as np
from torch.utils.data import DataLoader

from timm.data import Mixup

def kl_divergence(p, q):
    epsilon = 1e-10
    p = p + epsilon
    q = q + epsilon
    kl = torch.sum(p * torch.log(p / q))
    return kl.item()

# Set default dtype to fp32
torch.set_default_dtype(torch.float32)

# Define mixup function
mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=1.0, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=1000)

# Data transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # Normalization with mean and std for ImageNet
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Load dataset
train_dataset = datasets.ImageFolder('/data/ILSVRC2012/train', transform=train_transform)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

# Load ViT model
model = timm.create_model('vit_small_patch16_224', pretrained=True)
model.eval()
model.cuda()

# Initialize clusters
clusters = []

# Set threshold k (adjust as needed)
k = 0.1

# Process data
for batch_idx, (images, labels) in enumerate(train_loader):
    # Limit number of batches for testing
    if batch_idx >= 10:
        break

    # Apply mixup
    images, labels = mixup_fn(images, labels)

    images = images.cuda()

    with torch.no_grad():
        # Get patch embeddings
        x = model.patch_embed(images)  # x is [B, num_patches, embed_dim]

        B, num_patches, embed_dim = x.shape

        # For each image in the batch
        for i in range(B):
            # For each patch in the image
            for j in range(num_patches):
                patch_vec = x[i, j, :]  # Shape: [embed_dim]

                # Normalize to get distribution
                p = patch_vec
                p = p - p.min()  # Ensure non-negative
                p = p / p.sum()  # Sum to 1

                # Ensure no zeros
                p = p + 1e-10
                p = p / p.sum()

                assigned = False

                # For each cluster
                for cluster in clusters:
                    q = cluster['center']
                    kl = kl_divergence(p, q)

                    if kl <= k:
                        # Assign to cluster
                        cluster['elements'].append(p)
                        cluster['distances'].append(kl)
                        # Update cluster center
                        n = len(cluster['elements'])
                        cluster['center'] = (cluster['center'] * (n - 1) + p) / n
                        assigned = True
                        break

                if not assigned:
                    # Create new cluster
                    clusters.append({
                        'center': p,
                        'elements': [p],
                        'distances': []
                    })

    print(f"Processed batch {batch_idx+1}")

# After processing, rank clusters by number of elements
clusters_sorted = sorted(clusters, key=lambda x: len(x['elements']), reverse=True)

# For the top cluster
top_cluster = clusters_sorted[0]
num_elements = len(top_cluster['elements'])
print(f"Top cluster has {num_elements} elements.")

# Compute average and standard deviation of distances
if top_cluster['distances']:
    distances = top_cluster['distances']
    avg_distance = sum(distances) / len(distances)
    std_distance = np.std(distances)
else:
    avg_distance = 0
    std_distance = 0

print(f"Average distance to center: {avg_distance}")
print(f"Standard deviation of distances: {std_distance}")

# Output the 384-dimensional cluster center
cluster_center = top_cluster['center']
print("Cluster center values:")
print(cluster_center.cpu().numpy())


In [2]:
import os
import random
import numpy as np
from collections import Counter
from PIL import Image
from tqdm import tqdm

# 데이터 경로 설정
data_path = '/data/ILSVRC2012/train'
output_path = 'output_patch.npy'

# 클래스 폴더 확인
class_folders = [os.path.join(data_path, folder) for folder in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, folder))]


# 각 클래스에서 10개의 랜덤 샘플 선택
random.seed(42)  # 재현성을 위해 시드 설정
selected_images = []
for folder in tqdm(class_folders, desc="Selecting images from class folders"):
    images = [os.path.join(folder, file) for file in os.listdir(folder) if file.endswith(('.JPEG', '.png'))]
    selected_images.extend(random.sample(images, min(10, len(images))))



# 최빈값 계산 함수 정의
def calculate_mode_pixel(image_paths):
    pixel_list = []

    for image_path in tqdm(image_paths, desc="Processing images for mode calculation"):
        with Image.open(image_path) as img:
            img = img.convert('RGB')  # RGB로 변환
            pixels = np.array(img).reshape(-1, 3)  # 픽셀을 [R, G, B] 형태로 변환
            pixel_list.extend([tuple(pixel) for pixel in pixels])

    # 픽셀 최빈값 계산
    pixel_counter = Counter(pixel_list)
    print("Pixel frequency table:")
    for pixel, count in pixel_counter.most_common(30):  # 상위 10개만 출력
        print(f"Pixel: {pixel}, Count: {count}")

    most_common_pixel = pixel_counter.most_common(1)[0][0]
    return most_common_pixel

# 최빈값 계산
total_most_common_pixel = calculate_mode_pixel(selected_images)
print(f"Most common pixel: {total_most_common_pixel}")

# 최빈값으로 이루어진 이미지 패치 생성
def create_patch(most_common_pixel, size=(256, 256)):
    patch = np.full((size[0], size[1], 3), most_common_pixel, dtype=np.uint8)
    return patch

patch = create_patch(total_most_common_pixel)

# 넘파이 파일로 저장
np.save(output_path, patch)
print(f"Patch saved to {output_path}")

# 시각화를 위해 PIL로 저장 (옵션)
Image.fromarray(patch).save('patch_visualization.png')
print("Visualization image saved as 'patch_visualization.png'")


Selecting images from class folders: 100%|██████████| 1000/1000 [00:01<00:00, 915.50it/s]
Processing images for mode calculation: 100%|██████████| 10000/10000 [33:55<00:00,  4.91it/s] 


Pixel frequency table:
Pixel: (255, 255, 255), Count: 34984175
Pixel: (0, 0, 0), Count: 20129399
Pixel: (1, 1, 1), Count: 5483754
Pixel: (254, 254, 254), Count: 4642243
Pixel: (2, 2, 2), Count: 2891657
Pixel: (253, 253, 253), Count: 2116617
Pixel: (3, 3, 3), Count: 1753231
Pixel: (4, 4, 4), Count: 1593433
Pixel: (5, 5, 5), Count: 1371965
Pixel: (6, 6, 6), Count: 1283805
Pixel: (7, 7, 7), Count: 1217280
Pixel: (8, 8, 8), Count: 1199120
Pixel: (252, 252, 252), Count: 1086824
Pixel: (255, 255, 253), Count: 1011988
Pixel: (0, 0, 2), Count: 992489
Pixel: (9, 9, 9), Count: 977431
Pixel: (10, 10, 10), Count: 906612
Pixel: (13, 13, 13), Count: 898993
Pixel: (11, 11, 11), Count: 869029
Pixel: (246, 239, 247), Count: 824321
Pixel: (254, 255, 255), Count: 789671
Pixel: (12, 12, 12), Count: 786587
Pixel: (251, 251, 251), Count: 739224
Pixel: (16, 16, 16), Count: 737744
Pixel: (14, 14, 14), Count: 727069
Pixel: (15, 15, 15), Count: 682276
Pixel: (17, 17, 17), Count: 654971
Pixel: (1, 0, 0), Count: 