In [None]:
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

In [None]:
import torch
import torchvision.transforms as T

In [None]:
!curl -LO https://github.com/pytorch/vision/raw/main/gallery/assets/astronaut.jpg

In [None]:
!ls

In [None]:
plt.rcParams['savefig.bbox'] = 'tight'
orig_img = Image.open('astronaut.jpg')
torch.manual_seed(0)

In [None]:
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
  if not isinstance(imgs[0], list):
    imgs = [imgs]

  num_rows = len(imgs)
  num_cols = len(imgs[0]) + with_orig
  fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
  for row_idx, row in enumerate(imgs):
    row = [orig_img] + row if with_orig else row
    for col_idx, img in enumerate(row):
      ax = axs[row_idx, col_idx]
      ax.imshow(np.asarray(img), **imshow_kwargs)
      ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

  if with_orig:
    axs[0, 0].set(title='Original image')
    axs[0, 0].title.set_size(8)
  if row_title is not None:
    for row_idx in range(num_rows):
      axs[row_idx, 0].set(ylabel=row_title[row_idx])

  plt.tight_layout()


In [None]:
# Rand augment
augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)