# 実行環境の確認

In [1]:
!python -V

Python 3.7.12


## Library

In [2]:
!pip install -r ../requirements-training.txt

Collecting webdataset>=0.2.5
  Downloading webdataset-0.2.5-py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.9/46.9 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting regex
  Downloading regex-2022.4.24-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (749 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m749.7/749.7 kB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
Collecting braceexpand
  Downloading braceexpand-0.1.7-py2.py3-none-any.whl (5.9 kB)
Installing collected packages: braceexpand, webdataset, regex, ftfy
Successfully installed braceexpand-0.1.7 ftfy-6.1.1 regex-2022.4.24 webdataset-0.2.5


In [18]:
import torch

print(torch.__version__)

1.11.0


# 検証

## データセット

In [89]:
from open_clip import (
    create_model_and_transforms,
    image_transform,
    tokenize,
)
from training import data as data_module
from training.data import (
    get_data,
    get_wds_dataset,
)

from PIL import Image
import requests
import io

In [72]:
import pandas as pd

data_df = pd.read_table("./Train_GCC-training.tsv", header=None)

In [None]:
data_df.head()

Unnamed: 0,0,1
0,a very typical bus station,http://lh6.ggpht.com/-IvRtNLNcG8o/TpFyrudaT6I/...
1,sierra looked stunning in this top and this sk...,http://78.media.tumblr.com/3b133294bdc7c7784b7...
2,young confused girl standing in front of a war...,https://media.gettyimages.com/photos/young-con...
3,interior design of modern living room with fir...,https://thumb1.shutterstock.com/display_pic_wi...
4,cybernetic scene isolated on white background .,https://thumb1.shutterstock.com/display_pic_wi...


In [109]:
model, _, preprocess = create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')

# preprocess = image_transform(128, is_train=False)

batch_num = 10
data_cnt = 0
images = []
texts = []
for row_cnt, (text, image_url) in data_df.iterrows():
    try:
        image = Image.open(io.BytesIO(requests.get(image_url).content))
    except:
        print("画像の読み込み失敗")
        print(image_url)
        continue

    image = preprocess(image).unsqueeze(0)
    images.append(image)
    texts.append(text)

    if data_cnt == batch_num - 1:
        break
    else:
        data_cnt += 1

images = torch.cat(images)
texts = tokenize(texts)

with torch.no_grad():
    image_features, text_features, logit_scale = model(images, texts)
    print(image_features.shape)
    print(text_features.shape)


torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
画像の読み込み失敗
http://www.robinhoodshow.com/clients/17668/8642054_org.jpg
torch.Size([1, 3, 224, 224])
torch.Size([10, 512])
torch.Size([10, 512])


# ネットワークの実装

In [88]:
import torch.nn as nn
import torch.nn.functional as F

## Loss

In [116]:
class CustomClipLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, image_features, text_features, logit_scale):
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logit_scale * text_features @ image_features.T

        num_logits = logits_per_image.shape[0]
        print(num_logits)
        labels = torch.arange(num_logits, device="cpu", dtype=torch.long)
        print(labels)

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
            ) / 2
        
        print(total_loss)

loss = CustomClipLoss()
loss(image_features, text_features, logit_scale)

10
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor(0.0001)
