In [None]:
import os
import random
from PIL import Image, ImageDraw
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt

# === Parameters ===
IMAGE_SIZE = 128
SHAPE_SIZE = 60
NUM_IMAGES_PER_CLASS = 300
OUTPUT_DIR = "fixed_length_fixed_rotation"
SHAPES = ["circle", "square", "triangle"]
os.makedirs(OUTPUT_DIR, exist_ok=True)

# === Drawing Functions ===
def draw_shape(shape, image_size, shape_size, rotation=0):
    img = Image.new("RGB", (image_size, image_size), "white")
    draw = ImageDraw.Draw(img)
    center = (image_size // 2, image_size // 2)
    half = shape_size // 2

    if shape == "circle":
        bbox = [center[0]-half, center[1]-half, center[0]+half, center[1]+half]
        draw.ellipse(bbox, fill="black")
    elif shape == "square":
        bbox = [center[0]-half, center[1]-half, center[0]+half, center[1]+half]
        square = Image.new("L", (image_size, image_size), 0)
        ImageDraw.Draw(square).rectangle(bbox, fill=255)
        square = square.rotate(rotation, expand=False)
        img.paste("black", mask=square)
    elif shape == "triangle":
        triangle = Image.new("L", (image_size, image_size), 0)
        points = [
            (center[0], center[1]-half),
            (center[0]-half, center[1]+half),
            (center[0]+half, center[1]+half)
        ]
        ImageDraw.Draw(triangle).polygon(points, fill=255)
        triangle = triangle.rotate(rotation, expand=False)
        img.paste("black", mask=triangle)
    return img

# === Generate Images ===
for shape in SHAPES:
    shape_dir = os.path.join(OUTPUT_DIR, shape)
    os.makedirs(shape_dir, exist_ok=True)
    for i in range(NUM_IMAGES_PER_CLASS):
        img = draw_shape(shape, IMAGE_SIZE, SHAPE_SIZE, rotation=0)
        img.save(os.path.join(shape_dir, f"{shape}_{i}.png"))


In [None]:
import sys
print(sys.executable)


In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
import torch
import os
from PIL import Image

# === Transformations ===
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),  # converts to [0,1]
    transforms.Normalize(mean=[0.5], std=[0.5])  # normalize to [-1,1]
])

# === Custom Dataset ===
class ShapeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        self.labels = {"circle": 0, "square": 1, "triangle": 2}
        self.transform = transform
        for shape_name, label in self.labels.items():
            shape_dir = os.path.join(root_dir, shape_name)
            for img_file in os.listdir(shape_dir):
                if img_file.endswith(".png"):
                    self.samples.append((os.path.join(shape_dir, img_file), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# === Load Full Dataset ===
dataset = ShapeDataset("data/fixed_length_fixed_rotation", transform=transform)

# === Split Dataset ===
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

# === Create DataLoaders ===
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32)
test_loader = DataLoader(test_set, batch_size=32)


In [None]:
import matplotlib.pyplot as plt

def show_samples(dataset, num_samples=5):
    fig, axs = plt.subplots(3, num_samples, figsize=(num_samples * 2, 6))
    labels_map = {0: 'Circle', 1: 'Square', 2: 'Triangle'}
    counts = {0: 0, 1: 0, 2: 0}

    for img, label in dataset:
        if counts[label] < num_samples:
            ax = axs[label, counts[label]]
            ax.imshow(img.permute(1, 2, 0) * 0.5 + 0.5)  # unnormalize
            ax.axis('off')
            if counts[label] == 0:
                ax.set_title(labels_map[label])
            counts[label] += 1
        if all(v == num_samples for v in counts.values()):
            break
    plt.tight_layout()
    plt.show()

# Show samples from the training dataset
show_samples(train_set)
