# 実行環境の確認

In [1]:
!python -V

Python 3.7.12


## Library

In [2]:
!pip install -r ../requirements-training.txt

Collecting webdataset>=0.2.5
  Downloading webdataset-0.2.5-py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.9/46.9 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting regex
  Downloading regex-2022.4.24-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (749 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m749.7/749.7 kB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Collecting braceexpand
  Downloading braceexpand-0.1.7-py2.py3-none-any.whl (5.9 kB)
Installing collected packages: braceexpand, webdataset, regex, ftfy
Successfully installed braceexpand-0.1.7 ftfy-6.1.1 regex-2022.4.24 webdataset-0.2.5


In [3]:
import torch

print(torch.__version__)

  from .autonotebook import tqdm as notebook_tqdm


1.11.0


# 検証

## データセット

In [4]:
from open_clip import (
    create_model_and_transforms,
    image_transform,
    tokenize,
)
from training import data as data_module
from training.data import (
    get_data,
    get_wds_dataset,
)

from PIL import Image
import requests
import io

In [42]:
import pandas as pd

# data_df = pd.read_table("./Train_GCC-training.tsv", header=None)
data_df = pd.read_table("./Validation_GCC-1.1.0-Validation.tsv", header=None)
data_df = data_df.sample(frac=1, random_state=0, ignore_index=True)

In [45]:
data_df.head()

Unnamed: 0,0,1
0,cherry blossoms and a field of roses make for ...,https://static1.squarespace.com/static/504c2ec...
1,a sq ft tiny house made from reclaimed barn wo...,https://s-media-cache-ak0.pinimg.com/originals...
2,person at the fashion show,https://media.gettyimages.com/photos/martha-wa...
3,hospitality business also boasts a bird 's eye...,http://i.dailymail.co.uk/i/pix/2016/12/27/21/3...
4,person in hooded sweater using a laptop on woo...,https://media.gettyimages.com/photos/person-in...


In [19]:
model, _, preprocess = create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')

batch_num = 10
data_cnt = 0
images = []
texts = []
for row_cnt, (text, image_url) in data_df.iterrows():
    try:
        image = Image.open(io.BytesIO(requests.get(image_url).content))
    except:
        print("画像の読み込み失敗")
        print(image_url)
        continue

    image = preprocess(image).unsqueeze(0)
    images.append(image)
    texts.append(text)

    if data_cnt == batch_num - 1:
        break
    else:
        data_cnt += 1

images = torch.cat(images)
texts = tokenize(texts)

with torch.no_grad():
    image_features, text_features, logit_scale = model(images, texts)
    print(image_features.shape)
    print(text_features.shape)


画像の読み込み失敗
http://www.standard.net/image/2015/02/04/800x_a16-9_b0_q81_p1/winter-fly-fishing.jpg
画像の読み込み失敗
http://indianapolis-photos.funcityfinder.com/files/2009/12/Clearwater-Crossing-Shopping-Center-sign-Indianapolis-Indiana.jpg
画像の読み込み失敗
https://www.featurepics.com/StockImage/20090316/carrying-globe-stock-image-1115085.jpg
画像の読み込み失敗
http://www.waste360.com/sites/waste360.com/files/styles/article_featured_standard/public/Trista%2002%20007_0.jpg?itok=F1eJZsX3
画像の読み込み失敗
https://media.gettyimages.com/photos/young-woman-seated-on-the-beach-picture-id97545987?s=612x612
画像の読み込み失敗
http://piquemagazine.uk/wp-content/uploads/2017/10/LPO-24-Feb-Albrecht-Menzel-%C2%AE-Anne-Hornemann-300dpi.jpg
torch.Size([10, 512])
torch.Size([10, 512])


# ネットワークの実装

In [9]:
import torch.nn as nn
import torch.nn.functional as F

## Loss

In [34]:
class CustomClipLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, image_features, text_features, logit_scale):
        # print("logit_scale: ", logit_scale)
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logit_scale * text_features @ image_features.T
        
        # print(logits_per_image.shape)
        # print(logits_per_text.shape)

        num_logits = logits_per_image.shape[0]
        # print("num_logits", num_logits)
        labels = torch.arange(num_logits, device="cpu", dtype=torch.long)
        # print("label: ", labels)

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
            ) / 2

        return total_loss

clip_loss = CustomClipLoss()
clip_loss(image_features, text_features, logit_scale)

tensor(2.3547, grad_fn=<DivBackward0>)

## Skip Lanyer Network

In [21]:
class Block(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, input_dim)
        self.relu1 = nn.ReLU()

    def forward(self, x):
        h = self.fc1(x)
        h = torch.add(h, x)
        h = self.relu1(h)

        return h

class SkipNetwork(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, input_dim)
        self.relu1 = nn.ReLU()

        self.block1 = Block(input_dim)
        self.block2 = Block(input_dim)

        self.fc2 = nn.Linear(input_dim, output_dim)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)

        x = self.block1(x)
        x = self.block2(x)

        x = self.fc2(x)
        x = self.relu2(x)

        return x

In [51]:
text_network = SkipNetwork(512, 512)
image_network = SkipNetwork(512, 512)
clip_loss = CustomClipLoss()

optimizer = torch.optim.Adam(list(text_network.parameters()) + 
                             list(image_network.parameters()), lr=0.001)
model, _, preprocess = create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')

batch_num = 50
epoch_num = 5

for epoch_cnt in range(epoch_num):
    if epoch_cnt == 0:
        print(epoch_cnt)

    data_cnt = 0
    train_images = []
    train_texts = []
    test_images = []
    test_texts = []
    for row_cnt, (text, image_url) in data_df.iterrows():
        if row_cnt == 0:
            print("データの確認", text)

        try:
            image = Image.open(io.BytesIO(requests.get(image_url, timeout=(3.0, 7.5)).content))
        except:
            # print("画像の読み込み失敗")
            # print(image_url)
            continue

        image = preprocess(image).unsqueeze(0)

        if data_cnt < (batch_num * 2 / 3):
            train_images.append(image)
            train_texts.append(text)
        else:
            test_images.append(image)
            test_texts.append(text)

        if data_cnt == batch_num - 1:
            break
        else:
            data_cnt += 1

    with torch.no_grad():
        train_images = torch.cat(train_images)
        train_texts = tokenize(train_texts)
        train_image_features, train_text_features, train_logit_scale = model(train_images, train_texts)

        test_images = torch.cat(test_images)
        test_texts = tokenize(test_texts)
        test_image_features, test_text_features, test_logit_scale = model(test_images, test_texts)

    train_image_features = image_network(train_image_features)
    train_text_features = text_network(train_text_features)
    train_loss = clip_loss(train_image_features, train_text_features, train_logit_scale)

    train_loss.backward()
    optimizer.step()

    with torch.no_grad():
        test_image_features = image_network(test_image_features)
        test_text_features = text_network(test_text_features)
        test_loss = clip_loss(test_image_features, test_text_features, test_logit_scale)

    print("train loss: ", train_loss)
    print("test loss: ", test_loss)

0
データの確認 cherry blossoms and a field of roses make for a classic feminine scene celebrating perfume .
train loss:  tensor(3.6071, grad_fn=<DivBackward0>)
test loss:  tensor(2.8353)
データの確認 cherry blossoms and a field of roses make for a classic feminine scene celebrating perfume .
train loss:  tensor(3.4811, grad_fn=<DivBackward0>)
test loss:  tensor(2.7836)
データの確認 cherry blossoms and a field of roses make for a classic feminine scene celebrating perfume .
train loss:  tensor(3.0681, grad_fn=<DivBackward0>)
test loss:  tensor(2.7679)
データの確認 cherry blossoms and a field of roses make for a classic feminine scene celebrating perfume .
train loss:  tensor(2.7798, grad_fn=<DivBackward0>)
test loss:  tensor(2.7503)
データの確認 cherry blossoms and a field of roses make for a classic feminine scene celebrating perfume .
train loss:  tensor(2.2913, grad_fn=<DivBackward0>)
test loss:  tensor(2.7236)
