In [31]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [32]:
from collections import Counter

rest = []
image_list = []

PATH = 'logo_images'
size_list = []
file_names = os.listdir('logo_images')

for name in file_names:
    try:
        img = Image.open(os.path.join(PATH, name)).convert('RGBA')
        bg = Image.new('RGBA', img.size, (255, 255, 255))
        result = Image.alpha_composite(bg, img)
        size = img.size
        image_list.append(img)
        size_list.append(size)
    except:
        print(name)

Counter(size_list).most_common(1)

.DS_Store


[((250, 250), 19)]

In [33]:
len(image_list), len(file_names)

(437, 438)

# Preprocessing

In [34]:
resized = []

for img in image_list:
    resized.append(img.resize((224,224)))

In [35]:
import imgaug.augmenters as iaa

datas = []

seq = iaa.Sequential([
    iaa.Fliplr(0.5),
    iaa.Crop(percent=(0, 0.1)), 
    iaa.Affine(rotate=(-20, 20), shear=(-10, 10)) 
])


for img in resized:
    aug_img = seq(image=np.array(img))
    datas.append(Image.fromarray(aug_img))
    
train_data = datas + resized

In [36]:
assert len(image_list) == len(resized) 
len(train_data)

874

In [37]:
os.makedirs('train_images')

for i, img in enumerate(train_data):
    filename = f"image{i+1}.png"
    img.save(os.path.join('train_images', filename))

# Modeling

In [38]:
dataroot = 'train_images'

batch_size = 32

image_size = 224

RGB = 3

dimz = 100

num_epochs = 10

lr = 0.0002

beta1 = 0.5

In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import Dataset
import matplotlib.animation as animation


class Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_list = os.listdir(root_dir)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_list[idx])
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image
    
transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ])

dataset = Dataset(dataroot, transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [54]:
device

device(type='mps')

In [None]:
batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()