# Download flickr dataset

In [None]:
!wget --no-check-certificate 'https://storage.googleapis.com/kaggle-data-sets/771078/1328792/compressed/flickr8k.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20220828%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220828T123221Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=2da96ff4d583711c087af4482eb38f93034c0c9d78e743b4a3520b6480ef3d25998461288db8e5c23b18ad9270bec9d16d68ddd07aa8a20886f26f69f65afdb17b191cebda77a46b8d7fa9f78ab59359d935b544f1a434cfe472b943ed4293d74a1f821b4ca04c80d745791d9cafceb395cfe81006007afc9ab5f292ebb29954b1c7df69c47b556f1f71fef89ddfb924802df066f704916b81dfa36f978f0d70ebe0d71012b7c1e1d4937b3ce36520bfb71a818f449395297015e2dfec6ed2f29fb2b652469134a90fa4143e5095f77645632c6ebe4814fff501569badecab290d50d8316b3ea48667fdb7ab9f9fc86666e07026bf1470b41a0151de1b0e0323' -o flickr8k.zip
!unzip -q flickr8k.zip -d flickr8k


# Imports

In [1]:
from PIL import Image
import pandas as pd
import copy
from torchvision.datasets import CocoCaptions
from transformers import (
  DistilBertTokenizer, DistilBertForMaskedLM, DistilBertConfig,
  CLIPProcessor, CLIPModel as CLIP, CLIPConfig
)
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
import tqdm
import matplotlib.pyplot as plt
import os

if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)
print("using device: ", dev)


  from .autonotebook import tqdm as notebook_tqdm


using device:  cpu


In [None]:
# download pretrained model and tokenizer
def save_model_tokenizer(tokenizer_class, model_class, name):
  if tokenizer_class is not None:
    tokenizer = tokenizer_class.from_pretrained(name)
    tokenizer.save_pretrained(f"./tokenizers/{name}-local")
  if model_class is not None:
    model = model_class.from_pretrained(name)
    model.save_pretrained(f"./models/{name}-local/")

save_model_tokenizer(CLIPProcessor, CLIP, "openai/clip-vit-base-patch32")


# Hyperparameters

In [3]:
# hyperparameters
BATCH_SIZE = 64
MAX_LENGTH = 128 # max text length
LEARNING_RATE = 5e-5
EPOCH_NUM = 1
ROUNDING_WEIGHT = 0.3 # weight of rounding term, the probability of regenerated sequence 

# diffusion hyperparameter
BETA_MIN = 0.0001
BETA_MAX = 0.02
STEP_TOT = 2000 # total noise adding steps
COSIN_SCHEDULE = True # if alpha sequence is scheduled in cosin instead of linear patten
SAMPLE_SIZE = 3 # number of sample steps in each diffuse sequence
X_0_PREDICTION = True # if model predicts x_0 or x_{t-1}

# Model, trainer and loss function

In [4]:
class DistilBertModel(nn.Module):
  def __init__(self, projection, config=None) -> None:
    '''
    inputs:
      projection: torch.tensor
      config
    '''
    super().__init__()

    self.model = DistilBertForMaskedLM(config).to(device)

    self.projection = nn.Linear(projection.shape[-2], projection.shape[-1], device=device)
    self.projection.weight.data = projection
    self.projection.bias.data = torch.zeros(self.projection.bias.data.shape, device=device)
    self.projection.requires_grad_(False)
    
    self.model.set_input_embeddings(nn.Sequential())
    self.model.set_output_embeddings(nn.Sequential())

  def parameters(self):
    return self.model.parameters()
  
  def forward(self, x, mask):
    '''
    return 
      feature_out, shape: [batch_size, seq_len, dim]
      vocab_out, shape: [batch_size, seq_len, vocab_size]
    '''
    
    x_out = self.model(x, mask)[0]
    return self.projection(x_out), x_out

clip_processor = CLIPProcessor.from_pretrained("./tokenizers/openai/clip-vit-base-patch32-local")
clip = CLIP.from_pretrained("./models/openai/clip-vit-base-patch32-local")

configuration = DistilBertConfig(vocab_size=clip_processor.tokenizer.vocab_size, dim=clip.projection_dim, n_heads=8)
model = DistilBertModel(clip.get_submodule("text_model.embeddings.token_embedding").weight.data, config=configuration)

# parameter only include model, no embedding layer
# trainer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
trainer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)


# Define Dataset

In [52]:
# OLD dataloader, slow load from origin data
class Flickr8kCLIPDataset(torch.utils.data.Dataset):
  def __init__(self, dir, clip_processor, clip) -> None:
    self.dir = dir
    self.caption = pd.read_csv(f"{dir}/captions.txt")

    self.clip = clip
    self.clip_processor = clip_processor

  def collate_fn(self, batch):
    images = []
    captions = []
    for b in batch:
      images.append(Image.open(f"{self.dir}/images/{b['image']}"))
      captions.append(b["caption"])

    inputs = self.clip_processor(text=captions, images=images, return_tensors="pt", padding=True)
    outputs = self.clip(**inputs)

    return outputs.text_embeds, outputs.image_embeds

  def __len__(self):
    return len(self.caption)

  def __getitem__(self, idx):
    return self.caption.loc[idx]

# try on load on request in colect func
train_dataset = Flickr8kCLIPDataset("flickr8k/", clip_processor, clip)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, collate_fn=train_dataset.collate_fn)

In [1]:
train_dataset = torch.utils.data.TensorDataset(torch.load("image_all_final.pickle"), torch.load("text_all_final.pickle"))
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)

NameError: name 'torch' is not defined

In [66]:
len(train_dataset)

40455

In [50]:
with tqdm.tqdm(train_loader) as process:
  for i, _ in enumerate(process):
    if i > 100:
      break

  0%|          | 2/633 [00:22<1:59:09, 11.33s/it]


TypeError: 'NoneType' object is not subscriptable

In [62]:
image_all = torch.tensor([]).reshape((0, 512))
text_all = torch.tensor([]).reshape((0, 512))

In [60]:
start = 1000
end = 1010
cnt = 1
with tqdm.tqdm(range(start, end)) as process:
  for i in process:
    image, text = train_dataset[i]
    image_all = torch.vstack([image_all, image])
    text_all = torch.vstack([text_all, text])
cnt += 1

100%|██████████| 10/10 [00:02<00:00,  3.69it/s]


In [65]:
import sys
sys.argv[1]

'--ip=127.0.0.1'

In [55]:
image_all_temp = image_all.clone().detach()
text_all_temp = text_all.clone().detach()
print(f"saving to cnt = {cnt}")
torch.save(image_all_temp, f"image_all{cnt}.pickle")
torch.save(text_all_temp, f"text_all{cnt}.pickle")

In [59]:
# image_all_final = 
torch.save(image_all1, "image_all1.pickle")
torch.save(text_all1, "text_all1.pickle")

In [123]:
torch.save(image_all, "image_all.pickle")

In [111]:
small_dataset = torch.utils.data.TensorDataset(image_all, text_all)
small_dataloader = DataLoader(small_dataset, shuffle=True, batch_size=BATCH_SIZE)

In [112]:
with tqdm.tqdm(small_dataloader) as process:
  for image, text in process:
    continue
    # image_all = torch.vstack([image_all, image])
    # text_all = torch.vstack([text_all, text])

100%|██████████| 11/11 [00:00<00:00, 527.91it/s]


In [119]:
image_dict = torch.load("image_dict.pickle")
for k in image_dict:
  break
image_dict[k]["pixel_values"].shape

torch.Size([1, 3, 224, 224])

In [114]:
'''TODO: COCO dataset'''

# import torchvision.transforms as transforms
# cap = CocoCaptions(root = 'dir where images are',
#                         annFile = 'json annotation file',
#                         transform=transforms.PILToTensor())

# print('Number of samples: ', len(cap))
# img, target = cap[3] # load 4th sample

# print("Image Size: ", img.size())
# print(target)

'TODO: COCO dataset'

In [45]:
text = ["a photo of a cat", "a photo of a warship", "a photo of a boy", "a photo of a girl"]
image = Image.open("Shropshire.jpeg")
# plt.imshow(image)
# plt.show()
inputs = clip_processor(text=text, images=image, return_tensors="pt", padding=True)
inputs_no = clip_processor(text=text, images=None, return_tensors="pt", padding=True)

print(inputs.keys(), torch.all(inputs.input_ids == inputs_no.input_ids))

outputs = clip.get_text_features(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
# outputs = clip(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], pixel_values=torch.zeros((3,1,1)))

# outputs = clip(**inputs)
print(outputs.text_embeds.shape, outputs.image_embeds.shape, 
outputs.text_model_output["last_hidden_state"].shape, outputs.text_model_output["pooler_output"].shape, 
outputs.vision_model_output["last_hidden_state"].shape, outputs.vision_model_output["pooler_output"].shape, 
)
outputs.text_embeds, outputs.image_embeds, outputs.text_model_output, outputs.vision_model_output
# logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
# probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities

# print(probs, text[probs.argmax(dim=-1)[0].item()])

dict_keys(['input_ids', 'attention_mask', 'pixel_values']) tensor(True)
tensor([[ 0.1555,  0.0733, -0.2448,  ..., -0.5327, -0.4588,  0.0346],
        [-0.0583,  0.1287,  0.1936,  ..., -0.2737,  0.0333, -0.3535],
        [-0.0936,  0.3672, -0.1155,  ..., -0.8218, -0.2455, -0.1423],
        [-0.1143,  0.2267, -0.2819,  ..., -0.6799, -0.2747,  0.1227]],
       grad_fn=<MmBackward0>)


ValueError: You have to specify pixel_values