In [1]:
import numpy as np
import torch
import copy
import torch.nn.functional as F
from torch import nn, optim

In [2]:
import pandas as pd

import ast
import logging

In [3]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [4]:
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display
import io

In [14]:
from google.cloud import storage
client = storage.Client()
bucket = client.get_bucket('urbandetection')

In [None]:
def get_transform(train=True):
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    transform = []
    #transform.append(transforms.Resize([224,224]))
    if train:
        transform.append(transforms.Resize([300, 300]))  
        #transform.append(transforms.CenterCrop([224, 224]))          
        #transform.append(transforms.CenterCrop(256))
        transform.append(transforms.RandomCrop(224))
        transform.append(transforms.RandomHorizontalFlip())
    else:
        transform.append(transforms.Resize([300,300]))
        transform.append(transforms.CenterCrop([224, 224]))          

    transform.append(transforms.ToTensor())
    transform.append(normalize)
    return transforms.Compose(transform)

class FashionDataset(Dataset):
    
    def __init__(self, pkl_file, transform=None):
        
        self.transform = transform
        self.items = pd.read_pickle(pkl_file)
        
        
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        
        rows = self.items.iloc[idx]
        
        #print(rows)
        name = rows['key']
        label = rows['패턴']

        blob = bucket.get_blob(row['key'])
        img_str = blob.download_as_string()
        img = Image.open(io.BytesIO(img_str))
        
        if self.transform is not None:
            sample = self.transform(img)
        else:
            sample = np.asarray(img)
            
        
        return sample, label