In [337]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy, RandAugment
import torchvision
from PIL import Image

from aug_utils_show import RandomCropPaste, CutMix, MixUp  # 自定义增强方法

In [None]:
save_dir = "augment_images_all"
os.makedirs(save_dir, exist_ok=True)

# CIFAR-10 原始图像（不使用 transform）
raw_dataset = CIFAR10(root='./data', train=False, download=True)
classes = raw_dataset.classes

resize = transforms.Resize((32, 32))
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()

autoaugment = AutoAugment(policy=AutoAugmentPolicy.CIFAR10)
randaugment = RandAugment(num_ops=4, magnitude=15)
crop_paste = RandomCropPaste(size=32,alpha=1.0,flip_p=0.8)

Files already downloaded and verified


In [None]:
def imshow(img, label, subplot_pos, title):
    plt.subplot(2, 3, subplot_pos)
    plt.imshow(img)
    plt.title(f"{title}\nLabel: {label}")
    plt.axis('off')

In [None]:
def imshow(img, label, subplot_pos, title):
    plt.subplot(2, 3, subplot_pos)
    plt.imshow(img)
    plt.title(f"{title}\nLabel: {label}")
    plt.axis('off')

for idx in range(5):
    img, label = raw_dataset[idx]
    label_name = classes[label]

    resized_img = resize(img)

    img_tensor = to_tensor(resized_img)

    aa_img = autoaugment(resized_img)
    ra_img = randaugment(resized_img)
    rcp_img_tensor = crop_paste(img_tensor)
    rcp_img = to_pil(rcp_img_tensor)

    tensor_img = to_pil(img_tensor)

    plt.figure(figsize=(12, 6))
    imshow(img, label_name, 1, "Original")
    imshow(resized_img, label_name, 2, "Resized")
    imshow(aa_img, label_name, 3, "AutoAugment")
    imshow(ra_img, label_name, 4, "RandAugment")
    imshow(rcp_img, label_name, 5, "RandomCropPaste")
    imshow(tensor_img, label_name, 6, "ToTensor")

    plt.suptitle(f"Image {idx+1} - Class: {label_name}", fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'aug_compare_{idx+1}_{label_name}.png'))
    plt.close()

In [None]:
cutmix = CutMix(size=32, beta=0.8)
mixup = MixUp(alpha=2.5)

batch_imgs = torch.stack([to_tensor(resize(raw_dataset[i][0])) for i in range(8)])
batch_labels = torch.tensor([raw_dataset[i][1] for i in range(8)])

In [None]:
def save_batch_images(batch, labels, title, filename):
    grid = torchvision.utils.make_grid(batch, nrow=4)
    npimg = grid.permute(1, 2, 0).numpy()
    
    if isinstance(labels[0], (int, torch.Tensor)):
        label_names = [classes[int(l)] for l in labels]
    else:
        label_names = labels

    caption = ", ".join(label_names[:4]) + " ..."
    
    plt.figure(figsize=(8, 4))
    plt.imshow(npimg)
    plt.title(f"{title}\nLabels: {caption}")
    plt.axis('off')
    plt.savefig(os.path.join(save_dir, filename))
    plt.close()


In [343]:
save_batch_images(batch_imgs, batch_labels, "Original Batch", "batch_original.png")

cutmix_imgs, label_a, label_b, lam = cutmix((batch_imgs.clone(), batch_labels.clone()))
combined_labels_cutmix = [f"{classes[a]} + {classes[b]}" for a, b in zip(label_a, label_b)]
save_batch_images(cutmix_imgs, combined_labels_cutmix, f"CutMix (λ={lam:.2f})", "batch_cutmix.png")

mixup_imgs, label_a, label_b, lam2 = mixup((batch_imgs.clone(), batch_labels.clone()))
combined_labels_mixup = [f"{classes[a]} + {classes[b]}" for a, b in zip(label_a, label_b)]
save_batch_images(mixup_imgs, combined_labels_mixup, f"MixUp (λ={lam2:.2f})", "batch_mixup.png")