In [1]:
TF2_FILE = 'clip-vitb-32-tf2/weights'

In [2]:
import sys
import os

# install libraries
#!pip install ftfy regex tqdm
# import clip code
if not os.path.exists('CLIP/'):
    !git clone https://github.com/openai/CLIP.git
# add cloned git directories to path (otherwise colab can't find them)
sys.path.insert(0, 'CLIP/')

In [3]:
import torch
import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open(os.path.join("CLIP.png"))).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

Label probs: [[0.9927937  0.00421067 0.0029957 ]]


In [4]:
model.token_embedding.weight.detach().numpy()

array([[-3.90530936e-03, -6.32541208e-03,  7.35073257e-03, ...,
        -1.06596146e-02, -2.27636080e-02, -1.09076742e-02],
       [-2.60810871e-02,  8.79530329e-03, -1.17371846e-02, ...,
        -1.20190559e-02, -2.40586717e-02, -2.19290163e-02],
       [-1.96484327e-02, -6.67110318e-03, -9.05930717e-03, ...,
         4.57819877e-03, -2.06920728e-02, -8.71497300e-03],
       ...,
       [ 8.50279909e-03,  1.02194364e-03,  2.03663092e-02, ...,
         1.48675805e-02,  1.76269542e-02, -1.47524709e-03],
       [-1.67414418e-03,  7.30483516e-05, -4.19964641e-03, ...,
        -3.40962294e-03, -3.92947206e-03, -5.52894235e-05],
       [-6.02601049e-03,  2.02100957e-03,  4.96737135e-04, ...,
        -3.34585714e-03, -9.85872559e-03, -2.33900530e-04]], dtype=float32)

In [5]:
# TF2 replica of CLIP

from collections import OrderedDict
from typing import Tuple, Union

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras as keras


class QuickGELU(keras.layers.Layer):
    def call(self, x):
        return x * keras.activations.sigmoid(1.702 * x)


class ResidualAttentionBlock(keras.layers.Layer):
    def __init__(self, d_model, n_head, attn_mask=None):
        super().__init__()

        self.attn = keras.layers.MultiHeadAttention(n_head, d_model // n_head)
        self.ln_1 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.mlp = keras.Sequential([
            keras.layers.Dense(d_model * 4),
            QuickGELU(),
            keras.layers.Dense(d_model)
        ])
        self.ln_2 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.attn_mask = attn_mask

    def attention(self, x):
        return self.attn(x, x, x, attention_mask=self.attn_mask)

    def call(self, x):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(keras.layers.Layer):
    def __init__(self, width, layers, heads, attn_mask=None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = keras.Sequential([ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def call(self, x):
        return self.resblocks(x)


class VisionTransformer(keras.layers.Layer):
    def __init__(self, input_resolution, patch_size, width, layers, heads, output_dim):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = keras.layers.Conv2D(width, kernel_size=patch_size, strides=patch_size, use_bias=False)

        scale = width ** -0.5
        self.class_embedding = self.add_weight(name='cl_emb', shape=(1, 1, width,), trainable=True)
        self.positional_embedding = self.add_weight(name='pos_emb', shape=((input_resolution // patch_size) ** 2 + 1, width), trainable=True)
        self.ln_pre = keras.layers.LayerNormalization(epsilon=1e-5)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = keras.layers.LayerNormalization(epsilon=1e-5)
        self.proj = self.add_weight(name='proj', shape=(width, output_dim), trainable=True)

    def call(self, x):
        x = self.conv1(x)  # shape = [*, grid, grid, width]
        x = tf.reshape(x, (x.shape[0], -1, x.shape[3]))  # shape = [*, grid ** 2, width]
        x = tf.concat((tf.tile(self.class_embedding, (x.shape[0], 1, 1)), x), axis=1)
        x = x + self.positional_embedding
        x = self.ln_pre(x)

        x = self.transformer(x)

        x = self.ln_post(tf.gather(x, 0, axis=1))

        if self.proj is not None:
            x = tf.linalg.matmul(x, self.proj)

        return x


class CLIPTF2(keras.Model):
    def __init__(self,
                 embed_dim,
                 # vision
                 image_resolution,
                 vision_layers,
                 vision_width,
                 vision_patch_size,
                 # text
                 context_length,
                 vocab_size,
                 transformer_width,
                 transformer_heads,
                 transformer_layers
                 ):
        super().__init__()

        self.context_length = context_length
        self._tokenizer = _Tokenizer()

        vision_heads = vision_width // 64
        self.visual = VisionTransformer(
            input_resolution=image_resolution,
            patch_size=vision_patch_size,
            width=vision_width,
            layers=vision_layers,
            heads=vision_heads,
            output_dim=embed_dim
        )

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )

        self.vocab_size = vocab_size
        self.token_embedding = keras.layers.Embedding(vocab_size, transformer_width)
        self.positional_embedding = self.add_weight(name='pos_emb', shape=(self.context_length, transformer_width), trainable=True)
        self.ln_final = keras.layers.LayerNormalization(epsilon=1e-5)

        self.text_projection = self.add_weight(name='text_proj', shape=(transformer_width, embed_dim), trainable=True)
        #self.logit_scale = tf.constant(np.log(1 / 0.07), dtype=tf.float32)

    def build_attention_mask(self):
        mask = tf.ones(shape=(self.context_length, self.context_length))
        # lower triangular attention
        mask = tf.linalg.band_part(mask, -1, 0)
        return mask

    def encode_image(self, image):
        return self.visual(image)

    def encode_text(self, text):
        x = self.token_embedding(text)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding
        x = self.transformer(x)
        x = self.ln_final(x)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = tf.linalg.matmul(tf.gather(x, tf.math.argmax(text, axis=-1), axis=1, batch_dims=1), self.text_projection)

        return x

    def tokenize(self, texts, context_length=77, truncate=False):
        if isinstance(texts, str):
            texts = [texts]

        sot_token = self._tokenizer.encoder["<|startoftext|>"]
        eot_token = self._tokenizer.encoder["<|endoftext|>"]
        all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
        result = tf.zeros((len(all_tokens), context_length), dtype=tf.int32)

        for i, tokens in enumerate(all_tokens):
            if len(tokens) > context_length:
                if truncate:
                    tokens = tokens[:context_length]
                    tokens[-1] = eot_token
                else:
                    raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
            tokens = tokens + [0 for _ in range(max(0, context_length - len(tokens)))]
            result = tf.tensor_scatter_nd_update(result, tf.constant(i, shape=(1, 1)), tf.reshape(tf.constant(tokens, dtype=tf.int32), (1, -1)))

        return result

    def call(self, image, text, logit_scale=100):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        image_features = image_features / tf.norm(image_features, axis=-1, keepdims=True)
        text_features = text_features / tf.norm(text_features, axis=-1, keepdims=True)

        # cosine similarity as logits
        #logit_scale = tf.math.exp(self.logit_scale)

        logits_per_image = logit_scale * tf.linalg.matmul(image_features, tf.transpose(text_features, (1, 0)))
        logits_per_text = tf.transpose(logits_per_image, (1, 0))

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text

In [6]:
clip_tf = CLIPTF2(embed_dim=512, image_resolution=224, vision_layers=12, vision_width=768, vision_patch_size=32,
                  context_length=77, vocab_size=49408, transformer_width=512, transformer_layers=12, transformer_heads=8)

text_tf = clip_tf.tokenize(["a diagram", "a dog", "a cat"])
image_tf = tf.transpose(tf.constant(image.numpy()), (0, 2, 3, 1))

logits_per_image_tf, logits_per_text_tf = clip_tf(image_tf, text_tf)
print(logits_per_image_tf)

tf.Tensor([[3.304957  3.8573132 4.0554957]], shape=(1, 3), dtype=float32)


In [7]:
#########################
# ## WEIGHT TRANSFER ## #
#########################

# function for spectral norms
def reshape_weight_to_matrix(weight_in):
    weight_mat = weight_in
    height = weight_mat.shape[0]
    return weight_mat.reshape(height, -1)

# function to copy linear layer
def list_linear_weights(layer_pt):
    weight_list = []
    # get pt weights
    weight_list += [np.transpose(layer_pt.weight.data.numpy())]
    if layer_pt.bias is not None:
        weight_list += [layer_pt.bias.data.numpy()]
    return weight_list

# function to copy conv layer
def list_conv_weights(layer_pt):
    weight_list = []

    # add to list
    weight_list += [np.transpose(layer_pt.weight.data.numpy(), (2, 3, 1, 0))]
    # save bias
    if layer_pt.bias is not None:
        weight_list += [layer_pt.bias.data.numpy()]

    return weight_list

# copy batchnorm layer
def list_batch_norm(layer_pt):
    weight_list = []

    # learning scaling and bias for standard batch norm
    weight_list += [layer_pt.weight.data.numpy()]
    weight_list += [layer_pt.bias.data.numpy()]
    # running mean and var records 
    weight_list += [layer_pt.running_mean.data.numpy()]
    weight_list += [layer_pt.running_var.data.numpy()]

    return weight_list

# copy layer norm layer
def list_layer_norm(layer_pt):
    weight_list = []

    # layer norm scale and shift
    weight_list += [layer_pt.weight.data.numpy()]
    weight_list += [layer_pt.bias.data.numpy()]

    return weight_list

# copy multiheaded attention
def list_attn_weights(layer_pt, num_heads=12):
    weights = []

    # input weight
    weight_attn = layer_pt.in_proj_weight.data.numpy()
    weight_attn = np.transpose(weight_attn, (1, 0))
    d_model = weight_attn.shape[0]
    weight_attn = np.reshape(weight_attn, (d_model, 3 * num_heads, -1))
    
    # input weight
    bias_attn = layer_pt.in_proj_bias.data.numpy()
    bias_attn = np.reshape(bias_attn, (3 * num_heads, -1))

    for i in range(3):
        weights += [weight_attn[:, (i*num_heads):((i+1)*num_heads)]]
        weights += [bias_attn[(i*num_heads):((i+1)*num_heads)]]

    # out weight
    weight_out = layer_pt.out_proj.weight.data.numpy()
    weight_out = np.transpose(weight_out, (1, 0))
    weight_out = np.reshape(weight_out, (num_heads, -1, d_model))
    weights += [weight_out]
    # out bias
    weights += [layer_pt.out_proj.bias.data.numpy()]

    return weights

# mlp weights
def list_mlp_weights(layer_pt):
    dense_1_weights = list_linear_weights(layer_pt[0])
    dense_2_weights = list_linear_weights(layer_pt[2])
    return dense_1_weights + dense_2_weights

# copy gen block
def list_res_block(layer_pt, num_heads=12):
    weights = []

    # list weights of res block components
    weights += list_attn_weights(layer_pt.attn, num_heads=num_heads)
    weights += list_layer_norm(layer_pt.ln_1)
    weights += list_mlp_weights(layer_pt.mlp)
    weights += list_layer_norm(layer_pt.ln_2)
    
    return weights


# vision transformer
print('Patch Embeddings for ViT.')
# patch conv
patch_weights = list_conv_weights(model.visual.conv1)
clip_tf.visual.conv1.set_weights(patch_weights)
# class embedding
class_embedding = np.reshape(model.visual.class_embedding.data.numpy(), (1, 1, -1))
clip_tf.visual.class_embedding.assign(class_embedding)
# positional embedding
positional_embedding = model.visual.positional_embedding.data.numpy()
clip_tf.visual.positional_embedding.assign(positional_embedding)
# layer norm pre transformer
ln_weights = list_layer_norm(model.visual.ln_pre)
clip_tf.visual.ln_pre.set_weights(ln_weights)

# transformer
resblock_weights = []
for i in range(len(model.visual.transformer.resblocks)):
    print('ViT ResBlock {}'.format(i+1))
    # get weights for residual block
    resblock_weights += list_res_block(model.visual.transformer.resblocks[i], num_heads=12)
clip_tf.visual.transformer.resblocks.set_weights(resblock_weights)

# layer norm post transformer
ln_weights = list_layer_norm(model.visual.ln_post)
clip_tf.visual.ln_post.set_weights(ln_weights)
# final proj weight
proj_weight = model.visual.proj.data.numpy()
clip_tf.visual.proj.assign(proj_weight)


# text transformer
print('Embeddings for Text Transformer.')
# class embedding
token_embedding = [model.token_embedding.weight.data.numpy()]
clip_tf.token_embedding.set_weights(token_embedding)
# positional embedding
positional_embedding = model.positional_embedding.data.numpy()
clip_tf.positional_embedding.assign(positional_embedding)

# transformer
resblock_weights = []
for i in range(len(model.transformer.resblocks)):
    print('Text ResBlock {}'.format(i+1))
    # get weights for residual block
    resblock_weights += list_res_block(model.transformer.resblocks[i], num_heads=8)
clip_tf.transformer.resblocks.set_weights(resblock_weights)

# layer norm post transformer
ln_weights = list_layer_norm(model.ln_final)
clip_tf.ln_final.set_weights(ln_weights)
# final proj weight
proj_weight = model.text_projection.data.numpy()
clip_tf.text_projection.assign(proj_weight)

Patch Embeddings for ViT.
ViT ResBlock 1
ViT ResBlock 2
ViT ResBlock 3
ViT ResBlock 4
ViT ResBlock 5
ViT ResBlock 6
ViT ResBlock 7
ViT ResBlock 8
ViT ResBlock 9
ViT ResBlock 10
ViT ResBlock 11
ViT ResBlock 12
Embeddings for Text Transformer.
Text ResBlock 1
Text ResBlock 2
Text ResBlock 3
Text ResBlock 4
Text ResBlock 5
Text ResBlock 6
Text ResBlock 7
Text ResBlock 8
Text ResBlock 9
Text ResBlock 10
Text ResBlock 11
Text ResBlock 12


<tf.Variable 'UnreadVariable' shape=(512, 512) dtype=float32, numpy=
array([[-0.01043701,  0.01422882, -0.00836945, ..., -0.00688171,
        -0.01246643,  0.00120258],
       [ 0.00536346,  0.00133991, -0.00360298, ...,  0.0026207 ,
         0.01364899, -0.02009583],
       [ 0.00286102,  0.00314713,  0.01821899, ...,  0.00336075,
         0.00518799, -0.00634384],
       ...,
       [ 0.00939178,  0.03060913,  0.0135498 , ...,  0.01597595,
         0.00135612, -0.0109787 ],
       [-0.01131439,  0.00466537,  0.0016737 , ..., -0.00434875,
        -0.01869202, -0.00491714],
       [ 0.00764084, -0.00674438,  0.01120758, ...,  0.00359917,
        -0.00379562,  0.01699829]], dtype=float32)>

In [8]:
# test vit component
im_out_pt = model.visual(image)
im_out_tf = clip_tf.visual(image_tf)

print('Error for image encoding: ', np.max(np.abs((im_out_pt.data.numpy() - im_out_tf.numpy()))))

# test text component
text_out_pt = model.encode_text(text)
text_out_tf = clip_tf.encode_text(text_tf)

print('Error for text encoding: ', np.max(np.abs((text_out_pt.data.numpy() - text_out_tf.numpy()))))

Error for image encoding:  3.0994415e-06
Error for text encoding:  2.3841858e-06


In [9]:
logits_per_image_tf, _ = clip_tf(image_tf, text_tf)
probs = keras.activations.softmax(logits_per_image_tf, axis=1).numpy()

print("Label probs:", probs)  # should match pytorch output: [[0.9927937  0.00421068 0.00299572]]

Label probs: [[0.99279356 0.00421069 0.00299576]]


In [10]:
# save tf2 weights
clip_tf.save_weights(TF2_FILE)