In [1]:
import torch.nn as nn
import numpy as np
import torch
import yaml
from torch.autograd import Variable
from torch.utils.data import DataLoader
from txt2image_dataset import Text2ImageDataset
from utils import Utils, Logger
from PIL import Image
import os



import matplotlib.pyplot as plt
%matplotlib inline

from transformers import BertTokenizer, BertModel

In [2]:
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

config

{'birds_images_path': 'data/cvpr2016_cub/images/',
 'birds_embedding_path': 'data/cub_icml/',
 'birds_text_path': 'data/cvpr2016_cub/cvpr2016_cub/text_c10/',
 'birds_dataset_path': 'data/cvpr2016_cub/text2image/birds.hdf5',
 'val_split_path': 'data/cvpr2016_cub/valclasses.txt',
 'train_split_path': 'data/cvpr2016_cub/trainclasses.txt',
 'test_split_path': 'data/cvpr2016_cub/testclasses.txt',
 'flowers_images_path': 'data/cvpr2016_flowers/images/',
 'flowers_embedding_path': 'data/flowers_icml/',
 'flowers_text_path': 'data/cvpr2016_flowers/text_c10/',
 'flowers_dataset_path': 'data/cvpr2016_flowers/text2image/flowers.hdf5',
 'flowers_val_split_path': 'data/cvpr2016_flowers/valclasses.txt',
 'flowers_train_split_path': 'data/cvpr2016_flowers/trainclasses.txt',
 'flowers_test_split_path': 'data/cvpr2016_flowers/testclasses.txt'}

In [9]:
dataset = Text2ImageDataset(config['flowers_dataset_path'], split=0)


data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

for sample in data_loader:
    print(sample['right_embed'], sample['txt'])
    break

tensor([[ 0.0538,  0.0227, -0.1054,  ..., -0.1206,  0.0756,  0.0535],
        [ 0.1383, -0.0386,  0.0401,  ...,  0.0193,  0.0976,  0.0465],
        [-0.0472,  0.1055,  0.1817,  ..., -0.0108,  0.1234,  0.0780],
        ...,
        [ 0.1191, -0.0131,  0.1247,  ...,  0.0598,  0.0572,  0.1181],
        [ 0.1276,  0.0083,  0.0543,  ..., -0.0836,  0.0926,  0.0449],
        [-0.0887,  0.1301,  0.1238,  ..., -0.3147,  0.2946,  0.1418]]) ['this flower has pink flowers and white streaks, white center.\n', 'this flower has petals that are pink with yellow stamen\n', 'this flower has petals that are yellow and has a stringy stamen\n', 'this flower resembles an ear of corn, with long yellow petals at the base and a cluster of tight, dark red petals extending from the center.\n', 'this flower has green sepals with large smooth petals that are orange in color.\n', 'this flower has lightly colored purple petals with a smooth surface and slightly uneven edges.\n', 'this particular flower has petals th

In [12]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


for sample in data_loader:
    text = sample['txt'][0]
    marked_text = "[CLS] " + text + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
    
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    
    for tup in zip(tokenized_text, indexed_tokens):
        print('{:<12} {:>6,}'.format(tup[0], tup[1]))
    
    break

[CLS]           101
the           1,996
flower        6,546
is            2,003
white         2,317
and           1,998
funnel       25,102
shaped        5,044
and           1,998
also          2,036
the           1,996
pet           9,004
##al          2,389
is            2,003
supported     3,569
by            2,011
green         2,665
sep          19,802
##al          2,389
[SEP]           102


In [13]:
# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True, # Whether the model returns all hidden-states.
                                  )

# Put the model in "evaluation" mode, meaning feed-forward operation.
model.eval()

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/440M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exact

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          