**This notebook has been used for creating the datasets (the .txt files)**

# FASTAI

In [None]:
!pip install fastai==2.4

# COCO DOWNLOAD

In [None]:
from fastai.data.external import untar_data, URLs

In [None]:
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"

In [None]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

np.random.seed(1234)

In [None]:
paths = glob.glob(coco_path+"/*.jpg")
paths =np.array(paths)
num_images = len(paths)
print(f"# images: {num_images}")

# images: 21837


In [None]:
#TRAIN dataset

n_small = 6800
n_big = 14800

idxs = np.random.permutation(num_images)

train_idxs_small = idxs[:n_small]
train_idxs_big = idxs[:n_big]

train_paths_small = paths[train_idxs_small]
train_paths_big = paths[train_idxs_big]

print(f"Small dataset : {len(train_paths_small)} images,  Big dataset : {len(train_paths_big)} images")

Small dataset : 6800 images,  Big dataset : 14800 images


In [None]:
_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths_big):
    ax.imshow(Image.open(img_path))
    ax.axis("off")

# Create the .txt files

In [None]:
train_paths_small = train_paths_small.tolist()
train_paths_big = train_paths_big.tolist()

with open('coco_small_training.txt', 'w') as f:
  for line in train_paths_small:
        f.write(f"{line}\n")

with open('coco_big_training.txt', 'w') as f:
  for line in train_paths_big:
        f.write(f"{line}\n")

# Code for reading the files

In [None]:
filename_small = "coco_small_training.txt"

def read_lines(path):

  lines = None

  with open(path) as file:
    lines = [line.rstrip() for line in file]

  return lines

tr_small_paths = read_lines(filename_small)
print(f"{len(tr_small_paths)} images")


tr_big_paths = read_lines("coco_big_training.txt")
print(f"{len(tr_big_paths)} images")



# Code to get the path for a validation/test dataset

Execute the cells before

In [None]:
test_small_paths = []
test_big_paths = []

paths = glob.glob(coco_path+"/*.jpg")

for path in paths:

  if path not in tr_big_paths:
    test_big_paths.append(path)

for path in paths:

  if path not in tr_small_paths:
    test_small_paths.append(path)



In [None]:
print(len(test_small_paths), len(test_big_paths))

# Download animals dataset

In [None]:
!pip install -q kaggle
from google.colab import files

You have to upload a file called kaggle.json. To obtain it you need to follow the first 2 steps described in https://www.kaggle.com/general/74235

In [None]:
files.upload()

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! kaggle datasets list

In [None]:
!kaggle datasets download -d piyushkumar18/animal-image-classification-dataset

In [None]:
!mkdir /content/animal_data
!unzip -qq /content/animal-image-classification-dataset.zip -d /content/animal_data/

Now we have to split it on two: training dataset (used to train the classifier or for the adversarial training) and validation/test dataset. Before you have to upload "val_animals.txt" that allows to split the dataset (two list with the paths will be obtained)

In [None]:
files.upload();

In [None]:
import os
#set to None to use all the images (14K)
max_img_per_class = 100

path = "/content/animal_data"
animal_path = path + "/Animal Image Dataset"

animals = ["butterfly", "cats", "cow", "dogs", "elephant", "hen", "horse", "monkey", "panda", "sheep", "spider", "squirrel"]

tr_animal_paths = []


val_paths = []

#collect paths validation/test images
with open("/content/val_animals.txt") as file:
    val_paths = [line.rstrip() for line in file]



for animal in animals:

  counter = 0
  folder = os.listdir(animal_path+"/"+animal)

  for image in folder:

    if counter == max_img_per_class:

        break

    if animal_path+"/"+animal+"/"+image not in val_paths:

      tr_animal_paths.append(animal_path+"/"+animal+"/"+image)
      counter +=1

print(len(val_paths))
print(len(tr_animal_paths))


Create a unique file .txt with all the training images (coco + animals)

In [None]:
use_small = False

coco_paths = tr_small_paths if use_small else tr_big_paths

training_paths = tr_animal_paths + coco_paths

#shuffle
numpy_paths = np.array(training_paths)
np.random.shuffle(numpy_paths)


training_paths = numpy_paths.tolist()
filename = "coco_animals_big_training.txt" if not use_small else "coco_animals_small_training.txt"

with open(filename, 'w') as f:
  for line in training_paths:
        f.write(f"{line}\n")

In [None]:
test = []

test = read_lines("coco_animals_big_training.txt")
print(len(test))

# Dataset and Dataloader

In [None]:
SIZE = 256

train_transform = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  transforms.InterpolationMode.BILINEAR),
                transforms.RandomHorizontalFlip(),
            ])

test_transform = transforms.Compose([
                transforms.Resize((SIZE, SIZE),  transforms.InterpolationMode.BILINEAR),
            ])


class FromGrayToColorDataset(Dataset):

  def __init__(self, paths, transform = None):
    
    self.size = SIZE
    self.paths = paths
    self.transform = transform

  def __len__(self):

    return len(self.paths)

  def __getitem__(self, idx):

    img_rgb = Image.open(self.paths[idx]).convert("RGB")
    img_rgb = self.transform(img_rgb)
    img_rgb = np.array(img_rgb)

    #RGB -> Lab
    img_lab = rgb2lab(img_rgb).astype("float32")
    img_lab = transforms.ToTensor()(img_lab)

    #to have values in range [-1,1]
    L = img_lab[0,:]/50. - 1.
    ab = img_lab[[1,2],:] / 110.

    return (L,ab)


In [None]:
train_dataset = FromGrayToColorDataset(train_paths, train_transform)
test_dataset = FromGrayToColorDataset(test_paths, test_transform)

In [None]:
PIN_MEMORY = True
N_WORKERS = 2
BATCH_SIZE = 32

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=N_WORKERS,
                            pin_memory=PIN_MEMORY, shuffle = True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=N_WORKERS,
                            pin_memory=PIN_MEMORY)

# UTILS

In [None]:
def convert_lab_to_rgb(L, ab):

  """
  Provided a Lab image or a batch of Lab images, it returns it/them in RGB format 
  input:
    - L: torch.tensor
    - ab: torch.tensor
  
  output:
    - img: numpy.ndarray (the rgb images)
  """

  #check shape (one image or a batch)

  is_batch = len(ab.shape) > 3
  
  L = (L+1.)*50.
  ab = ab*110.

  if is_batch:
    # input tensors: N x 1 x 256 x 256, N x 2 x 256 x 256
    Lab_images = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
  else:
    # input tensors: 1 x 256 x 256, 2 x 256 x 256
    Lab_image = torch.cat([L, ab], dim=0).permute(1, 2, 0).cpu().numpy()
    return lab2rgb(Lab_image)

  rgb_images = list()

  for image in Lab_images:

    img_rgb = lab2rgb(image)
    rgb_images.append(img_rgb)

  return np.stack(rgb_images, axis=0)

L,ab = train_dataset[0]
L2, ab2 = train_dataset[1]

img1 = convert_lab_to_rgb(L, ab)
print(img1.shape)
L = torch.cat([L.unsqueeze(0).unsqueeze(0),L2.unsqueeze(0).unsqueeze(0)])
ab = torch.cat([ab.unsqueeze(0),ab2.unsqueeze(0)])

img2 = convert_lab_to_rgb(L, ab)
print(img2.shape)

img3 = img2[0]
img4 = img2[1]

(256, 256, 3)
(2, 256, 256, 3)


In [None]:
fig = plt.figure(figsize=(15, 8))
ax1 = plt.subplot(1, 3, 1)
ax1.imshow(img1)
ax1.axis("off")
ax2 = plt.subplot(1, 3, 2)
ax2.imshow(img3)
ax2.axis("off")
ax3 = plt.subplot(1, 3, 3)
ax3.imshow(img4)
ax3.axis("off")

In [None]:
batch = next(iter(train_dataloader))
print(batch[0].shape)
print(batch[1].shape)

torch.Size([32, 256, 256])
torch.Size([32, 2, 256, 256])


In [None]:
def show_batch(Ls, abs, n_cols = 4):

  """
  provided a batch of images, visualize them
  input:
    - Ls: batch with L for each image, N x 1 x 256 x 256 tensor
    - abs: batch with ab for each image, N x 2 x 256 x 256 tensor
  """
  batch_size = Ls.shape[0]
  num_rows = batch_size//n_cols

  rgb_images = convert_lab_to_rgb(Ls, abs)

  
  fig = plt.figure(figsize=(10, 20))

  for i in range(num_rows):

    for j in range(n_cols):
      ax = plt.subplot(num_rows, n_cols, i*n_cols + j +1)
      ax.imshow(rgb_images[i*n_cols+j])
      ax.axis("off")
    
  plt.show()


show_batch(batch[0].unsqueeze(1), batch[1])

In [None]:
def show_results(Ls, real_abs, fake_abs):

  """
  provided a batch of real and fake images, visualize them (+ the gray image)
  input:
    - Ls: batch with L for each image, N x 1 x 256 x 256 tensor
    - real_abs: batch with ab for each real image, N x 2 x 256 x 256 tensor
    - fake_abs: batch with ab for each fake image, N x 2 x 256 x 256 tensor
  """

  n_cols = Ls.shape[0]

  real_images = convert_lab_to_rgb(Ls, real_abs)
  fake_images = convert_lab_to_rgb(Ls, fake_abs)

  fig = plt.figure(figsize=(15, 15))

  for i in range(n_cols):

    ax = plt.subplot(3, n_cols, i+1)
    ax.imshow(Ls[i][0].cpu(), cmap='gray')
    ax.axis("off")

    ax = plt.subplot(3, n_cols, i+1+n_cols)
    ax.imshow(real_images[i])
    ax.axis("off")

    ax = plt.subplot(3, n_cols, i+1+2*n_cols)
    ax.imshow(fake_images[i])
    ax.axis("off")

show_results(batch[0].unsqueeze(1)[:4], batch[1][:4], batch[1][:4])