# Data

## Template data.py

In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torchvision
from torch.utils.data import Dataset
import warnings
warnings.filterwarnings("ignore")


class MyDataset(Dataset):
    def __init__(self, data_path, train=True):
        self.train = train
        self.data = torchvision.datasets.MNIST(root=data_path,
                                               train=self.train,
                                               transform=torchvision.transforms.ToTensor(),
                                               download=True)

        _, self.width, self.height = self.data[0][0].shape
        self.in_dim = self.width * self.height
        self.out_dim = len(self.data.classes)

    def __len__(self):
        return self.data.__len__()

    def __getitem__(self, idx):
        return self.data.__getitem__(idx)

    def get_features(self):
        return self.data.train_data if self.train else self.data.data.float()

    def get_labels(self):
        return self.data.train_labels if self.train else self.data.targets

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# 데이터셋 인스턴스 생성 및 테스트
data_path = "../datasets" # 데이터 경로 지정
datasets = MyDataset(data_path, train=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1076)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../datasets/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5546802.10it/s] 


Extracting ../datasets/MNIST/raw/train-images-idx3-ubyte.gz to ../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1076)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 145338.82it/s]


Extracting ../datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1076)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1399972.06it/s]


Extracting ../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1076)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2605022.39it/s]


Extracting ../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../datasets/MNIST/raw



In [5]:
# 데이터셋의 첫 번째 샘플 가져오기
sample = datasets[0]
print(f"Image shape: {sample[0].shape}, Label: {sample[1]}")

Image shape: torch.Size([1, 28, 28]), Label: 5


## Dataset load

In [10]:
from datasets import load_dataset
dataset = load_dataset("reczoo/AmazonBooks_m1")

Found cached dataset text (/root/.cache/huggingface/datasets/reczoo___text/reczoo--AmazonBooks_m1-3655901e653150e2/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)
100%|██████████| 2/2 [00:00<00:00, 938.43it/s]


### Save to disk

In [12]:
data_path = '../datasets/AmazonBooks_m1'
dataset.save_to_disk(data_path)

                                                                                                  

In [13]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from datasets import load_dataset
from torch.utils.data import Dataset
import torch

class MyDataset(Dataset):
    def __init__(self, data_path, hugging_path, train=True):
        self.train = train
        self.data = load_dataset(hugging_path, # huggingface path
                                 split="train" if train else "test")
        self.data.save_to_disk(data_path)

    def __len__(self):
        return len(self.data)


- 특정 파일만 huggingface에서 다운로드하기

In [22]:
from huggingface_hub import hf_hub_download

dataset_list = ["AmazonBooks_m1",
                "Yelp18_m1",
                "Gowalla_m1"]

for _ in dataset_list:
    dataset = hf_hub_download("reczoo/" + _, 
                       filename="train.txt",
                       repo_type="dataset",
                       cache_dir="../datasets/" + _)

for _ in dataset_list:
    dataset = hf_hub_download("reczoo/" + _, 
                       filename="test.txt",
                       repo_type="dataset",
                       cache_dir="../datasets/" + _)

dataset = hf_hub_download("reczoo/Movielens1M_m1", 
                       filename="train_data.json",
                       repo_type="dataset",
                       cache_dir="../datasets/Movielens1M_m1")

dataset = hf_hub_download("reczoo/Movielens1M_m1", 
                       filename="test_data.json",
                       repo_type="dataset",
                       cache_dir="../datasets/Movielens1M_m1")

Downloading train.txt: 100%|██████████| 14.1M/14.1M [00:00<00:00, 43.0MB/s]
Downloading train.txt: 100%|██████████| 6.90M/6.90M [00:01<00:00, 6.44MB/s]
Downloading train.txt: 100%|██████████| 4.63M/4.63M [00:00<00:00, 63.8MB/s]
Downloading test.txt: 100%|██████████| 3.85M/3.85M [00:00<00:00, 96.0MB/s]
Downloading test.txt: 100%|██████████| 1.99M/1.99M [00:00<00:00, 46.8MB/s]
Downloading test.txt: 100%|██████████| 1.37M/1.37M [00:00<00:00, 65.2MB/s]
Downloading train_data.json: 100%|██████████| 4.44M/4.44M [00:00<00:00, 30.0MB/s]
Downloading test_data.json: 100%|██████████| 565k/565k [00:00<00:00, 35.3MB/s]
