# Importations

In [17]:
import sys, os, multiprocessing, csv
from PIL import Image
from io import BytesIO
from urllib.request import urlopen
import tqdm
from tqdm import tnrange
from tqdm.contrib.concurrent import process_map
import numpy as np

from matplotlib import pyplot as plt

from skimage import data

import PIL.Image as IMG

from imageio import imread
import glob

import torch
from torchvision.utils import save_image

# Définition de quelques constantes

In [18]:
#Choix de l'intervalle du nombre d'images par label autorisé
debut_lab = 99
fin_lab = 106
# On prend cet intervalle pour "choisir" le nombre de données et par exemple comparer les dataset entre eux
# avec autant d'images dans le premier que le deuxième par exemple

#Size des photos après le reshape
size = 60

#dossier de sortie des images
folder = "trans100/"

# Fonctions de parsing et de téléchargement 

In [19]:
def ParseData(data_file):
  csvfile = open(data_file, 'r')
  csvreader = csv.reader(csvfile)
  dataset = [line for line in csvreader]
  return dataset[1:]  # Chop off header


def DownloadImage(data):
  (key, url, label) = data

  try:
    response = urlopen(url)
    image_data = response.read()
    #print("read")
  except:
    #print('Warning: Could not download image %s from %s' % (key, url))
    return

  try:
    pil_image = Image.open(BytesIO(image_data))
    #print("parsed")
    return pil_image
  except:
    #print('Warning: Failed to parse image %s' % key)
    return

def Create_labels (data_file):
    dataset_url = ParseData(data_file)
    dataset = []
    for data in tqdm.tqdm(dataset_url[:len(dataset_url)], total=len(dataset_url)) :
        (key, url, label) = data 
        if label != "None" :
            dataset.append(int(label))
    return dataset

# Sélection des labels

In [20]:
list_labels = Create_labels("train.csv")
num_labels = np.histogram(list_labels, bins=range(max(list_labels)+2))[0]

#fichier contenant tous les urls des images
dataset_url = ParseData("train.csv")

#selection des labels/images
image_per_label = [0]*(15000)
for data in tqdm.tqdm(dataset_url[:len(dataset_url)], total=len(dataset_url)):
    (key, url, label) = data
    if (label != "None") and (num_labels[int(label)]>debut_lab and num_labels[int(label)]<fin_lab) and image_per_label[int(label)]<fin_lab :
        image_per_label[int(label)]+=1

100%|██████████| 1225029/1225029 [00:00<00:00, 2332194.52it/s]
100%|██████████| 1225029/1225029 [00:01<00:00, 931339.99it/s]


# Téléchargement effectif des données

In [21]:
#Récupération des transformations du TD2, contenues dans preprocessing.py
%run couches

In [22]:
#Rescale des images aux tailles souhaitées
def rescale_reshape(img, size):
    img_t = to_float32(img)
    img_t = rescale(img,size, size)
    return img_t

#Calcul des différentes transformations
def ajout_transfo(img, high1=0.5, low1=0.1, high2=0.2, low2=0.05) :
    r,g,b = rgb(img) 
    bandw = rgb_to_bandw(img)
    vis_grad_g, vis_grad_b = vis_grad(bandw)
    vis_hessian_g, vis_hessian_b = vis_hessian(bandw)
    return [r,g,b,
          vis_grad_g, vis_grad_b,
          canny_edge_detection(bandw, high1, low1),
          canny_edge_detection(bandw, high2, low2),
          vis_hessian_g, vis_hessian_b]

#Téléchargement de l'image `t`
def create_and_register(t):
    i,data = t
    (key, url, label) = data
    if (label != "None") and (num_labels[int(label)]>debut_lab and num_labels[int(label)]<fin_lab) :
        pil_image = DownloadImage(data)
        if pil_image!= None :
            pil_image = np.array(pil_image)
            if len(pil_image.shape) < 3 :
                return
            #image_per_label[int(label)]+=1
            pil_image = rescale_reshape(pil_image, size)
            pil_image_li = ajout_transfo(pil_image)
            couche = 0
            for img in pil_image_li :
                img = img.reshape((size, size))
                plt.imsave(folder + str(i) + 'l' + str(label) + 'c' + str(couche) +".png", img, cmap="Greys")
                couche += 1
            
#Lancement du téléchargement des images en multiprocessing
def CreateDataset(data_file, num_labels):
    arg = [(i,dataset_url[i]) for i in range(len(dataset_url))]
    with multiprocessing.Pool() as p :
        list(tqdm.tqdm(p.imap(create_and_register, arg), total=len(dataset_url)))

In [23]:
CreateDataset("train.csv", num_labels)

100%|██████████| 1225029/1225029 [2:50:53<00:00, 119.47it/s] 
