# PyTorch Basics - Samplers

By [Akshaj Verma](https://akshajverma.com)

This notebook takes you through an implementation of `random_split`, `SubsetRandomSampler`, and `WeightedRandomSampler` on [Natural Images](https://www.kaggle.com/prasunroy/natural-images) data using PyTorch.

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt


import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, utils, datasets
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, WeightedRandomSampler

In [2]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f6b5f43f0f0>

In [3]:
%matplotlib inline
sns.set_style('darkgrid')

## Define Data Path

Set the root directory for the dataset.

In [8]:
root_dir = "../../data/computer_vision/image_classification/natural-images/"
print("The data lies here =>", root_dir)

The data lies here => ../../data/computer_vision/image_classification/natural-images/


## Define Transforms

Crop the images to be of size `(224, 224)` and convert them to tensors.

In [5]:
image_transforms = {
    "train": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
}

## Initialize Dataset

Using `ImageFolder`, we will create out dataset. We'll only use the train folder for this blogpost.

In [6]:
natural_img_dataset = datasets.ImageFolder(root = root_dir,
                                      transform = image_transforms["train"]
                                     )

natural_img_dataset

FileNotFoundError: [Errno 2] No such file or directory: '../../data/computer_vision/image_classification/natural-images/'

## Class <=> ID Mapping

The `.class_to_idx` method returns the class-mapping label in the dataset.

In [None]:
natural_img_dataset.class_to_idx

We will create a dictionary called `idx2class` which is the reverse of `class_to_idx` method in PyTorch.

In [None]:
idx2class = {v: k for k, v in natural_img_dataset.class_to_idx.items()}
idx2class

## Observe Class Distribution

To observe the distribution of different classes in a dataset object, we create a function called `get_class_distribution()`. This function takes a dataset as an input argument and returns a dictionary which contains the count of all classes in the dataset object.

1. To to this, we first initialize our `count_dict` where all the class counts are 0.
2. The we iterate over our dataset object to extract the class labels. The dataset object contains elements in the form of a tuple (x,y). So, we need to extract the item at position 1 from the tuple.
3. Then we use the `idx2class` to get the class name from the class id.
4. Finally, we update the count in our `count_dict` by 1 for the relevant class-key.

In [None]:
def get_class_distribution(dataset_obj):
    count_dict = {k:0 for k,v in dataset_obj.class_to_idx.items()}
    
    for element in dataset_obj:
        y_lbl = element[1]
        y_lbl = idx2class[y_lbl]
        count_dict[y_lbl] += 1
            
    return count_dict

In [None]:
print("Distribution of classes: \n", get_class_distribution(natural_img_dataset))

To plot our dictionary, we use the Seaborn library. We first convert our dictionary to a dataframe and then melt it. Finally, we use the function `sns.barplot()` to construct our plot.

In [None]:
plt.figure(figsize=(15,8))
sns.barplot(data = pd.DataFrame.from_dict([get_class_distribution(natural_img_dataset)]).melt(), x = "variable", y="value", hue="variable").set_title('Natural Images Class Distribution')

From the above graph, we observe that the classes are imbalanced.

## `random_split()`

`random_split(dataset, lengths)` works directly on the dataset. The function expects 2 input arguments. The first argument is the dataset. The second is a tuple of lengths. If we want to split our dataset into 2 parts, we will provide a tuple with 2 numbers. These numbers are the sizes of the corresponding datasets after the split. 

Our dataset has 6899 images. If we want to split this into 2 parts (*train/test, train/val*) of size (6000, 899), we will call random split as `random_split(6000, 899)`.

Let's split our dataset into train and val sets.

In [None]:
train_dataset, val_dataset = random_split(natural_img_dataset, (6000, 899))

Pass data to the dataloader.

In [None]:
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=1)
val_loader = DataLoader(dataset=val_dataset, shuffle=False, batch_size=1)

In [None]:
print("Length of the train_loader:", len(train_loader))
print("Length of the val_loader:", len(val_loader))

Note that we have used a `batch_size = 1`. If we increase the `batch_size`, the number of images would be the same but the length of train/val loaders would change.

Let's take a look at the distribution of classes in the train and val loaders.

In [None]:
def get_class_distribution_loaders(dataloader_obj, dataset_obj):
    count_dict = {k:0 for k,v in dataset_obj.class_to_idx.items()}
    
    for _,j in dataloader_obj:
        y_idx = j.item()
        y_lbl = idx2class[y_idx]
        count_dict[str(y_lbl)] += 1
            
    return count_dict

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(18,7))

sns.barplot(data = pd.DataFrame.from_dict([get_class_distribution_loaders(train_loader, natural_img_dataset)]).melt(), x = "variable", y="value", hue="variable",  ax=axes[0]).set_title('Train Set')
sns.barplot(data = pd.DataFrame.from_dict([get_class_distribution_loaders(val_loader, natural_img_dataset)]).melt(), x = "variable", y="value", hue="variable",  ax=axes[1]).set_title('Val Set')

## `SubsetRandomSampler()`

`SubsetRandomSampler(indices)` takes as input the indices of data. 

We first create our samplers and then we'll pass it to our dataloaders. 

1. Create a list of indices. 
2. Shuffle the indices. 
3. Split the indices based on train-val percentage.
4. Create `SubsetRandomSampler`.

Create a list of indices from 0 to length of dataset.

In [None]:
dataset_size = len(natural_img_dataset)
dataset_indices = list(range(dataset_size))

Shuffle the list of indices using `np.shuffle`.

In [None]:
np.random.shuffle(dataset_indices)

Create the split index. We choose the split index to be 20% (0.2) of the dataset size.

In [None]:
val_split_index = int(np.floor(0.2 * dataset_size))

Slice the lists to obtain 2 lists of indices, one for train and other for test.

> `0`-----------`val_split_index`------------------------------`n`.

Train => `val_split_index` to `n`


Val => `0` to `val_split_index`

In [None]:
train_idx, val_idx = dataset_indices[val_split_index:], dataset_indices[:val_split_index]

Finally, create samplers.

In [None]:
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

Now, we will pass the samplers to our dataloader. Note that `shuffle=True` cannot be used when you're using the `SubsetRandomSampler`.

In [None]:
train_loader = DataLoader(dataset=natural_img_dataset, shuffle=False, batch_size=1, sampler=train_sampler)
val_loader = DataLoader(dataset=natural_img_dataset, shuffle=False, batch_size=1, sampler=val_sampler)

Now, we'll plot the class distribution in our dataloaders.

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(18,7))

sns.barplot(data = pd.DataFrame.from_dict([get_class_distribution_loaders(train_loader, natural_img_dataset)]).melt(), x = "variable", y="value", hue="variable",  ax=axes[0]).set_title('Train Set')
sns.barplot(data = pd.DataFrame.from_dict([get_class_distribution_loaders(val_loader, natural_img_dataset)]).melt(), x = "variable", y="value", hue="variable",  ax=axes[1]).set_title('Val Set')

As we can observe, the number of samples per class in the validation set is proportional to the number in train set.

## `WeightedRandomSampler()`

`WeightedRandomSampler` is used, unlike `random_split` and `SubsetRandomSampler`, to ensure that each batch sees a proportional number of all classes.


1. Get all the target classes. 
2. Shuffle the target classes.
3. Get the class weights. Class weights are the reciprocal of the number of items per class.
4. Obtain corresponding weight for each target sample.

First an example with lists.

In [7]:
classes = [0, 1]
print(f"Classes = {classes}")

target = [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
print(f"\nTargets = {target}")

shuffled_targets = torch.tensor(target)[torch.randperm(len(target))]
print(f"\nShuffled Targets = {shuffled_targets}")

idx_of_class_0 = [idx for idx, element in enumerate(shuffled_targets) if element==0]
idx_of_class_1 = [idx for idx, element in enumerate(shuffled_targets) if element==1]

print(f"\nIndices for class-0 = {idx_of_class_0}")
print(f"Indices for class-1 = {idx_of_class_1}")

class_count = torch.tensor([len(idx_of_class_0), len(idx_of_class_1)])
class_weights = 1/class_count
print(f"\nWeights for each class = {class_weights}")

weights_for_each_sample = class_weights[shuffled_targets]
print(f"\nAssign class-weights to each sample = {weights_for_each_sample}")

weighted_sampler = WeightedRandomSampler(
    weights=weights_for_each_sample,
    num_samples=len(weights_for_each_sample),
    replacement=True
)
print(f"\nWeighted Random Sampler   = {list(weighted_sampler)}")
print(f"Output classes in sampler = {[0 if i in idx_of_class_0 else 1 for i in list(weighted_sampler)]}")

Classes = [0, 1]

Targets = [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]

Shuffled Targets = tensor([1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1])

Indices for class-0 = [1, 7]
Indices for class-1 = [0, 2, 3, 4, 5, 6, 8, 9, 10]

Weights for each class = tensor([0.5000, 0.1111])

Assign class-weights to each sample = tensor([0.1111, 0.5000, 0.1111, 0.1111, 0.1111, 0.1111, 0.1111, 0.5000, 0.1111,
        0.1111, 0.1111])

Weighted Random Sampler   = [7, 1, 2, 6, 1, 4, 1, 2, 7, 5, 4]
Output classes in sampler = [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0]


Now, let's go back to the dataset example.

Obtain the list of target classes and shuffle.

In [None]:
target_list = torch.tensor(natural_img_dataset.targets)
target_list = target_list[torch.randperm(len(target_list))]

Get the class counts and calculate the weights/class by taking its reciprocal.

In [None]:
class_count = [i for i in get_class_distribution(natural_img_dataset).values()]
class_weights = 1./torch.tensor(class_count, dtype=torch.float) 
class_weights

Assign the weight of each class to all the samples. 

In [None]:
class_weights_all = class_weights[target_list]
class_weights_all

Pass the `weight` and `number of samples` to the WeightedRandomSampler.

In [None]:
weighted_sampler = WeightedRandomSampler(
    weights=class_weights_all,
    num_samples=len(class_weights_all),
    replacement=True
)

Pass the sampler to the dataloader.

In [None]:
train_loader = DataLoader(dataset=natural_img_dataset, shuffle=False, batch_size=8, sampler=weighted_sampler)

In [None]:
len(train_loader)

In [None]:
batch_count_dict = {k:[] for k,v in natural_img_dataset.class_to_idx.items()}
batch_count_dict

In [None]:
for _, batch in train_loader:
    temp_batch_count_dict = {k:0 for k,v in natural_img_dataset.class_to_idx.items()}
    
    for item in batch:
        op_id = item.item()
        op_class = idx2class[op_id]
        temp_batch_count_dict[op_class] += 1
        
    {batch_count_dict[k].append(v) for k,v in temp_batch_count_dict.items()}

In [None]:
plt.figure(figsize=(15,8))


for c in batch_count_dict.keys():
    sns.distplot(batch_count_dict[c], hist=False)