In [7]:
from datasets import load_dataset, DatasetDict
from torchvision import transforms
from PIL import Image
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [8]:
# Load the dataset and perform the first split (80% train, 20% test + valid)
ds = load_dataset('huggan/wikiart', split='train').train_test_split(train_size=0.80, test_size=0.20, seed=51)

# Split the 20% (test + valid) into 50% test and 50% valid
test_valid = ds['test'].train_test_split(test_size=0.5, seed=51)

# Combine the splits into a new DatasetDict
train_test_valid_dataset = DatasetDict({
    'train': ds['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']
})

# Function to display images from a dataset
def show_images(dataset, num_images=5):
    plt.figure(figsize=(15, 10))
    
    for i in range(num_images):
        # Load image from dataset and apply the transformation
        img = Image.open(dataset[i]['image']).convert('RGB')
        img = transform(img)
        
        # Convert tensor back to image for display
        img = img.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
        img = img * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])  # Denormalize
        img = img.clamp(0, 1)  # Ensure the values are between 0 and 1
        
        # Plot the image
        plt.subplot(1, num_images, i+1)
        plt.imshow(img)
        plt.axis('off')
    
    plt.show()

# Display a few images from the training split
show_images(train_test_valid_dataset['train'], num_images=5)

Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

AttributeError: read

<Figure size 1500x1000 with 0 Axes>

In [5]:

# Define the transformation pipeline (resize, convert to tensor, normalize)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images to a common size (e.g., 256x256)
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize based on ImageNet stats
])

# Apply the transformation to each dataset split
def apply_transform(example):
    # Ensure that 'example['image']' contains the correct file path or image object
    img = Image.open(example['image']).convert('RGB')  # Open image and ensure it's in RGB mode
    example['image'] = transform(img)  # Apply the transformation to the image
    return example

# Apply the transformation to each split (train, test, validation)
train_test_valid_dataset = train_test_valid_dataset.map(apply_transform)

# Verify the transformation by inspecting a sample (e.g., from the train split)
print(train_test_valid_dataset['train'][0])  # Check the first example in the train split

Map:   0%|          | 0/65155 [00:00<?, ? examples/s]

AttributeError: read