In [1]:
from pretrain import *
import pytorch_lightning as pl
import pathlib
from transformer.transformer import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_folder = pathlib.Path('.').absolute() / 'model'

In [3]:
checkpoint_file = model_folder / 'pretrain/PretrainCLIP-epoch=23.ckpt'

In [4]:
cp = torch.load(checkpoint_file)

In [5]:
print(cp.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])


In [6]:
for key in cp['state_dict'].keys():
    print(key)

logit_scale
img_encoder.cls_embedding
img_encoder.embedding_conv.weight
img_encoder.position_embedding.weight
img_encoder.ln.weight
img_encoder.ln.bias
img_encoder.encoder.blocks.0.self_attention.q_proj.weight
img_encoder.encoder.blocks.0.self_attention.q_proj.bias
img_encoder.encoder.blocks.0.self_attention.k_proj.weight
img_encoder.encoder.blocks.0.self_attention.k_proj.bias
img_encoder.encoder.blocks.0.self_attention.v_proj.weight
img_encoder.encoder.blocks.0.self_attention.v_proj.bias
img_encoder.encoder.blocks.0.linear.weight
img_encoder.encoder.blocks.0.linear.bias
img_encoder.encoder.blocks.0.feedforward.mlp_1.weight
img_encoder.encoder.blocks.0.feedforward.mlp_1.bias
img_encoder.encoder.blocks.0.feedforward.mlp_2.weight
img_encoder.encoder.blocks.0.feedforward.mlp_2.bias
img_encoder.encoder.blocks.0.ln_1.weight
img_encoder.encoder.blocks.0.ln_1.bias
img_encoder.encoder.blocks.0.ln_2.weight
img_encoder.encoder.blocks.0.ln_2.bias
img_encoder.encoder.blocks.1.self_attention.q_proj

In [7]:
te = TextEncoder(
    token_num=45,
    max_seq_length=45,
    token_type_num=0,
    hidden_dim=128,
    intermediate_dim=128*4,
    head_num=8,
    block_num=6,
    drop_rate=0.1,
    ln_eps=1e-5
)
ie = ImageEncoder(
    input_resolution=256,
    patch_size=24,
    hidden_dim=128,
    intermediate_dim=128*4,
    head_num=8,
    block_num=8,
    drop_rate=0.1,
    ln_eps=1e-5
)

In [8]:
model = PretrainCLIP.load_from_checkpoint(
    model_folder / 'pretrain/PretrainCLIP-epoch=23.ckpt',
    img_encoder = ie,
    text_encoder = te
)

In [11]:
cp['state_dict']['text_encoder.embedding.token_embedding.weight']

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9759, -0.8328, -0.0370,  ...,  0.4951,  0.7021,  1.6298],
        [ 0.7853,  0.0558,  1.5444,  ..., -2.4350,  0.4090,  0.1058],
        ...,
        [ 1.0269, -0.1056,  0.2926,  ...,  0.3973, -1.2935, -0.0172],
        [-0.7020, -0.5010,  1.1669,  ...,  0.1498, -1.0544, -0.3841],
        [ 0.8337, -0.8809, -1.1774,  ..., -0.1676,  0.3148, -0.5701]],
       device='cuda:0')

In [None]:
model.text_encoder.state_dict()

OrderedDict([('embedding.token_embedding.weight',
              tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [-0.9759, -0.8328, -0.0370,  ...,  0.4951,  0.7021,  1.6298],
                      [ 0.7853,  0.0558,  1.5444,  ..., -2.4350,  0.4090,  0.1058],
                      ...,
                      [ 1.0269, -0.1056,  0.2926,  ...,  0.3973, -1.2935, -0.0172],
                      [-0.7020, -0.5010,  1.1669,  ...,  0.1498, -1.0544, -0.3841],
                      [ 0.8337, -0.8809, -1.1774,  ..., -0.1676,  0.3148, -0.5701]],
                     device='cuda:0')),
             ('embedding.position_embedding.weight',
              tensor([[ 0.9933,  0.3255, -0.8882,  ..., -0.1889,  0.9677,  0.3070],
                      [ 0.3425,  0.0321,  0.9651,  ...,  1.5556, -0.3256, -1.3271],
                      [-1.2665,  0.9199, -0.0651,  ..., -0.8678,  0.9924, -0.2846],
                      ...,
                      [ 1.5117, -0.2427, -0.29

In [None]:
model.text_encoder.state_dict()['embedding.token_embedding.weight']

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.9759, -0.8328, -0.0370,  ...,  0.4951,  0.7021,  1.6298],
        [ 0.7853,  0.0558,  1.5444,  ..., -2.4350,  0.4090,  0.1058],
        ...,
        [ 1.0269, -0.1056,  0.2926,  ...,  0.3973, -1.2935, -0.0172],
        [-0.7020, -0.5010,  1.1669,  ...,  0.1498, -1.0544, -0.3841],
        [ 0.8337, -0.8809, -1.1774,  ..., -0.1676,  0.3148, -0.5701]],
       device='cuda:0')

In [None]:
for mn, m in model.named_modules():
    for pn, p in m.named_parameters():
        fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

        print(mn, pn, fpn)

 logit_scale logit_scale
 img_encoder.cls_embedding img_encoder.cls_embedding
 img_encoder.embedding_conv.weight img_encoder.embedding_conv.weight
 img_encoder.position_embedding.weight img_encoder.position_embedding.weight
 img_encoder.ln.weight img_encoder.ln.weight
 img_encoder.ln.bias img_encoder.ln.bias
 img_encoder.encoder.blocks.0.self_attention.q_proj.weight img_encoder.encoder.blocks.0.self_attention.q_proj.weight
 img_encoder.encoder.blocks.0.self_attention.q_proj.bias img_encoder.encoder.blocks.0.self_attention.q_proj.bias
 img_encoder.encoder.blocks.0.self_attention.k_proj.weight img_encoder.encoder.blocks.0.self_attention.k_proj.weight
 img_encoder.encoder.blocks.0.self_attention.k_proj.bias img_encoder.encoder.blocks.0.self_attention.k_proj.bias
 img_encoder.encoder.blocks.0.self_attention.v_proj.weight img_encoder.encoder.blocks.0.self_attention.v_proj.weight
 img_encoder.encoder.blocks.0.self_attention.v_proj.bias img_encoder.encoder.blocks.0.self_attention.v_proj.bias


In [None]:
te = TextEncoder(
    token_num=45,
    max_seq_length=45,
    token_type_num=0,
    hidden_dim=128,
    intermediate_dim=128*4,
    head_num=8,
    block_num=6,
    drop_rate=0.1,
    ln_eps=1e-5
)
ie = ImageEncoder(
    input_resolution=256,
    patch_size=24,
    hidden_dim=128,
    intermediate_dim=128*4,
    head_num=8,
    block_num=8,
    drop_rate=0.1,
    ln_eps=1e-5
)

In [12]:
from img_shape import *
train_epochs = 50
tokens_per_epoch = 550
i2s = Img2Shape(
    img_encoder=model.img_encoder,
    intermediate_dim=128,
    shape_num=5,
    ln_eps=1e-5,
    drop_rate=0.1,
    learning_rate=0.0002,
    warmup_tokens=tokens_per_epoch,
    final_tokens=train_epochs*tokens_per_epoch,
    weight_decay=0.01,
    adamw_betas=(0.9, 0.97)
)
i2s.to(model.device)

Img2Shape(
  (img_encoder): ImageEncoder(
    (embedding_conv): Conv2d(3, 128, kernel_size=(24, 24), stride=(24, 24), bias=False)
    (position_embedding): Embedding(101, 128)
    (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (encoder): Encoder(
      (blocks): ModuleList(
        (0-7): 8 x EncoderBlock(
          (self_attention): MultiHeadAttention(
            (q_proj): Linear(in_features=128, out_features=128, bias=True)
            (k_proj): Linear(in_features=128, out_features=128, bias=True)
            (v_proj): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (linear): Linear(in_features=128, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (feedforward): FeedForward(
            (mlp_1): Linear(in_features=128, out_features=512, bias=True)
            (mlp_2): Linear(in_features=512, out_feature

In [None]:
i2s.load_from_pretrain(
    model.img_encoder.state_dict(),
    model.img_proj.state_dict(),
    model.ln_img.state_dict()
)

In [None]:
i2s.state_dict()

OrderedDict([('img_encoder.cls_embedding',
              tensor([ 0.0150,  0.0594, -0.0656, -0.0097,  0.0217, -0.1738, -0.0173,  0.0450,
                       0.0335,  0.0025, -0.0498,  0.1228, -0.1472, -0.0178, -0.0056, -0.0448,
                      -0.0710, -0.0919,  0.0429, -0.0046,  0.0455,  0.0127, -0.0466, -0.1142,
                       0.1130,  0.1039, -0.2352, -0.0520, -0.1319, -0.0257, -0.1303, -0.0430,
                      -0.0396,  0.1187, -0.0646, -0.0246, -0.0257,  0.0638, -0.0473, -0.0769,
                      -0.0227, -0.0170,  0.1427, -0.0447,  0.0208,  0.0850,  0.1054, -0.0241,
                       0.0461,  0.0417,  0.0227, -0.0137,  0.1399,  0.0171, -0.1745, -0.0368,
                      -0.0510,  0.0445,  0.0158, -0.0097,  0.0971,  0.0956, -0.0722,  0.0293,
                       0.0754,  0.0827,  0.0133, -0.0830, -0.1335, -0.0672,  0.1274, -0.0367,
                       0.1640, -0.1487,  0.0522,  0.0089,  0.2692, -0.1478,  0.1535, -0.1153,
                 

In [13]:
for a in cp['state_dict'].keys():
    for b in i2s.state_dict().keys():
        if a == b:
            print('a', a)
            print('b', b)
            result = torch.all(cp['state_dict'][a] == i2s.state_dict()[b])
            if not result.item():
                print('fuck')
        elif a == 'ln_img.weight' and b == 'ln.weight':
            result = torch.all(cp['state_dict'][a] == i2s.state_dict()[b])
            if not result.item():
                print('fuck')
        elif a == 'ln_img.bias' and b == 'ln.bias':
            result = torch.all(cp['state_dict'][a] == i2s.state_dict()[b])
            if not result.item():
                print('fuck')

a img_encoder.cls_embedding
b img_encoder.cls_embedding
a img_encoder.embedding_conv.weight
b img_encoder.embedding_conv.weight
a img_encoder.position_embedding.weight
b img_encoder.position_embedding.weight
a img_encoder.ln.weight
b img_encoder.ln.weight
a img_encoder.ln.bias
b img_encoder.ln.bias
a img_encoder.encoder.blocks.0.self_attention.q_proj.weight
b img_encoder.encoder.blocks.0.self_attention.q_proj.weight
a img_encoder.encoder.blocks.0.self_attention.q_proj.bias
b img_encoder.encoder.blocks.0.self_attention.q_proj.bias
a img_encoder.encoder.blocks.0.self_attention.k_proj.weight
b img_encoder.encoder.blocks.0.self_attention.k_proj.weight
a img_encoder.encoder.blocks.0.self_attention.k_proj.bias
b img_encoder.encoder.blocks.0.self_attention.k_proj.bias
a img_encoder.encoder.blocks.0.self_attention.v_proj.weight
b img_encoder.encoder.blocks.0.self_attention.v_proj.weight
a img_encoder.encoder.blocks.0.self_attention.v_proj.bias
b img_encoder.encoder.blocks.0.self_attention.v_pr