# PyTorch Basics - Samplers

By [Akshaj Verma](https://akshajverma.com)

This notebook takes you through an implementation of `random_split`, `SubsetRandomSampler`, and `WeightedRandomSampler` on [Natural Images](https://www.kaggle.com/prasunroy/natural-images) data using PyTorch.

In [9]:
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt


import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, utils, datasets
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, WeightedRandomSampler

In [10]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f454d64c5f0>

In [11]:
%matplotlib inline
sns.set_style('darkgrid')

## Define Data Path

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("We're using =>", device)

root_dir = "../../data/computer_vision/image_classification/natural-images/"
print("The data lies here =>", root_dir)

We're using => cpu
The data lies here => ../../data/computer_vision/image_classification/natural-images/


## Define Transforms

In [22]:
image_transforms = {
    "train": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ]),
    "test": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
}

## Initialize Dataset

In [26]:
natural_img_dataset = datasets.ImageFolder(root = root_dir,
                                      transform = image_transforms["train"]
                                     )

natural_img_dataset

Dataset ImageFolder
    Number of datapoints: 6899
    Root location: ../../data/computer_vision/image_classification/natural-images/
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR)
               ToTensor()
           )

## Class <=> ID Mapping

In [27]:
natural_img_dataset.class_to_idx

{'airplane': 0,
 'car': 1,
 'cat': 2,
 'dog': 3,
 'flower': 4,
 'fruit': 5,
 'motorbike': 6,
 'person': 7}