# Emoji ruDALL-E

![](https://huggingface.co/sberbank-ai/rudalle-Emojich/resolve/main/pics/emojich_rgba_100.png)

Model was trained by [Sber AI](https://github.com/sberbank-ai) using pretrained [ruDALL-E (XL) Malevich](https://www.kaggle.com/shonenkov/rudalle-example-generation)
* Task: `text2image generation`
* Num Parameters: `1.3 B`
* Training Data Volume: `120 million text-image pairs` & [`2749 text-emoji pairs`](https://www.kaggle.com/shonenkov/russian-emoji)

[![Telegram](https://img.shields.io/badge/Telegram-Stickers-blue?style=for-the-badge&logo=data:image/svg%2bxml;base64,PHN2ZyBlbmFibGUtYmFja2dyb3VuZD0ibmV3IDAgMCAyNCAyNCIgaGVpZ2h0PSI1MTIiIHZpZXdCb3g9IjAgMCAyNCAyNCIgd2lkdGg9IjUxMiIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48cGF0aCBkPSJtOS40MTcgMTUuMTgxLS4zOTcgNS41ODRjLjU2OCAwIC44MTQtLjI0NCAxLjEwOS0uNTM3bDIuNjYzLTIuNTQ1IDUuNTE4IDQuMDQxYzEuMDEyLjU2NCAxLjcyNS4yNjcgMS45OTgtLjkzMWwzLjYyMi0xNi45NzIuMDAxLS4wMDFjLjMyMS0xLjQ5Ni0uNTQxLTIuMDgxLTEuNTI3LTEuNzE0bC0yMS4yOSA4LjE1MWMtMS40NTMuNTY0LTEuNDMxIDEuMzc0LS4yNDcgMS43NDFsNS40NDMgMS42OTMgMTIuNjQzLTcuOTExYy41OTUtLjM5NCAxLjEzNi0uMTc2LjY5MS4yMTh6IiBmaWxsPSIjMDM5YmU1Ii8+PC9zdmc+)](https://telegram.me/addstickers/SberAI_ruDALLE)

Authors:

+ [Alex Shonenkov](https://kaggle.com/shonenkov)
+ [Daria Bakshandaeva](https://kaggle.com/dariabakshandaeva)
+ [Denis Dimitrov](https://kaggle.com/ddimitrov)

In [1]:
!pip install rudalle==0.4.0 > /dev/null
!pip install bitsandbytes-cuda110==0.25.0 > /dev/null
!pip install timm==0.4.12 > /dev/null

In [1]:
import multiprocessing
import torch
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device.type)

!nvidia-smi

In [1]:
# PARAMS:
N_EPOCHS = 40
EARLY_STOP = True
MAX_LR = 1e-5
BATCH_SIZE = 1 # used 2 in my experiments 
LOSS_IMG_WEIGHT = 10**3
SAVE_EVERY = 2000
FINAL_DIV_FACTOR = 500

In [1]:
import os
import PIL
import random
import torchvision
import numpy as np
import pandas as pd
import seaborn as sns
import bitsandbytes as bnb
import torchvision.transforms as T
from PIL import Image
from tqdm.auto import tqdm
from collections import Counter
from wordcloud import WordCloud
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip, convert_emoji_to_rgba, show_rgba
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip, get_emojich_unet, utils
from rudalle.utils import seed_everything
seed_everything(42)

# [Emoji Dataset](https://www.kaggle.com/shonenkov/russian-emoji)

In [1]:
def merge_pil_images(pil_images, nrow=16):
    merged_images = [pil_image.resize((128, 128)) for pil_image in pil_images]
    merged_images = utils.pil_list_to_torch_tensors(merged_images)
    merged_images = torchvision.utils.make_grid(merged_images, nrow=nrow)
    merged_images = torchvision.transforms.functional.to_pil_image(merged_images.detach())
    return merged_images


class EmojiDataset(Dataset):

    def __init__(self, df, data_dir, tokenizer, text_seq_length=128, scale_ratio=1.0):       
        self.data_dir = data_dir
        self.text_seq_length = text_seq_length
        self.tokenizer = tokenizer
        self.image_size = 256
        self.samples = []
        self.image_transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.RandomResizedCrop(self.image_size, scale=(scale_ratio, 1.), ratio=(1.0, 1.)),
            T.ToTensor()
        ])
        self.texts = df['text'].values
        self.image_ids = df.index.values

    def __len__(self):
        return self.texts.shape[0]

    def __getitem__(self, idx):
        image_id, text = self.image_ids[idx], self.texts[idx]
        text = text.lower()
        image = PIL.Image.open(f'{self.data_dir}/{image_id}.png')
        image = self.image_transform(image)
        text = self.tokenizer.encode_text(text, text_seq_length=self.text_seq_length)
        return text, image

In [1]:
df = pd.read_csv('../input/russian-emoji/marking.csv', index_col='image_id')
df['text'] = df['text'].str.lower()
df['text_length'] = df['text'].apply(len)
df['word_count'] = df['text'].apply(lambda x: len(x.split()))
df.head()

In [1]:
tokenizer = get_tokenizer()
train_dataset = EmojiDataset(df, data_dir='../input/russian-emoji/images', tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [1]:
idx = random.randint(0, len(train_dataset)-1)
encoded, image = train_dataset[idx]

print(tokenizer.decode_text(encoded))

plt.imshow(image.permute(1,2,0).cpu().numpy());

In [1]:
wc, c = WordCloud(), Counter()
for text in df['text']:
    c.update(wc.process_text(text))    
wc.fit_words(c)
plt.figure(figsize=(7,7));
plt.imshow(wc, interpolation='bilinear');
plt.axis("off");

In [1]:
wc, c = WordCloud(), Counter()
for text in df['name']:
    c.update(wc.process_text(text))    
wc.fit_words(c)
plt.figure(figsize=(7,7));
plt.imshow(wc, interpolation='bilinear');
plt.axis("off");

In [1]:
text_value_counts = pd.DataFrame(df['text'].value_counts())
ax = sns.histplot(data=text_value_counts, x="text");
ax.set_title('Duplicated text count histogram');
ax.set_xlabel('duplicates count');

In Russian language these entities have different concept, but have equal text \
Examples of emoji with 2 duplicated text:

In [1]:
pil_images = []
texts = text_value_counts[text_value_counts['text'] == 2].index.values
for text in texts[:4]:
    image_ids = df[df['text'] == text].index
    pil_images += [Image.open(f'../input/russian-emoji/images/{image_id}.png') for image_id in image_ids]
merge_pil_images(pil_images, 2)

Many samples with different skin color \
Examples of emoji with 6 duplicated text:

In [1]:
pil_images = []
texts = text_value_counts[text_value_counts['text'] == 6].index.values
for text in texts[25:35]:
    image_ids = df[df['text'] == text].index
    pil_images += [Image.open(f'../input/russian-emoji/images/{image_id}.png') for image_id in image_ids]
merge_pil_images(pil_images, 6)

text length distribution by emoji group:

In [1]:
g = sns.displot(df, x="text_length", col="group",
    binwidth=2, height=4, aspect=1, facet_kws=dict(margin_titles=True),
    col_wrap=3, kde=True);

word count distribution by emoji group:

In [1]:
g = sns.displot(df, x="word_count", col="group",
    binwidth=0.5, height=4, aspect=1, facet_kws=dict(margin_titles=True),
    col_wrap=3, kde=True);

# Train [Emojich](https://huggingface.co/sberbank-ai/rudalle-Emojich)

In [1]:
def train(model, vae, optimizer, scheduler, train_loader, n_epochs, save_every=2000):
    os.makedirs('/kaggle/working/saved_models', exist_ok=True)
    model.train()
    vae.eval()
    device = model.get_param('device')
    loss_logs = []
    progress = tqdm(total=len(train_loader), desc='finetuning goes brrr')
    save_counter = 0
    for epoch in range(n_epochs):
        for encoded, images in train_loader:
            bs = images.shape[0]
            save_counter+=1
            if EARLY_STOP and save_counter > 100:
                print('Stopped early')
                return
            model.zero_grad()
            optimizer.zero_grad()
            attention_mask = torch.tril(torch.ones((bs, 1, 1152, 1152), device=device))
            with torch.no_grad():
                codebooks = vae.get_codebook_indices(images.to(device))
            input_ids = torch.cat((encoded.to(device), codebooks.long()), dim=1)
            loss, loss_values = model.forward(input_ids, attention_mask, return_loss=True)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            loss.backward()
            optimizer.step()
            scheduler.step()     
            if save_counter % save_every == 0:
                print(f'Saveing checkpoint here Emojich_{save_counter}.pt')
                plt.plot(loss_logs)
                plt.show()
                torch.save(model.state_dict(), os.path.join('/kaggle/working/saved_models', f"Emojich_{save_counter}.pt"))
            loss_logs += [loss.detach().item()]
            progress.update()
            progress.set_postfix({"loss": loss.item()})
    plt.plot(loss_logs)
    plt.show()
    torch.save(model.state_dict(), os.path.join('/kaggle/working/saved_models', "Emojich_last.pt"))

In [1]:
def freeze(
    model,
    freeze_emb=False,
    freeze_ln=False,
    freeze_attn=True,
    freeze_ff=True,
    freeze_other=False,
):
    for name, p in model.module.named_parameters():
        name = name.lower()
        if 'ln' in name or 'norm' in name:
            p.requires_grad = not freeze_ln
        elif 'embeddings' in name:
            p.requires_grad = not freeze_emb
        elif 'mlp' in name:
            p.requires_grad = not freeze_ff
        elif 'attn' in name:
            p.requires_grad = not freeze_attn
        else:
            p.requires_grad = not freeze_other
    return model

In [1]:
device = 'cuda'
model = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device, loss_img_weight=LOSS_IMG_WEIGHT, mlp_activation='gelu')
vae = get_vae().to(device)

In [1]:
model = freeze(model)

torch.cuda.empty_cache()

optimizer = bnb.optim.Adam8bit(model.parameters(), lr=MAX_LR)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=MAX_LR,
    steps_per_epoch=len(train_loader), 
    epochs=N_EPOCHS,
    final_div_factor=FINAL_DIV_FACTOR,
)

train(model, vae, optimizer, scheduler, train_loader, n_epochs=N_EPOCHS, save_every=SAVE_EVERY)

In [1]:
import gc
def _optimizer_to(optimizer, device):
    for param in optimizer.state.values():
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

cpu_device = torch.device('cpu')
_optimizer_to(optimizer, cpu_device)
del optimizer
del scheduler
gc.collect()
torch.cuda.empty_cache()

# Generation with [Emojich](https://huggingface.co/sberbank-ai/rudalle-Emojich)

In [1]:
checkpoint = torch.load('../input/emojich/pytorch_model.bin', map_location='cpu')
model.load_state_dict(checkpoint)

In [1]:
pil_images = []
for text in [
    'зима в горах', # winter in the mountain
    'флаг зомби апокалипсиса', # zombie apocalypse flag
    'флаг Cбербанка', # Sberbank flag
    'храм Василия Блаженного',  # St. basil 's cathedral
    'вишневая девятка',  # cherry lada 2109
    'Дональд Трамп из лего',  # Donald Trump from LEGO
    'человек кушает яблоко',  # a human eats an apple
    'ежик в голубой шапке',  # hedgehog in a blue hat
    'волк в овечьей шкуре',  # a wolf in sheep's clothing
    'кролик синего цвета',  # blue rabbit
    'розовая альпака улыбается',  # pink alpaca smiles
    'арфа в форме улитки',  # a snail-shaped harp
]:
    seed_everything(42)
    for top_k, top_p, images_num in [
        (2048, 0.995, 16),
    ]:
        pil_images += generate_images(text, tokenizer, model, vae, top_k=top_k, images_num=images_num, top_p=top_p, bs=8)[0]

In [1]:
merge_pil_images(pil_images, 16)

In [1]:
model.to(cpu_device)
vae.to(cpu_device)
del model
del vae
gc.collect()
torch.cuda.empty_cache()

# [Telegram Stickers](https://telegram.me/stickers)

Preparing emojis for [Telegram Stickers](https://telegram.me/stickers) format (512x512, RGBA) using Unet - model was trained on pseudo-labeled emojis generated with "[Emojich](https://huggingface.co/sberbank-ai/rudalle-Emojich)" 

In [1]:
device = 'cuda'
realesrgan = get_realesrgan('x2', device=device)
emojich_unet = get_emojich_unet('unet_effnetb5').to(device)

In [1]:
sr_images = super_resolution(pil_images, realesrgan)
rgba_images, _ = convert_emoji_to_rgba(sr_images, emojich_unet,  device=device)

In [1]:
for i in range(12):
    show_rgba(rgba_images[i*16+1])