In [17]:
from pycocotools.coco import COCO
import os

data_dir = '../DATA//coco'
annotation_file = os.path.join(data_dir, 'annotations/instances_train2017.json')

# COCOデータセットを読み込む
coco = COCO(annotation_file)

loading annotations into memory...
Done (t=13.79s)
creating index...
index created!


In [18]:
import cv2
import torch
import numpy as np
import torchvision.transforms as transforms

def reshape_func(image_file, target):#sslに入れるための正方形の画像とそれに対応する計算しなおしたbboxを出力する

    IMAGENET_MEAN = [0.485, 0.456, 0.406] #imagenetの正規化
    IMAGENET_STD = [0.229, 0.224, 0.225]
    IMAGENET_SIZE = 224

    #２枚の画像を比べる(余白なし)

    transform1 = transforms.Resize(224)
    transform2 = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((IMAGENET_MEAN), (IMAGENET_STD))
        ])

    # 画像ファイルのパスを指定する

    img = cv2.imread(image_file)#PILだとshapeがよくわからんからcv2で読み込み
    reshaeped_im = transform1(Image.open(image_file))
    transformed_image = transform2(reshaeped_im).to("cuda")#cv2からだとtorchに入んないからPILで読み込み(いらなそう)

    new_width = reshaeped_im.size[0]
    new_height = reshaeped_im.size[1]


    resize_ratio_x = new_width / img.shape[1]
    resize_ratio_y = new_height / img.shape[0]


    # 画像ファイル名から画像IDを取得する
    image_id = None
    for image_info in coco.dataset['images']:
        if image_info['file_name'] == os.path.basename(image_file):
            image_id = image_info['id']
            break

    if image_id is not None:
        # 画像IDに対応するアノテーション情報を取得する
        annotations_ids = coco.getAnnIds(imgIds=image_id)
        annotations = coco.loadAnns(annotations_ids)

        # BBOXとラベルを表示する
        for annotation in annotations:
            bbox = annotation['bbox']
            resized_bbox = [
            int((bbox[0] * resize_ratio_x) - ((reshaeped_im.size[0] - 224) / 2)),
            int((bbox[1] * resize_ratio_y) - ((reshaeped_im.size[1] - 224) / 2)),
            int(bbox[2] * resize_ratio_x),
            int(bbox[3] * resize_ratio_y)
            ]
            label = coco.loadCats(annotation['category_id'])[0]['name']
            if label == target:
                
                return transformed_image, resized_bbox

In [19]:
vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')##############################################################################
model = vitb16.to("cuda")

def extract(target, inputs):#抽出する関数
    feature = None

    def forward_hook(module, inputs, outputs):
        # 順伝搬の出力を features というグローバル変数に記録する
        global features
        # 1. detach でグラフから切り離す。
        # 2. clone() でテンソルを複製する。モデルのレイヤーで ReLU(inplace=True) のように
        #    inplace で行う層があると、値がその後のレイヤーで書き換えられてまい、
        #    指定した層の出力が取得できない可能性があるため、clone() が必要。
        features = outputs.detach().clone()

    # コールバック関数を登録する。
    handle = target.register_forward_hook(forward_hook)

    # 推論する
    model.eval()
    model(inputs)

    # コールバック関数を解除する。
    handle.remove()

    return features

Using cache found in /home/yishido/.cache/torch/hub/facebookresearch_dino_main


In [20]:
def get_id(bbox):#idをとってくるように変更済み。ちゃんととってきているか確認したい場合はcenter_pointsを返すようにする
    # 画像のサイズとグリッドの設定
    image_size = (224, 224)
    grid_size = (14, 14)###################################################################################

    # BBOXの座標

    # グリッドのセルの幅と高さを計算
    cell_width = image_size[0] // grid_size[0]
    cell_height = image_size[1] // grid_size[1]

    # BBOXの範囲内にあるセルの中心点を取得
    center_points = []
    id = []
    for i in range(grid_size[0]):
        for j in range(grid_size[1]):
            center_x = i * cell_width + cell_width // 2
            center_y = j * cell_height + cell_height // 2
            if bbox[0] <= center_x <= bbox[0] + bbox[2] and bbox[1] <= center_y <= bbox[1] + bbox[3]:
                center_points.append((center_x, center_y))
                id.append([i,j])
    
    return id

In [21]:
def dot_sim(v1, v2):#内積の関数
    return torch.dot(v1, v2) / (torch.linalg.norm(v1) * torch.linalg.norm(v2))

In [78]:
def dot_count(id_1,emb_1,id_2,emb_2):
    
    A = 0 #最大値がbboxの中に履いている個数

    for i in range(len(id_1)):
        n = id_1[i][0] #行
        m = id_1[i][1] #列

        id_emb = n*14+m #196の中のどこに当たるのか###########################################
        # print(id_emb)
        emb1 = emb_1[id_emb]
        # print(emb1.shape)

        # inner_product = torch.matmul(emb_2,emb1)

        # print(emb_2.shape)
        inner_product = []
        for j in range(len(emb_2)):
            inner_product.append(dot_sim(emb1,emb_2[j]).item())
        # print(inner_product)
        inner_product = torch.tensor(inner_product)
        # inner_product = torch.tensor([k.item() for k in inner_product])
        # inner_product = torch.cat(inner_product,dim=1)
        # result = torch.cat(inner_product)

        # print(inner_product.shape)
        # print(inner_product)
        # print(torch.argmax(inner_product))
        max_index = torch.argmax(inner_product).item()
        # print(torch.argmax(inner_product).item())

        k = max_index // 14 #最大値に当たるpatchの行##################################################
        l = max_index % 14 #最大値に当たるpatchの列

        max_id = [k, l] #maxのid（行，列）
        # print(f"maxid:{max_id}")

        if max_id in id_2:
            A += 1
    
    # print(f"count数:{A}")
    # print(f"bbox内のpatchの数:{len(id_1)}")

    if A == 0:
        wariai = 0

    wariai = A/len(id_1)

    return wariai

In [70]:
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

with open('cat.txt', 'r') as file:
    B_C_path = file.read().splitlines()
        
with open('cat_dog.txt', 'r') as file:
    BnC_path = file.read().splitlines()

with open('dog.txt', 'r') as file:
    C_B_path = file.read().splitlines()

In [80]:
#猫ー猫！＿犬

target1 = 'cat'
target2 = 'dog'

allcount=[]

for i in tqdm(range(len(B_C_path))):
    element1 = B_C_path[i]
    for j in range(len(BnC_path)):
        element2 = BnC_path[j]

        image_path1 = f"../DATA/coco/images/train2017/{element1}"
        image_path2 = f"../DATA/coco/images/train2017/{element2}"

        # try:
        im_1 = reshape_func(image_path1, target1) #chanelが１の場合
        im_2 = reshape_func(image_path2, target1)

        # except Exception as e:
        #     continue

        target_module = model.norm
        emb_1 = extract(target_module, im_1[0].unsqueeze(0))[0][1:,]
        emb_2 = extract(target_module, im_2[0].unsqueeze(0))[0][1:,]

        id_in_bbox_1 = get_id(im_1[1])
        id_in_bbox_2 = get_id(im_2[1])
        # print(dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2))

        # try:
        A = dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2) #

        # except Exception as e: #元になるbboxが小さすぎたせいで、idが取れなかった場合
            # continue

        # print(A)
        allcount.append(A)

print(sum(allcount)/len(allcount))

100%|██████████| 10/10 [01:46<00:00, 10.66s/it]

0.17233018106232392





In [81]:
print(sum(allcount)/len(allcount))

0.17233018106232392


In [82]:
#猫ー猫＿犬!

target1 = 'cat'
target2 = 'dog'

allcount=[]

for i in tqdm(range(len(B_C_path))):
    element1 = B_C_path[i]
    for j in range(len(BnC_path)):
        element2 = BnC_path[j]

        image_path1 = f"../DATA/coco/images/train2017/{element1}"
        image_path2 = f"../DATA/coco/images/train2017/{element2}"

        # try:
        im_1 = reshape_func(image_path1, target1) #chanelが１の場合
        im_2 = reshape_func(image_path2, target2)

        # except Exception as e:
        #     continue

        target_module = model.norm
        emb_1 = extract(target_module, im_1[0].unsqueeze(0))[0][1:,]
        emb_2 = extract(target_module, im_2[0].unsqueeze(0))[0][1:,]

        id_in_bbox_1 = get_id(im_1[1])
        id_in_bbox_2 = get_id(im_2[1])
        # print(dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2))

        # try:
        A = dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2) #

        # except Exception as e: #元になるbboxが小さすぎたせいで、idが取れなかった場合
        #     continue

        # print(A)
        allcount.append(A)

print(sum(allcount)/len(allcount))

100%|██████████| 10/10 [01:46<00:00, 10.63s/it]

0.3287963907785338





In [83]:
print(sum(allcount)/len(allcount))

0.3287963907785338


In [15]:
#犬ー猫！＿犬

target1 = 'cat'
target2 = 'dog'

allcount=[]

for i in tqdm(range(len(C_B_path))):
    element1 = C_B_path[i]
    for j in range(len(BnC_path)):
        element2 = BnC_path[j]

        image_path1 = f"../DATA/coco/images/train2017/{element1}"
        image_path2 = f"../DATA/coco/images/train2017/{element2}"

        # try:
        im_1 = reshape_func(image_path1, target2) #chanelが１の場合
        im_2 = reshape_func(image_path2, target1)

        # except Exception as e:
        #     continue

        target_module = model.norm
        emb_1 = extract(target_module, im_1[0].unsqueeze(0))[0][1:,]
        emb_2 = extract(target_module, im_2[0].unsqueeze(0))[0][1:,]

        id_in_bbox_1 = get_id(im_1[1])
        id_in_bbox_2 = get_id(im_2[1])
        # print(dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2))

        # try:
        A = dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2) #

        # except Exception as e: #元になるbboxが小さすぎたせいで、idが取れなかった場合
        #     continue

        # print(A)
        allcount.append(A)

print(sum(allcount)/len(allcount))

100%|██████████| 10/10 [00:10<00:00,  1.04s/it]

0.13984989361519076





In [16]:
print(sum(allcount)/len(allcount))

0.13984989361519076


In [18]:
#犬ー猫＿犬!

target1 = 'cat'
target2 = 'dog'

allcount=[]

for i in tqdm(range(len(C_B_path))):
    element1 = C_B_path[i]
    for j in range(len(BnC_path)):
        element2 = BnC_path[j]

        image_path1 = f"../DATA/coco/images/train2017/{element1}"
        image_path2 = f"../DATA/coco/images/train2017/{element2}"

        # try:
        im_1 = reshape_func(image_path1, target2) #chanelが１の場合
        im_2 = reshape_func(image_path2, target2)

        # except Exception as e:
        #     continue

        target_module = model.norm
        emb_1 = extract(target_module, im_1[0].unsqueeze(0))[0][1:,]
        emb_2 = extract(target_module, im_2[0].unsqueeze(0))[0][1:,]

        id_in_bbox_1 = get_id(im_1[1])
        id_in_bbox_2 = get_id(im_2[1])
        # print(dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2))

        # try:
        A = dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2) #

        # except Exception as e: #元になるbboxが小さすぎたせいで、idが取れなかった場合
        #     continue

        # print(A)
        allcount.append(A)

print(sum(allcount)/len(allcount))

100%|██████████| 10/10 [00:09<00:00,  1.00it/s]

0.3711794342576059





In [19]:
print(sum(allcount)/len(allcount))

0.3711794342576059


犬は猫と犬の違いをわかっているが，猫は犬とみなしてしまう．

In [9]:
#同じ物体同士

target1 = 'cat'
target2 = 'dog'

allcount=[]

for i in tqdm(range(len(B_C_path))):
    element1 = B_C_path[i]
    for j in range(len(B_C_path)):
        element2 = B_C_path[j]

        image_path1 = f"../DATA/coco/images/train2017/{element1}"
        image_path2 = f"../DATA/coco/images/train2017/{element2}"

        # try:
        im_1 = reshape_func(image_path1, target1) #chanelが１の場合
        im_2 = reshape_func(image_path2, target1)

        # except Exception as e:
        #     continue

        target_module = model.norm
        emb_1 = extract(target_module, im_1[0].unsqueeze(0))[0][1:,]
        emb_2 = extract(target_module, im_2[0].unsqueeze(0))[0][1:,]

        id_in_bbox_1 = get_id(im_1[1])
        id_in_bbox_2 = get_id(im_2[1])
        # print(dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2))

        # try:
        A = dot_count(id_in_bbox_1, emb_1, id_in_bbox_2, emb_2) #

        # except Exception as e: #元になるbboxが小さすぎたせいで、idが取れなかった場合
            # continue

        print(A)
        allcount.append(A)

print(sum(allcount)/len(allcount))

  0%|          | 0/10 [00:00<?, ?it/s]

1.0
0.6319444444444444
0.2569444444444444
0.1527777777777778
0.3680555555555556
0.8263888888888888
0.2361111111111111
0.3472222222222222


 10%|█         | 1/10 [00:01<00:11,  1.27s/it]

0.3680555555555556
0.4027777777777778
0.673469387755102
1.0
0.21428571428571427
0.1326530612244898
0.5306122448979592
0.5714285714285714
0.4387755102040816
0.1836734693877551
0.2653061224489796


 20%|██        | 2/10 [00:02<00:08,  1.11s/it]

0.41836734693877553
0.7959183673469388
0.7755102040816326
1.0
0.02040816326530612
0.20408163265306123
0.7755102040816326
0.3673469387755102
0.08163265306122448


 30%|███       | 3/10 [00:02<00:06,  1.13it/s]

0.42857142857142855
0.10204081632653061
0.7708333333333334
0.6041666666666666
0.041666666666666664
1.0
0.14583333333333334
0.9166666666666666
0.3541666666666667
0.10416666666666667
0.7083333333333334


 40%|████      | 4/10 [00:03<00:05,  1.08it/s]

0.1875
0.7777777777777778
0.6
0.2
0.26666666666666666
1.0
0.6888888888888889
0.28888888888888886
0.4888888888888889


 50%|█████     | 5/10 [00:04<00:04,  1.13it/s]

0.4888888888888889
0.4444444444444444
0.7857142857142857
0.44805194805194803
0.2987012987012987
0.2727272727272727
0.2012987012987013
1.0
0.14935064935064934
0.43506493506493504


 60%|██████    | 6/10 [00:05<00:03,  1.23it/s]

0.38961038961038963
0.35714285714285715
0.7551020408163265
0.6938775510204082
0.20408163265306123
0.40816326530612246
0.24489795918367346
0.2857142857142857
1.0
0.2857142857142857
0.20408163265306123


 70%|███████   | 7/10 [00:06<00:02,  1.13it/s]

0.2857142857142857
0.8035714285714286
0.3392857142857143
0.10714285714285714
0.14285714285714285
0.39285714285714285
0.8035714285714286
0.16071428571428573
1.0


 80%|████████  | 8/10 [00:07<00:01,  1.25it/s]

0.42857142857142855
0.30357142857142855
0.9814814814814815
0.4444444444444444
0.4074074074074074
0.16666666666666666
0.37037037037037035
0.8333333333333334
0.4444444444444444
0.6296296296296297


 90%|█████████ | 9/10 [00:08<00:00,  1.12it/s]

1.0
0.3888888888888889
0.95
0.475
0.15
0.125
0.275
0.475
0.45
0.3


100%|██████████| 10/10 [00:09<00:00,  1.07it/s]

0.325
1.0
0.4702646619253762





In [10]:
print(sum(allcount)/len(allcount))

0.4702646619253762


### todolist
1.4パターン試す<br>
2.その間codeがあっているかの確認<br>
3.各モデルで実行する<br>
4.論文かく！