# Setup

In [None]:
import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt

from PIL import Image

import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

# Loading the data into a dataset class

In [None]:
pokemon_info = pd.read_csv('../datasets/processed/pokemon.csv')
pokemon_info.head()

## OneHot encoding the Types

In [None]:
def onehot(df):
    """Use df to:
     - Create unique list of pokemon types
     - Loop through Type1 and Type2 column
     - Create one-hot encoding for each type
     - Return one-hot encoding for each type as numpy array"""

    all_types = df['type1'].to_list() + df['type2'].to_list()
    uniques = list(set(all_types))
    uniques = [ x for x in uniques if isinstance(x, str) ] # remove nan by keeping only strings
    
    onehot = np.zeros((len(df), len(uniques)))
    for i, row in df.iterrows():
        type1 = row['type1']
        type2 = row['type2']

        # only work with non nan
        if isinstance(type1, str):
            onehot[i, uniques.index(type1)] = 1
        if isinstance(type2, str):
            onehot[i, uniques.index(type2)] = 1

    return onehot, uniques

In [None]:
onehot_enc, uniques = onehot(pokemon_info)
print(onehot_enc.shape)
print(uniques)

Helper function for showing a pokemon and its name and types

In [None]:
def show_pokemon(image, types):
    """Show image with name, types and evolution of pokemon"""
    plt.imshow(image)
    plt.title(types)
    plt.pause(0.001)  # pause a bit so that plots are updated

show_pokemon(io.imread(pokemon_info['image'][0]),
               onehot_enc[0])

In [None]:
class PokemonDataset(Dataset):
    """Pokemon dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.pokemon_info = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.pokemon_info)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # img_name = os.path.join(self.root_dir, str(idx+1) + '.png')
        img_name = self.pokemon_info['image'][idx]
        image = io.imread(img_name)

        # onehot encode the types
        sample = {'image': image, 'types': onehot_enc[idx]}

        if self.transform:
            sample = self.transform(sample)

        return sample

Let’s instantiate this class and iterate through the data samples. We will print the sizes of first 4 samples and show their landmarks.

In [None]:
pokemon_dataset = PokemonDataset(csv_file='../datasets/processed/pokemon.csv',
                                 root_dir='../datasets/raw/renders_2d/images/')

fig = plt.figure()

for i, sample in enumerate(pokemon_dataset):

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_pokemon(**sample)

    if i == 3:
        plt.show()
        break

# Transform to Tensors

In [None]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, types = sample['image'], sample['types']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'types': torch.from_numpy(types)}

In [None]:
transformed_dataset = PokemonDataset(csv_file='../datasets/processed/pokemon.csv',
                                     root_dir='../datasets/raw/renders_2d/images/',
                                     transform=transforms.Compose([ToTensor()]))

for i, sample in enumerate(transformed_dataset):
    print(i, sample['image'].size(), sample['types'].size())

    if i == 3:
        break

# Save dataloader and encoding mappings

In [None]:
dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=0)

# save dataloader
torch.save(dataloader, '../dataloaders/pokemon_dataloader.pth')

# save onehot encoding mapping
mapping = dict()
for idx, i in enumerate(uniques):
    mapping[i] = idx

import json
with open('../datasets/processed/onehot_mapping.json', 'w') as f:
    json.dump(mapping, f)