In [None]:
from torch.utils.data import Dataset
import boto3
from botocore.client import Config
from botocore.exceptions import ClientError
from pathlib import Path
import logging
import cv2
import numpy as np
import random
import collections
import matplotlib.pyplot as plt
import boto3
import boto3.session
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed, Executor
import threading

logging.getLogger().setLevel(logging.INFO)
from torch.utils.data.sampler import Sampler
from collections import defaultdict
from tqdm import tqdm

In [None]:
StorageConfig = collections.namedtuple(
    "StorageConfig",
    [
        "endpoint_url",
        "aws_access_key_id",
        "aws_secret_access_key",
        "region_name",
        "bucket_name",
        "config",
    ],
)

config = StorageConfig(
    endpoint_url="http://localhost:9000",
    aws_access_key_id="username",
    aws_secret_access_key="password",
    config=Config(signature_version="s3v4"),
    region_name="us-east-1",
    bucket_name="data",
)


def get_bucket(w, h):
    if w > 900:
        bucket_w = 900
    elif w < 120:
        bucket_w = (w // 10 + 1) * 10
    else:
        bucket_w = (w // 15) * 15

    if h > 45:
        bucket_h = (h // 10) * 10
    else:
        bucket_h = 32

    return bucket_w, bucket_h

In [None]:
class RemoteDataset(Dataset):
    def __init__(
        self,
        storage_config: StorageConfig,
        project_name: str,
    ) -> None:
        super().__init__()

        self.client = boto3.client(
            "s3",
            endpoint_url=storage_config.endpoint_url,
            aws_access_key_id=storage_config.aws_access_key_id,
            aws_secret_access_key=storage_config.aws_secret_access_key,
            config=storage_config.config,
            region_name=storage_config.region_name,
        )

        self.total_count = 0
        try:
            logging.info(
                f"Fetching data information from {storage_config.endpoint_url}, please wait..."
            )
            result = self.client.list_objects(
                Bucket=storage_config.bucket_name,
                Prefix=f"/{project_name}/",
                Delimiter="/",
            )

            self.buckets = {}
            if "CommonPrefixes" in result:
                # Hotfix: get first version
                item = result["CommonPrefixes"][0]

                response = self.client.list_objects_v2(
                    Bucket=storage_config.bucket_name,
                    Prefix=item["Prefix"],
                )

                keys = response["Contents"]
                while (
                    "NextContinuationToken" in response
                    and response["NextContinuationToken"]
                ):
                    response = self.client.list_objects_v2(
                        Bucket=storage_config.bucket_name,
                        Prefix=item["Prefix"],
                        ContinuationToken=response["NextContinuationToken"],
                    )

                    keys.extend(response["Contents"])

                self.cluster_indices = defaultdict(list)
                for file in tqdm(keys):
                    info = self.client.head_object(
                        Bucket=storage_config.bucket_name, Key=file["Key"]
                    )
                    bucket_w, bucket_h = get_bucket(
                        int(info["Metadata"]["width"]), int(info["Metadata"]["height"])
                    )
                    self.cluster_indices[(bucket_w, bucket_h)].append(file["Key"])
                self.total_count = len(keys)
        except ClientError as e:
            logging.critical(e)
            raise ValueError(
                "Something went wrong when fetching data information from {} with provided credentials"
            )
        self.bucket_name = config.bucket_name

    def __len__(self):
        return self.total_count

    def __getitem__(self, index):
        key = None
        for data in self.cluster_indices.values():
            index = index - len(data)
            if index < 0:
                key = data[index + len(data)]
                break
        if key is None:
            raise IndexError(f"Cannot get item with index {index}")

        data = self.client.get_object(Bucket=self.bucket_name, Key=key)
        contents = data["Body"].read()
        image_np = np.frombuffer(contents, np.uint8)
        img_np = cv2.imdecode(image_np, cv2.IMREAD_COLOR)

        return img_np

    def _get_keys(self, response: dict) -> list:
        if "Contents" in response:
            return [item["Key"] for item in response["Contents"]]
        else:
            return []

In [None]:
dataset = RemoteDataset(project_name="online_fetching", storage_config=config)
len(dataset)

In [None]:
class GetDataTask(threading.Thread):
    def run(self):
        storage_config, keys = self._args
        client = boto3.client(
            "s3",
            endpoint_url=storage_config.endpoint_url,
            aws_access_key_id=storage_config.aws_access_key_id,
            aws_secret_access_key=storage_config.aws_secret_access_key,
            config=storage_config.config,
            region_name=storage_config.region_name,
        )

        with ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(get_single_file, config.bucket_name, client, key)
                for key in keys
            ]

        self.result = []
        for f in as_completed(futures):
            self.result.append(f.result())


def get_single_file(bucket, client, key):
    data = client.get_object(Bucket=bucket, Key=key)
    contents = data["Body"].read()
    image_np = np.frombuffer(contents, np.uint8)
    return cv2.imdecode(image_np, cv2.IMREAD_COLOR)

In [None]:
class ClusterRandomSampler(Sampler):
    def __init__(self, data_source, batch_size, shuffle=True):
        super().__init__(data_source)
        self.data_source = data_source
        self.batch_size = batch_size
        self.shuffle = shuffle

    @staticmethod
    def flatten_list(lst):
        return [item for sublist in lst for item in sublist]

    def __iter__(self):
        batch_lists = []

        for cluster_indices in self.data_source.cluster_indices.values():
            if self.shuffle:
                random.shuffle(cluster_indices)

            batches = [
                cluster_indices[i : i + self.batch_size]
                for i in range(0, len(cluster_indices), self.batch_size)
            ]
            batches = [_ for _ in batches if len(_) == self.batch_size]

            if self.shuffle:
                random.shuffle(batches)

            batch_lists.append(batches)

        batch_lists = self.flatten_list(batch_lists)
        if self.shuffle:
            random.shuffle(batch_lists)

        self.batch_lists = batch_lists
        self.index = 0

        self.thread = GetDataTask(args=(config, self.batch_lists[0]))
        self.thread.start()
        return self

    def __next__(self):
        if self.index <= len(self.batch_lists):
            self.thread.join()
            self.index += 1

            result = self.thread.result

            self.thread = GetDataTask(args=(config, self.batch_lists[0]))
            self.thread.start()
            return result
        else:
            raise StopIteration

    def __len__(self):
        return len(self.data_source)

In [None]:
plt.imshow(dataset[1])

In [None]:
sampler = ClusterRandomSampler(data_source=dataset, batch_size=16, shuffle=True)

In [None]:
for i, batch in enumerate(sampler):
    if i == 3:
        break

In [None]:
np.max(batch[0])

In [None]:
batch[0].shape

In [None]:
plt.imshow(batch[2])