<a target="_blank" href="https://colab.research.google.com/github/vchabaux/recherche_images_gallica/blob/main/recherche_images_gallica.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [1]:
try:import gradio as gr
except:
  %pip install gradio
  import gradio as gr

try:import xmltodict
except:
  %pip install xmltodict
  import xmltodict

try:from lavis.models import load_model_and_preprocess
except:
  %pip install salesforce-lavis
  from lavis.models import load_model_and_preprocess

try:from transformers import MarianMTModel, MarianTokenizer
except:
  %pip install transformers
  from transformers import MarianMTModel, MarianTokenizer

try:from sentence_transformers import SentenceTransformer, util
except:
  %pip install sentence-transformers
  from sentence_transformers import SentenceTransformer, util

import os
import numpy as np
import pandas as pd
from PIL import Image
import csv
from math import dist
import urllib.request, urllib.error, urllib.parse
from urllib.error import HTTPError, URLError
import requests
import torch
from torch import tensor
import time

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Dossier du corpus d'image (sera créé au m^me niveau que ce notebook)
DIR_BASE = "gdrive/MyDrive/recherche_images_gallica/IMG_CORPUS"
# Dossier des fichiers d'embeddings  (sera créé au même niveau que ce notebook)
DIR_INDEX = "gdrive/MyDrive/recherche_images_gallica/IMG_INDEXS"
# Dossier temporaire de la collecte  (sera créé au même niveau que ce notebook)
DIR_TMP = "gdrive/MyDrive/recherche_images_gallica/IMG_TMP"

# Initialisation de Blip2 (lavis framework)
model, vis_processors, txt_processors = load_model_and_preprocess(name="blip2_feature_extractor", model_type="coco", device=device)

# Initialisation du transformer encoder-decoders pour la traduction (Opus-MT based on Marian-NMT)
model_name = 'Helsinki-NLP/opus-mt-fr-en'
tokenizer = MarianTokenizer.from_pretrained(model_name)
marian = MarianMTModel.from_pretrained(model_name)
marian = marian.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

100%|██████████| 1.89G/1.89G [00:29<00:00, 67.6MB/s]


Position interpolate from 16x16 to 26x26


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

100%|██████████| 4.37G/4.37G [01:27<00:00, 53.6MB/s]


tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/778k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/301M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

In [6]:

import torch.nn.functional as F

# Retourne le nombre de pages du document)
def nombre_pages(ark):
  # In : identifiant ark | Out : nombre de pages (int)
  PAGINATION_BASEURL = 'https://gallica.bnf.fr/services/Pagination?ark='
  url = "".join([PAGINATION_BASEURL, ark])
  s = requests.get(url, stream=True)
  paginationdic = xmltodict.parse(s.text)
  nb_pages = int(paginationdic["livre"]["structure"]["nbVueImages"])
  return nb_pages


#Calcul la distance entre le milieu-bas du rectangle superieur (image) et le milieu-haut du rectangle inferieur(légende)
def rect_distance(rect1, rect2):
    x1, y1, x1b, y1b = rect1
    x2, y2, x2b, y2b = rect2
    # Coordonnées des milieux des côtés
    milieu_haut_rect2 = ((x2 + x2b) / 2, y2)
    milieu_bas_rect1 = ((x1 + x1b) / 2, y1b)
    # Calcul de la distance entre les milieux
    distance = dist(milieu_haut_rect2, milieu_bas_rect1)
    return distance

def translate_legend(row):
    if row["legend"] is not np.nan and row["legend"] is not None and row["legend"].strip() != "":
        inputs = tokenizer.encode(row["legend"], return_tensors="pt").to(device)

        outputs = marian.generate(inputs, num_beams=4, max_length=200, early_stopping=True)
        translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(row["legend"], translated_text)
        return translated_text
    else:
        print(row["legend"], "")
        return ""

# Retourne les embeddings
def get_embedding(row):
    img_path = row["img_path"]
    txt = row["en_legend"] if row["en_legend"] is not np.nan else ""

    # CHargement et pre-process de l'image
    image = Image.open(img_path).convert("RGB")
    image_processed = vis_processors["eval"](image).unsqueeze(0).to(device)

    # pre-process de la légende (en anglais)
    text_input = txt_processors["eval"](txt)

    sample = {"image": image_processed, "text_input": text_input}

    # English embedding
    text_emb = model.extract_features(sample, mode="text").text_embeds_proj[:,0,:] # size (1, 256)
    text_emb /= text_emb.norm(dim=-1, keepdim=True)

    # Image embedding
    image_emb = model.extract_features(sample, mode="image").image_embeds_proj[:,0,:] # size (1, 256)
    image_emb /= image_emb.norm(dim=-1, keepdim=True)

    # mean embedding
    mean_emb = torch.stack([image_emb, text_emb]).mean(dim=0)

    print("Embedding :",img_path, "\t", txt)

    # Retourne l'embedding du texte, l'embedding de l'image et l'embedding moyen (image et texte) # size : (1,256)
    return [text_emb, image_emb, mean_emb]


# Variable globale (sry) représentant le csv d'embedding (corpus) actuellement selectionné
base = None


# Charger le csv d'embedding selectionné dans la liste déroulante
def load_csv(corpus_select):
    global DIR_INDEX
    global base
    print(corpus_select)
    base = pd.read_csv(DIR_INDEX+"/"+corpus_select)
    base['embedding'] = base['embedding'].apply(lambda x: eval(x.replace("'cuda:0'", "device")))


# Lance la collecte et l'embedding d'une liste d'arks sous forme de string("a,b,c") avec un nom de collection coll_name
def search_ark_fn(search_ark, coll_name):
    global DIR_BASE
    global DIR_INDEX
    global DIR_TMP

    # Parse de la liste de d'arks en entrée
    arks=search_ark.split(",")

    # Création des dossiers s'il n'existent pas
    try:
        os.makedirs(DIR_BASE)
    except FileExistsError:
        pass

    try:
        os.makedirs(DIR_TMP)
    except FileExistsError:
        pass

    try:
        os.makedirs(DIR_BASE+"/"+coll_name)
    except:
        return "Erreur : ce nom de collection exite déjà ou n'est pas valide !"

    try:
        os.makedirs(DIR_INDEX)
    except FileExistsError:
        pass

    # Liste des images collectées
    base_img = []
    if os.path.exists(DIR_TMP+"/"+coll_name+".csv"):
        os.remove(DIR_TMP+"/"+coll_name+".csv")
    with open(DIR_TMP+"/"+coll_name+".csv", "a", encoding='utf-8') as tmp_file:
        csv_writer = csv.writer(tmp_file)
        for ark in arks:
            links = {}
            pages = nombre_pages(ark)
            print("\nARK :", ark)

            for page in range(1, pages+1) :
                images = []
                texts = []
                alto_url = 'https://gallica.bnf.fr/RequestDigitalElement?O={}&E=ALTO&Deb={}'.format(ark, page)
                # Boucle de requête de l'alto de la page. Si erreur sleep 15 secondes. Skip la page à la 3eme erreur
                fail_counter = 0
                while True:
                    try:
                        s = requests.get(alto_url, stream=True)
                        break
                    except:
                        fail_counter += 1
                        if fail_counter > 2:
                            print("Echec de collecte de l'alto avec l'url :", alto_url, "Echecs:", fail_counter)
                            print("Top d'echecs, page is skiped")
                            break
                        else:
                            print("Echec de collecte de l'alto avec l'url :", alto_url, "Echecs:", fail_counter)
                            print("Nouvel essai dans 15 secondes")
                            time.sleep(15)
                            continue
                if fail_counter > 2:
                    break

                # Vérifier si la page est est océrisée, sinon la page est skiped
                try:
                    altodic = xmltodict.parse(s.text)
                except :
                    print(ark, "Document non océrisé. Skiped")
                    break

                # Collecte des images et légendes
                print("==========", "Page", page,"==========")
                cbs = altodic["alto"]["Layout"]["Page"].get("PrintSpace", {}).get("TextBlock", [])
                if not isinstance(cbs, list): cbs = [cbs]
                for cb in cbs:
                    content = []
                    textLines = cb.get("TextLine",[])
                    if not isinstance(textLines, list): textLines = [textLines]
                    for textLine in textLines:
                        strings = textLine.get("String",[])
                        if not isinstance(strings, list): strings = [strings]
                        content.extend(string.get("@CONTENT") for string in strings)
                    texts.append(((int(cb["@HPOS"]), int(cb["@VPOS"]), int(cb["@HPOS"])+int(cb["@WIDTH"]), int(cb["@VPOS"])+int(cb["@HEIGHT"])), " ".join(content)))
                cbs = altodic["alto"]["Layout"]["Page"].get("PrintSpace", {}).get("Illustration", [])
                if not isinstance(cbs, list): cbs = [cbs]
                for cb in cbs:
                    images.append(((int(cb["@HPOS"]), int(cb["@VPOS"]), int(cb["@HPOS"])+int(cb["@WIDTH"]), int(cb["@VPOS"])+int(cb["@HEIGHT"])), cb))
                cbs = altodic["alto"]["Layout"]["Page"].get("PrintSpace", {}).get("ComposedBlock", [])
                if not isinstance(cbs, list): cbs = [cbs]
                for cb in cbs:
                    illustration = cb.get("Illustration", [])
                    if not isinstance(illustration, list):
                        illustration = [illustration]
                    textBlocks = cb.get("TextBlock", [])
                    if not isinstance(textBlocks, list):
                        textBlocks = [textBlocks]
                    for cb in textBlocks:
                        content = []
                        textLines = cb.get("TextLine",[])
                        if not isinstance(textLines, list):
                            textLines = [textLines]
                        for textLine in textLines:
                            strings = textLine.get("String",[])
                            if not isinstance(strings, list): strings = [strings]
                            content.extend(string.get("@CONTENT") for string in strings)
                        texts.append(((int(cb["@HPOS"]), int(cb["@VPOS"]), int(cb["@HPOS"])+int(cb["@WIDTH"]), int(cb["@VPOS"])+int(cb["@HEIGHT"])), " ".join(content)))
                    for cb in illustration:
                        images.append(((int(cb["@HPOS"]), int(cb["@VPOS"]), int(cb["@HPOS"])+int(cb["@WIDTH"]), int(cb["@VPOS"])+int(cb["@HEIGHT"])),cb))

                # Récupérer toutes les images (et leur légende) identifiées sur la page
                for i,img in enumerate(images) :
                    url = "https://gallica.bnf.fr/iiif/ark:/12148/{}/f{}/{},{},{},{}/{}/0/native.jpg".format(ark,page,img[1]["@HPOS"],img[1]["@VPOS"],img[1]["@WIDTH"],img[1]["@HEIGHT"],"full")
                    nomfichier = ark+"_"+str(page)+"_"+img[1]["@ID"]+".jpg"
                    cheminout = DIR_BASE+"/"+coll_name+"/"+nomfichier

                    # Boucle de téléchargement de l'image. si erreur, réessaie dans 1à seconds
                    while True:
                        try :
                            urllib.request.urlretrieve(url, cheminout)
                        except (HTTPError, URLError) as erreur:
                            print(str(erreur.reason))
                            print("wait 10 seconds")
                            time.sleep(10)
                        break

                    try:
                        # Identifier et récupérer la légende de l'image (si trouvée)
                        txt_rank = []
                        legend = []
                        for txt in texts:
                            distance = rect_distance(img[0], txt[0])
                            if distance <100 : legend.append(txt[1])
                            txt_rank.append((distance, txt[1]))
                        txt_rank.sort(key= lambda x : x[0])
                        txt_legned = " ".join(legend)
                        print("Image :", i, "| Description :",txt_legned)
                        #for rank in txt_rank:print(rank)

                        # Ajouter l'image à la liste
                        #base_img.append((cheminout, txt_legned))
                        csv_writer.writerow([cheminout, txt_legned])

                    except:
                        pass

    # load la liste image<>legende dans un dataframe pandas
    index_df = pd.read_csv(DIR_TMP+"/"+coll_name+".csv", names=["img_path", "legend"], encoding="utf-8")

    # Traduction des légendes en anglais (pour Blip2)
    #legends = index_df["legend"].fillna("").tolist()
    #translated = marian.generate(**tokenizer(legends, return_tensors="pt", padding=True))
    #en_legends = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
    #index_df["en_legend"] = np.array(en_legends)
    #index_df["en_legend"][(index_df["legend"].isna()) | (index_df["legend"].str.strip() == "") | (index_df["legend"] is None)] = ""
    index_df["en_legend"] = index_df.apply(translate_legend, axis = 1)

    # Obtenir les embedding du corpus par Blip2
    index_df["embedding"] = index_df.apply(get_embedding, axis = 1)

    # Sauvegarder le fichier csv d'embeddings
    index_df.to_csv(DIR_INDEX+"/"+coll_name+".csv", index=False, encoding="utf-8")


# Met à jour le corpu selectionné selon la selection de l'utilisateur
def update_corpus():
    corpus = []
    for csv_file in os.listdir(DIR_INDEX):
        if csv_file.lower().endswith('.csv'):
            corpus.append(csv_file)
    return gr.update(choices=corpus)


# Deactivate UI
def deactivate():
    return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)


# Activate UI
def activate():
    return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)


# Lacer la recherche d'images similaires depuis l'UI
def search(input_search_txt, input_search_type, img_count):
    global base
    # Calcul des similarités entre la recherche et et le type d'embedding choisi
    def sim_calc(row, filter):
        if filter == "text_sim":
            sim = -1 if row["en_legend"] is np.nan or row["en_legend"] == "" else util.cos_sim(row["embedding"][0], input_search_embedding).cpu().numpy()[0]
        elif filter == "image_sim":
            sim = util.cos_sim(row["embedding"][1], input_search_embedding).cpu().numpy()[0]
        else:
            sim_mean = util.cos_sim(row["embedding"][2], input_search_embedding).cpu().numpy()[0]
            #sim_text = torch.matmul(row["embedding"][0], input_search_embedding.T).cpu().numpy()[0]
            sim_text = -1 if row["en_legend"] is np.nan or row["en_legend"] == "" else util.cos_sim(row["embedding"][0], input_search_embedding).cpu().numpy()[0]
            sim_image = util.cos_sim(row["embedding"][1], input_search_embedding).cpu().numpy()[0]
            #sim_image = torch.matmul(row["embedding"][1], input_search_embedding.T).cpu().numpy()[0]
            sim = (sim_mean + sim_text + sim_image)/3
        return sim

    # Traduction, Tokenization et Embedding de la l'entrée texte de recherche
    inputs = tokenizer.encode(input_search_txt, return_tensors="pt").to(device)
    outputs = marian.generate(inputs, num_beams=4, max_length=50, early_stopping=True)
    input_search_tk = tokenizer.decode(outputs[0], skip_special_tokens=True)
    input_search = txt_processors["eval"](input_search_tk)
    input_search_sample = {"image": None, "text_input": [input_search]}
    input_search_embedding = model.extract_features(input_search_sample, mode="text").text_embeds_proj[:,0,:] # size (1, 768)
    input_search_embedding /= input_search_embedding.norm(dim=-1, keepdim=True)

    # Get le type de similarité selectionnée
    if input_search_type == "Texte & Image":
        filter = "total_mean"
    elif input_search_type == "Image":
        filter = "image_sim"
    else:
        filter = "text_sim"

    # Calcul des similarités  selon le filtre choisi : Text<>Texte, Texte<>Image ou Texte<>Mean(Texte,Image)
    base['sim'] = base.apply(lambda x: sim_calc(x,filter), axis = 1)

    # Return images to gradio gallery
    images = []
    for index, row in base.sort_values("sim", ascending=False).head(img_count).iterrows():
        image = Image.open("gdrive/MyDrive/recherche_images_gallica/"+row["img_path"]).convert("RGB")
        images.append((image, str(row["sim"]) +" - "+str(row["legend"])))

    return images


In [None]:
# WebUI by Gradio
with gr.Blocks() as demo:
    with gr.Tab("Recherche"):
        with gr.Row():
            corpus_select = gr.Dropdown([], label="Corpus", info="Choisissez un corpus", value=0)
            update_btn = gr.Button("Rafraîchir la liste des corpus")
        with gr.Row():
            search_type = gr.Dropdown(["Texte & Image", "Image", "Texte"], label="Type de recherche", info="Choisissez un type de recherche", value=0, interactive = False)
            search_txt = gr.Textbox(label="Recherche", info="Texte pour la recherche d'images", interactive = False)
        with gr.Row():
            img_count = gr.Slider(1, 50, value=5, label="Nombre d'images", info="Choisissez le nombre d'images à rechercher", step=1, interactive = False)
        with gr.Row():
            search_btn = gr.Button("Rechercher des images", interactive = False)
        with gr.Row():
            gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="fill", height="auto", interactive = False)
    with gr.Tab("Collecte"):
        coll_name = gr.Textbox(label="Nom de collection", info="Entrez le nom de votre collection", interactive = True)
        search_ark = gr.Textbox(label="Arks", info="Identifiants Arks à collecter", interactive = True)
        search_ark_btn = gr.Button("Lancer la collecte et les traitements", interactive = True)
        output_ark = gr.Textbox(label="Console", info="Logs de la collecte", interactive = False)

    search_btn.click(fn=deactivate, outputs=[corpus_select, update_btn, search_type, search_txt, img_count, search_btn, coll_name, search_ark, search_ark_btn]).then(search, inputs=[search_txt,search_type, img_count], outputs=[gallery]).then(fn=activate, outputs=[corpus_select, update_btn, search_type, search_txt, img_count, search_btn, coll_name, search_ark, search_ark_btn])
    update_btn.click(fn=deactivate, outputs=[corpus_select, update_btn, search_type, search_txt, img_count, search_btn, coll_name, search_ark, search_ark_btn]).then(update_corpus, outputs=[corpus_select]).then(fn=activate, outputs=[corpus_select, update_btn, search_type, search_txt, img_count, search_btn, coll_name, search_ark, search_ark_btn])
    search_ark_btn.click(fn=deactivate, outputs=[corpus_select, update_btn, search_type, search_txt, img_count, search_btn, coll_name, search_ark, search_ark_btn]).then(search_ark_fn, inputs=[search_ark, coll_name], outputs=[output_ark]).then(fn=activate, outputs=[corpus_select, update_btn, search_type, search_txt, img_count, search_btn, coll_name, search_ark, search_ark_btn])
    corpus_select.change(fn=deactivate, outputs=[corpus_select, update_btn, search_type, search_txt, img_count, search_btn, coll_name, search_ark, search_ark_btn]).then(load_csv, inputs=[corpus_select], show_progress=True).then(fn=activate, outputs=[corpus_select, update_btn, search_type, search_txt, img_count, search_btn, coll_name, search_ark, search_ark_btn])

demo.launch(debug=True)



Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://cd0d7b1cd2e4df0859.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


manuels_histoire.csv
manuels_histoire.csv
