# 🤗 Multimodal generation - Part 2: text and image representation dataset 🤗

In [1]:
cd ..

/Users/simonbrandeis/dev/multimodal-code-exercise


In [2]:
import os
import torch
import transformers
from pytorch_pretrained_biggan import BigGAN
from transformers import AutoTokenizer, AutoModel

## Parameter definition

In [3]:
# Directory where data is to be stored and loaded
DATA_DIR = "./data/"

# Name of the pretrained language model to use
PRETRAINED_LM_NAME = "distilbert-base-uncased"

# Name of the pretrained image GAN model to use
PRETRAINED_GAN_NAME = "biggan-deep-128"

## Load the data

In [4]:
labels_tensor = torch.load(os.path.join(DATA_DIR, "labels_tensor.bin"))
tokens_tensor = torch.load(os.path.join(DATA_DIR, "tokens_tensor.bin"))

## Instanciate the pretrained models: GAN and Language Model

In [5]:
lm_model = AutoModel.from_pretrained(PRETRAINED_LM_NAME)
lm_tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM_NAME)

gan_model = BigGAN.from_pretrained(PRETRAINED_GAN_NAME)

## Compute the representations

In [6]:
from text2img.data.transform import text_tokens_to_representation, image_label_to_representation

In [7]:
labels_representations_tensor = image_label_to_representation(labels_tensor=labels_tensor,
                                                              embedding_module=gan_model.embeddings,
                                                              batch_size=len(labels_tensor))

100%|██████████| 1/1 [00:00<00:00,  7.81it/s]


In [8]:
text_representations_tensor = text_tokens_to_representation(tokens_tensor=tokens_tensor,
                                                           language_model=lm_model,
                                                           batch_size=32)

100%|██████████| 187/187 [03:02<00:00,  1.02it/s]


## Save the representations

In [9]:
torch.save(labels_representations_tensor, os.path.join(DATA_DIR, "labels_representations_tensor.bin"))
torch.save(text_representations_tensor, os.path.join(DATA_DIR, "text_representations_tensor.bin"))

# Discussion on how to improve the generation

Our label representation tensors lack diversity, which will damage the learning afterwards.

To overcome this, we can either:
- Add noise to the representation vectors to make them less similar
- Ideally, use the BigGAN encoder to compute embeddings of images from the Internet