# Get all sets

In [37]:
import itertools
from collections import namedtuple

In [38]:
# card namedtuple type
Card = namedtuple('card', ['number', 'color', 'shape', 'shade'])

In [39]:
# characteristics
numbers = ['one', 'two', 'three']
colors = ['green', 'purple', 'red']
shapes = ['diamond', 'squiggle', 'oval']
shades = ['open', 'solid', 'striped']

In [40]:
# create all cards 
all_cards = {
    Card(number=number, color=color, shape=shape, shade=shade)
    for number in numbers
    for color in colors
    for shape in shapes
    for shade in shades
}

In [41]:
print(f"There are {len(all_cards)} cards in total.")

There are 81 cards in total.


In [42]:
all_cards_list = list(all_cards)
possible_sets = list(itertools.combinations(all_cards_list, 3))

In [43]:
print(f"There are {len(possible_sets):,} different 3 card combinations in total.")

There are 85,320 different 3 card combinations in total.


In [44]:
def check_if_valid(candidate_set):
    set_numbers = set()
    set_colors = set()
    set_shapes = set()
    set_shades = set()
    
    for card in candidate_set:
        set_numbers.add(card.number)
        set_colors.add(card.color)
        set_shapes.add(card.shape)
        set_shades.add(card.shade)

    set_totals = {len(set_numbers), len(set_colors), len(set_shapes), len(set_shades)}
    
    if set_totals in [{1},{3},{1,3}]:
        return True
    
    return False

In [45]:
valid_sets, invalid_sets = set(), set()

for candidate_set in possible_sets:
    (valid_sets if check_if_valid(candidate_set) else invalid_sets).add(candidate_set)

In [46]:
assert len(valid_sets)+len(invalid_sets) == len(possible_sets)

In [47]:
print(f"There are {len(valid_sets):,} valid sets and {len(invalid_sets):,} invalid sets.")

There are 1,080 valid sets and 84,240 invalid sets.


# Get training dataset for DL task 
Dataset obtained from Kaggle: https://www.kaggle.com/datasets/kwisatzhaderach/set-cards

In [73]:
import os

In [76]:
def print_folder_tree(base_dir, indent=''):
    items = os.listdir(base_dir)
    
    items = [item for item in items if os.path.isdir(os.path.join(base_dir, item))]
    
    for index, item in enumerate(items):
        item_path = os.path.join(base_dir, item)
        is_last = index == len(items) - 1
        
        if is_last:
            print(indent + '└── ' + item)
            new_indent = indent + '    '
        else:
            print(indent + '├── ' + item)
            new_indent = indent + '│   '
        
        print_folder_tree(item_path, new_indent)

In [78]:
base_dir = 'dataset'
print_folder_tree(base_dir)

├── one
│   ├── green
│   │   ├── diamond
│   │   │   ├── solid
│   │   │   ├── striped
│   │   │   └── open
│   │   ├── squiggle
│   │   │   ├── solid
│   │   │   ├── striped
│   │   │   └── open
│   │   └── oval
│   │       ├── solid
│   │       ├── striped
│   │       └── open
│   ├── red
│   │   ├── diamond
│   │   │   ├── solid
│   │   │   ├── striped
│   │   │   └── open
│   │   ├── squiggle
│   │   │   ├── solid
│   │   │   ├── striped
│   │   │   └── open
│   │   └── oval
│   │       ├── solid
│   │       ├── striped
│   │       └── open
│   └── purple
│       ├── diamond
│       │   ├── solid
│       │   ├── striped
│       │   └── open
│       ├── squiggle
│       │   ├── solid
│       │   ├── striped
│       │   └── open
│       └── oval
│           ├── solid
│           ├── striped
│           └── open
├── zthree
│   ├── green
│   │   ├── diamond
│   │   │   ├── solid
│   │   │   ├── striped
│   │   │   └── open
│   │   ├── squiggle
│   │   │   ├── solid
│   │   │   ├── str

# Create a Custom Dataset Class

In [79]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

In [80]:
class CardDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []

        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for root, _, files in os.walk(class_dir):
                for file in files:
                    if file.endswith('png'):
                        self.image_paths.append(os.path.join(root, file))
                        self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label