In [15]:
import random
from typing import List

import torch
import torch.nn as nn

from torchvision import datasets
import torchvision.transforms as T
import torchvision.transforms.functional as TF

# torchvision.transforms

- torchvision.transforms.ToPILImage()
- torchvision.transforms.PILToTensor()
- torchvision.transforms.ToTensor()
- torchvision.transforms.Lambda()
- torchvision.transforms.Rotate()
- torchvision.transforms.function.rotate()

In [16]:
training_data = datasets.FashionMNIST(
    root = "~/machinelearning/datasets/pytorch_data/",
    train = True,
    download = False,
    transform = T.ToTensor(),
    target_transform = T.Lambda(lambda y: torch.zeros(10, dtype = torch.float).scatter_(0, torch.tensor(y), value = 1))
)
test_data = datasets.FashionMNIST(
    root = "~/machinelearning/datasets/pytorch_data/",
    train = False,
    download = False,
    transform = T.ToTensor(),
    target_transform = T.Lambda(lambda y: torch.zeros(10, dtype = torch.float).scatter_(0, torch.tensor(y), value = 1))
)

# functional transforms

In [17]:
def segmentation_transforms(image, segmentation):
    if random.random() > 0.5:
        angle = random.randint(-30, 30)
        image = TF.rotate(image, angle)
        segmentation = TF.rotate(segmentation, angle)
    return image, segmentation

In [18]:
class RotationTransform:
    """
    Rotate by one of the given angles
    """
    def __init__(self, angles: List[float]):
        self.angles = angles
    
    def __call__(self, x):
        angle = random.choice(self.angles)
        transformed_x = TF.rotate(x, angle)
        return transformed_x

rotation_transform = RotationTransform(
    angles = [-30, -15, 0, 15, 30]
)

# Scriptable transforms

- torch.nn.Sequential
- torch.jit.script

In [19]:
transforms = nn.Sequential(
    T.CenterCrop(10),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)

# Compositions of transforms

- torchvision.transform.Compose

In [22]:
transforms = T.Compose([
    T.CenterCrop(10),
    T.PILToTensor(),
    T.ConvertImageDtype(torch.float),
])

# Conversion Transforms

- ToPILImage
- ToTensor
- PILToTensor

## Generic Transforms

- Lambda