In [1]:
from __future__ import print_function, division
import os
from io import BytesIO
import bson
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [2]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()

In [28]:
class CdiscountDataset(Dataset):
    """Cdiscount dataset."""

    def __init__(self, offsets_csv, images_csv, bson_file_path, with_label, transform=None):
        self.offsets_df = pd.read_csv(offsets_csv, index_col=0)
        self.images_df = pd.read_csv(images_csv, index_col=0)
        self.bson_file = open(bson_file_path, "rb")
        self.with_label = with_label
        self.transform = transform

    def __len__(self):
        return len(self.images_df)

    def __getitem__(self, idx):
        image_row = self.images_df.iloc[idx]
        product_id = image_row["product_id"]
        offset_row = self.offsets_df.loc[product_id]

        # Read this product's data from the BSON file.
        self.bson_file.seek(offset_row["offset"])
        item_data = self.bson_file.read(offset_row["length"])

        # Grab the image from the product.
        item = bson.BSON.decode(item_data)
        img_idx = image_row["img_idx"]
        bson_img = item["imgs"][img_idx]["picture"]
        img = io.imread(BytesIO(bson_img))
        
        if self.transform:
            img = self.transform(img)
            
        label = torch.FloatTensor(1)        
        if self.with_label:
            label = torch.FloatTensor([image_row["category_idx"].item()])

        sample = {'img': img, 'label': label}
        return sample

In [29]:
class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        img = transform.resize(sample, (self.output_size, self.output_size))
        return img
    
class ToTensor(object):
    def __call__(self, sample):
        sample = sample.transpose((2, 0, 1))
        return torch.from_numpy(sample)

# train_dataset = CdiscountDataset(
#     offsets_csv="train_offsets.csv",
#     images_csv="dev_train_images.csv",
#     bson_file_path="/mnt/data/cdiscount/train.bson",
#     with_label=True,
#     transform=transforms.Compose([
#         Rescale(128),
#         ToTensor()
#     ])
# )

val_dataset = CdiscountDataset(
    offsets_csv="train_offsets.csv",
    images_csv="dev_val_images.csv",
    bson_file_path="/mnt/data/cdiscount/train.bson",
    with_label=True,
    transform=transforms.Compose([
        Rescale(256),
        ToTensor()
    ])
)

In [30]:
# for i in range(len(val_dataset)):
#     sample = val_dataset[i]
#     img, label = sample['img'], sample['label']
#     print(label)
#     if i == 30:
#         break

In [31]:
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=1)

In [32]:
import time

time_start = time.clock()
for i_batch, sample_batched in enumerate(val_dataloader):
    print(i_batch)
    print(time.clock() - time_start)
    time_start = time.clock()
    if (i_batch == 10):
        break

Exception ignored in: <bound method DataLoaderIter.__del__ of <torch.utils.data.dataloader.DataLoaderIter object at 0x7f40a6de1dd8>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 241, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 237, in _shutdown_workers
    self.index_queue.put(None)
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 355, in put
    self._writer.send_bytes(obj)
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
OSError: [Errno 9] Bad file descriptor


0
0.042221000000012054
1
0.07671400000000972
2
0.07833500000000981
3
0.06738400000000411
4
0.06701400000000035
5
0.06327400000000694
6
0.06129500000000121
7
0.06340099999999893
8
0.0669019999999989
9
0.07105000000001382
10
0.06867000000002577
