In [1]:
%matplotlib inline


Transforms
===================

Data does not always come in its final processed form that is required for
training machine learning algorithms. We use **transforms** to perform some
manipulation of the data and make it suitable for training.

All TorchVision datasets have two parameters -``transform`` to modify the features and
``target_transform`` to modify the labels - that accept callables containing the transformation logic.
The [torchvision.transforms](https://pytorch.org/vision/stable/transforms.html) module offers
several commonly-used transforms out of the box.

The FashionMNIST features are in PIL Image format, and the labels are integers.
For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors.
To make these transformations, we use ``ToTensor`` and ``Lambda``.

Broken down and explained below.

In [3]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(), # can also take multiple transforms using `Compose`!
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

## ToTensor()

[ToTensor](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor)
converts a PIL image or NumPy ``ndarray`` into a ``FloatTensor``. and scales
the image's pixel intensity values in the range [0., 1.]




## Lambda Transforms (Custom Transforms!)

Lambda transforms apply any user-defined lambda function. Here, we define a function
to turn the integer into a one-hot encoded tensor.
It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls
[scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html) which *assigns a ``value=1`` on the index as given by the label ``y``.*



In [4]:
target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

--------------




## Exercise
Let's try creating a custom transformation for the Kuzushiji-MNIST data abbreviated as `KMNIST`. [KMNIST](https://pytorch.org/vision/stable/generated/torchvision.datasets.KMNIST.html?highlight=kmnist#torchvision.datasets.KMNIST)

You will be applying not one, but **2** transformations using `Compose`:
1. [CenterCrop](https://pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html?highlight=centercrop#torchvision.transforms.CenterCrop) with a size of 20
2. `ToTensor()`

Also one-hot encoding the target integer value will also be done.

Hint: The [Compose](https://pytorch.org/vision/stable/generated/torchvision.transforms.Compose.html#torchvision.transforms.Compose) function will come in handy! Let's get used to reading the documentations!


In [None]:
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, CenterCrop, Compose

##### TODO #######
# transform_sequential = [] # fill in this list!
# KMNIST_transforms = # Hint: use `Compose`.
# target_onehot = # Hint: hmm.. We have seen it just now!

# Hint: follow above, but its a different dataset!
# ds = 

In [5]:
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, CenterCrop, Compose


transform_sequential = [CenterCrop(20), ToTensor()]
KMNIST_transforms = Compose(transform_sequential)

ds = datasets.KMNIST(
    root="data",
    train=True,
    download=True,
    transform=KMNIST_transforms,
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz
Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz to data/KMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/18165135 [00:00<?, ?it/s]

Extracting data/KMNIST/raw/train-images-idx3-ubyte.gz to data/KMNIST/raw

Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz
Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz to data/KMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29497 [00:00<?, ?it/s]

Extracting data/KMNIST/raw/train-labels-idx1-ubyte.gz to data/KMNIST/raw

Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz
Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz to data/KMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/3041136 [00:00<?, ?it/s]

Extracting data/KMNIST/raw/t10k-images-idx3-ubyte.gz to data/KMNIST/raw

Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz
Downloading http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz to data/KMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5120 [00:00<?, ?it/s]

Extracting data/KMNIST/raw/t10k-labels-idx1-ubyte.gz to data/KMNIST/raw



### Further Reading
- [torchvision.transforms API](https://pytorch.org/vision/stable/transforms.html)

