In [42]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np

In [43]:
# Load data and labels
labels = pd.read_csv("../data/train_labels.csv")
image_folder = "../data/Train/"
label_encoder = LabelEncoder()
labels["label"] = label_encoder.fit_transform(labels["label"])

In [44]:
def crop_pokemon(img_path):
    img = Image.open(img_path)
    gray_img = img.convert("L") # Convert to grayscale
    img_array = np.array(gray_img)

    # Detect black silhouette (thresholding)
    threshold = 3  # Adjust this value if needed, provavelmente 3/4/5 é o melhor, 3 fico 95% mas ha 1 ou outro estranho
    mask = img_array < threshold

    # Get coordinates of silhouette (bounding box)
    coords = np.column_stack(np.where(mask))

    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)

    # Apply padding (5%)
    pad = int(0.05 * min(x_max - x_min, y_max - y_min))
    x_min, y_min = max(0, x_min - pad), max(0, y_min - pad)
    x_max, y_max = min(img.width, x_max + pad), min(img.height, y_max + pad)

    cropped_img = img.crop((x_min, y_min, x_max, y_max))
    
    #Fill it back to 64x64 with transparent pixels
    width, height = cropped_img.size
    new_img = Image.new("RGBA", (64, 64), (0, 0, 0, 0))
    x_offset = (64 - width) // 2
    y_offset = (64 - height) // 2
    new_img.paste(cropped_img, (x_offset, y_offset))

    return new_img

In [45]:
class PokemonDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, str(self.df.iloc[idx, 0]) + ".png")

        image = crop_pokemon(img_name)
        #image = remove_background(image)
        #image = Image.open(img_name)
        
        if self.transform:
            image = self.transform(image)
        
        if len(self.df.columns) > 1:  # Train Set has labels, Test does not.
            label = self.df.iloc[idx, 1]
            name = str(self.df.iloc[idx, 0]) + ".png"
            return image, label, name
        else:  
            return image, -1, name  # X dont care for Test

In [47]:
dataset = PokemonDataset(labels, image_folder)

In [55]:
dataset = PokemonDataset(labels, image_folder)
for i in range(len(dataset)):
    dataset[i][0].save(os.path.join("../data/TrainCropped", dataset[i][2]))