In [6]:
import torch
import torchvision

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:
def set_seeds(seed: int=42):
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)

In [14]:
import os
import zipfile
from pathlib import Path
import requests

def get_data(source:str, destination: str, remove_source: bool = True) -> Path:
    data_path = Path("data")
    image_path = data_path / destination

    if image_path.is_dir():
        print("Dataset sudah ada")
    else:
        print("Dataset belum ada,mendownload dataset...")

        image_path.mkdir(parents=True,exist_ok=True)

        target_file = Path(source).name

        with open(data_path / target_file, "wb") as f:
            res = requests.get(source)
            f.write(res.content)

        with zipfile.ZipFile(data_path / target_file, "r") as zf:
            print("Mengektrak data...")
            zf.extractall(image_path)
        
        if remove_source:
                os.remove(data_path / target_file)

    return image_path

In [15]:
image_path  = get_data(
    source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi_20_percent.zip",
    destination="pizza_steak_sushi"
)

image_path

Dataset belum ada,mendownload dataset...
Mengektrak data...


WindowsPath('data/pizza_steak_sushi')

In [16]:
from torchvision.transforms import v2
from torchvision.transforms.functional import InterpolationMode

manual_transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.Resize(size=(256, 256), interpolation=InterpolationMode.BICUBIC),
        v2.CenterCrop(size=(224, 224)),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [20]:
from going_modular.data_setup import create_dataloaders

train_dir = image_path / "train"
test_dir = image_path / "test"
BATCH_SIZE = 32

train_dataloader, test_dataloader, class_names  = create_dataloaders(train_dir,test_dir,manual_transforms,manual_transforms,BATCH_SIZE)

train_dataloader, test_dataloader, class_names

(<torch.utils.data.dataloader.DataLoader at 0x1f68a5783d0>,
 <torch.utils.data.dataloader.DataLoader at 0x1f68a57a5c0>,
 ['pizza', 'steak', 'sushi'])

<torch.utils.data.dataloader.DataLoader at 0x1f68a6e64a0>