## Downloading and extracting the Fashionpedia dataset

In [None]:
# This creates a git filter that strips the notebook output when committing
!git config filter.strip-notebook-output.clean 'jupyter nbconvert --ClearOutputPreprocessor.enabled=True --to=notebook --stdin --stdout --log-level=ERROR'

# Download dataset (4 GB)
!mkdir fashionpedia
!mkdir fashionpedia/img

!curl https://s3.amazonaws.com/ifashionist-dataset/images/train2020.zip -o fashionpedia/train.zip
!unzip fashionpedia/train.zip -d fashionpedia/img
!rm fashionpedia/train.zip

!curl https://s3.amazonaws.com/ifashionist-dataset/annotations/instances_attributes_train2020.json -o fashionpedia/attributes.json

## Preprocess data
### Load data

In [None]:
import json
import lib.fashionpedia_type as fpt

with open('fashionpedia/attributes.json') as item:
    att: fpt.FashionPedia = json.load(item)

In [None]:
def get_by_id(group, id):
    '''Helper function for finding a single item by id from a group in the dataset'''
    for item in att[group]:
        if item['id'] == id:
            return item

### Preprocessing functions

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def show_segmentation(annotation: fpt.Annotation):
    file_name = get_by_id('images', annotation['image_id'])['file_name']
    img = mpimg.imread(f'fashionpedia/img/{file_name}')
    plt.imshow(img)

    seg = annotation['segmentation'][0]
    xs = seg[::2]
    ys = seg[1::2]
    plt.plot(xs, ys, c='red')

    plt.show()

test = get_by_id('annotations', 11)
print(test['area'])
show_segmentation(test)
        

In [None]:
from PIL import Image, ImageDraw

def crop_segmentation(annotation: fpt.Annotation):
    file_name = get_by_id('images', annotation['image_id'])['file_name']
    img = Image.open(f'fashionpedia/img/{file_name}').convert('RGB')

    # make segmentation mask
    mask = Image.new('1', img.size, 1)

    seg = annotation['segmentation'][0]
    points = list(zip(*(iter(seg),) * 2))

    ImageDraw.Draw(mask).polygon(points, outline=0, fill=0)
    
    # Set all but masked area white
    img.paste((256, 256, 256), mask)
    
    # crop to bounding box
    x, y, width, height = map(int, annotation['bbox'])
    cropped_image = img.crop((x, y, x + width, y + height))

    return cropped_image

test = get_by_id('annotations', 11)
plt.imshow(crop_segmentation(test))
plt.show()

In [None]:
from torchvision import transforms

def resize(img):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224))
    ])

    return transform(img)

### Sample data from original dataset

In [None]:
# Dataset sampling conditions
SUPERCATEGORY = 'upperbody'
DATASET_SIZE = 5000
MIN_GARMENT_AREA = 1

selected_cat = [cat for cat in att['categories'] if cat['supercategory'] == SUPERCATEGORY]
selected_cat_ids = [cat['id'] for cat in selected_cat]

print("Selected categories:", *[cat['name'] for cat in selected_cat], sep='\n')

In [None]:
import numpy as np

# select 32 most frequent attributes
attribute_ids = [att_id for a in att['annotations'] if a['category_id'] in selected_cat_ids for att_id in a['attribute_ids']]
attribute_freqs = np.unique(attribute_ids, return_counts=True)
selected_att_ids, _ = list(zip(*sorted(zip(*attribute_freqs), key=lambda i: i[1], reverse=True)[:32]))
sorted_selected_att_ids = sorted(selected_att_ids)

print("Most frequent attributes:", *[get_by_id('attributes', att_id)['name'] for att_id in selected_att_ids], sep='\n')

In [None]:
import torch
from tqdm.autonotebook import tqdm

dataset = []
att_counts = {int(i): 0 for i in selected_att_ids}
with tqdm(total=DATASET_SIZE) as pbar:
    for ann in att['annotations']:
        if not ann['category_id'] in selected_cat_ids:
            continue

        if ann['area'] < MIN_GARMENT_AREA:
            continue

        if type(ann['segmentation']) != list:
            # skip images with RLE segmentation masks
            continue

        attributes = [att_id for att_id in ann['attribute_ids'] if att_id in selected_att_ids]
        attributes_one_hot = [int(att_id in attributes) for att_id in sorted_selected_att_ids]

        if not attributes:
            continue

        for a in attributes:
            att_counts[a] += 1

        # Crop, set to tensor and resize image
        img = crop_segmentation(ann)
        img = resize(img)
        
        dataset.append({
                'fn': get_by_id('images', ann['image_id'])['file_name'],
                'img': img,
                'cat': ann['category_id'],
                'att_oh': torch.tensor(attributes_one_hot, dtype=torch.float)
            })
        pbar.update()
        
        if len(dataset) >= DATASET_SIZE:
            break

In [None]:
def normalize_dataset(data):
	# Transform list to a Tensor
	images = torch.stack([d['img'] for d in data])

	# Create normalization transformation
	means = torch.mean(images, dim = [0,2,3])
	stds = torch.std(images, dim = [0,2,3])
	normalize = transforms.Normalize(mean=means, std=stds, inplace=True)

	# Normalize the images
	for d in tqdm(data):
		normalize(d['img'])

	return means, stds


means, stds = normalize_dataset(dataset)
plt.imshow(dataset[2]['img'].permute(1, 2, 0))

## Save dataset

In [None]:
import torch

# save dataset
with open("fashionpedia/processed_dataset.pt", "wb") as file:
    torch.save({"dataset": dataset, "means": means, "stds": stds}, file)

with open("fashionpedia/selected_attributes.json", "w") as file:
    selected_atts = {int(i): get_by_id('attributes', i)['name'] for i in selected_att_ids}

    json.dump(selected_atts, file, indent=1, sort_keys=True)

In [None]:
plt.barh([get_by_id('attributes', i)['name'] for i in att_counts.keys()], att_counts.values())
plt.show()

In [None]:
def get_image_and_attributes(img):
    i_id = img['id']
    image_id = img['image_id']

    style_annotations = get_by_id('annotations', i_id)
    style_image = get_by_id('images', image_id)

    cat_names = att['categories'][style_annotations['category_id']]['name']
    att_names = ', '.join([att_['name']  for attri_id in style_annotations['attribute_ids'] for att_ in att['attributes'] if att_['id'] == attri_id]).replace('\\'', '')
	
    prompt = f'Categories: {cat_names}, Attributes: {att_names}'

    print(prompt)
    print(style_image['file_name'])
    print(style_image['original_url'])
    print('\\n')

    return style_annotations
    
for d in dataset[:5]:
    test = get_image_and_attributes(d)
    plt.imshow(crop_segmentation(test))
    plt.axis('off')
    plt.show()