# Run the cell below to use CLIP

In [None]:
!pip install ftfy regex tqdm
!git clone https://github.com/openai/CLIP.git
%cd CLIP

# CLIP input : WC folder of sub-images
## The code to produce the CLIP statistics (clip_stats.csv) is provided at the bottom of this notebook

In [3]:
import os
import sys
import torch
import clip         # need to go to the CLIP repo (this is done in the previous cell)
import numpy as np
import matplotlib.pyplot as plt
sys.path.append('..')
from img_bbox import quickview
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

prompt_list = ["a picture of an orange in a tree",
               "a picture of an unripe orange in a tree",
               "a picture of a part of an orange in a tree",
               "a picture of a part of an unripe orange in a tree",
               "a picture of a lemon in a tree",
               "a picture of an unripe lemon in a tree",
               # from here on, the labels are not related to fruits but background
               "a picture of leaves and tree branches without any ripe or unripe fruit",
               "a picture of leaves and tree branches",
               "a picture of parts of leaves and tree branches",
               "a picture of parts of leaves and tree branches without any ripe or unripe fruit",
               "a picture of a building",
               "a picture of a part of a building",
               "a picture of a part of the sky",
               "a picture of dead leaves on the ground without any ripe or unripe fruit",
               "a picture of dead leaves on the ground",
               "a picture of the roots of a tree",
               "a picture of the roots of a tree without any ripe or unripe fruit"]

def predict_class(img_name):
    image = preprocess(Image.open(img_name)).unsqueeze(0).to(device)
    text = clip.tokenize(prompt_list).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    return probs

# ========================= the four next functions were used to control each subimage and its bboxes
 
def get_lines(label_path:str):
    lines=[]
    with open(label_path, 'r') as file:
        lines = file.readlines()
    return lines

def remove_bboxes(line_number:int, label_path:str):
    lines = get_lines(label_path)
    if len(lines) == 0:
        print("No bounding boxes to remove")
        return
    with open(label_path, 'w') as file:
        for i, line in enumerate(lines):
            if i != line_number:
                file.write(line)

def keep_only_bboxes(line_number:int, label_path:str):
    lines = get_lines(label_path)
    if len(lines) == 0:
        print("No bounding boxes to keep")
        return
    with open(label_path, 'w') as file:
        for i, line in enumerate(lines):
            if i == line_number:
                file.write(line)

def delete_img_label(path, filename):
    os.remove(path + "images/" + filename)
    os.remove(path + "labels/" + filename.replace("jpg", "txt"))

def remove_all_bboxes(label_path:str):
    open(label_path, "w").close()

# Example : check on afternoon sunny

In [None]:
folder_orange = "../FINAL_SUBIMAGES/AS_CLIP_AFT_SUN/images/"
mean_list = []
not_mean_list = []
argmax_list = []
not_argmax_list = []
for filename in os.listdir(folder_orange):
    if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
        pred_class = predict_class(folder_orange + filename)
        if np.argmax(pred_class) < 6:
            argmax_list.append(filename)
        else:
            not_argmax_list.append(filename)
        if np.mean(pred_class[0][:6]) > np.mean(pred_class[0][6:]):
            mean_list.append(filename)
        else:
            not_mean_list.append(filename)

both_list = list(set(mean_list) & set(argmax_list)) # taking intersection to prevent duplicates (i.e. images that are both mean and argmax)
not_both_list = list(set(not_mean_list) & set(not_argmax_list))

print(f"Argmax accuracy : {len(argmax_list) / (len(argmax_list) + len(not_argmax_list))}")
print(f"Mean accuracy : {len(mean_list) / (len(mean_list) + len(not_mean_list))}")
print(f"Both accuracy : {len(both_list) / (len(both_list) + len(not_both_list))} with {len(both_list) + len(not_both_list)} chosen images")

In [106]:
# delete_img_label("data/PUT_TO_CLIP/", tmp)
# ind-=5
ind-=1

In [118]:
remove_bboxes(0, tmp_text.replace("jpg", "txt"))
ind -=1

In [201]:
remove_all_bboxes(tmp_text)
ind-=1

In [None]:
tmp = not_argmax_list[ind]
print(ind)
tmp_text = ("data/CLIP_AFT_SUN/labels/"+tmp).replace("jpg", "txt")
quickview(folder_orange+tmp, tmp_text)
ind += 1

### update of the subimages that were classified as argmax FN (mean already checked)

In [None]:
initial_length = len(not_argmax_list)
nb_deleted = 0
for i, filename in enumerate(not_argmax_list):
    tmp_text = ("data/CLIP_AFT_SUN/labels/"+filename).replace("jpg", "txt")
    if get_lines(tmp_text) == []:
        delete_img_label("data/CLIP_AFT_SUN/", filename)
        nb_deleted += 1
final_length = len(not_argmax_list) - nb_deleted
print(f"Deleted {nb_deleted} images, {initial_length} -> {final_length}")

In [None]:
folder_orange = "data/CLIP_AFT_SUN/images/"
mean_list = []
not_mean_list = []
argmax_list = []
not_argmax_list = []
for filename in os.listdir(folder_orange):
    if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
        pred_class = predict_class(folder_orange + filename)
        if np.argmax(pred_class) < 6:
            argmax_list.append(filename)
        else:
            not_argmax_list.append(filename)
        if np.mean(pred_class[0][:6]) > np.mean(pred_class[0][6:]):
            mean_list.append(filename)
        else:
            not_mean_list.append(filename)

both_list = list(set(mean_list) & set(argmax_list)) # taking intersection to prevent duplicates (i.e. images that are both mean and argmax)
not_both_list = list(set(not_mean_list) & set(not_argmax_list))

print(f"Argmax accuracy : {len(argmax_list) / (len(argmax_list) + len(not_argmax_list))}")
print(f"Mean accuracy : {len(mean_list) / (len(mean_list) + len(not_mean_list))}")
print(f"Both accuracy : {len(both_list) / (len(both_list) + len(not_both_list))} with {len(both_list) + len(not_both_list)} chosen images")

# CLIP stat plots
One needs to run this from the CLIP folder (run the first two cells at the top of the notebook)

In [9]:
import pandas as pd
all_stats_list = []
ALL_WC_PATH = "../FINAL_SUBIMAGES/" # same structure as in the paper
for folders in os.listdir(ALL_WC_PATH):
    if folders != ".DS_Store":
        folder_path = ALL_WC_PATH + folders + "/images/"
        for filename in os.listdir(folder_path):
            if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
                pred_class = predict_class(folder_path + filename)
                stats = {"WC": filename[:2]}
                for i, prob in enumerate(pred_class[0]):
                    stats[f"p{i}"] = prob
                all_stats_list.append(stats)
all_stats = pd.DataFrame(all_stats_list)

all_stats.to_csv("../clip_stats.csv", index=False)