# Import Libraries

In [None]:
!pip install -q diffusers

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
from diffusers import StableDiffusionPipeline
from tqdm import tqdm

# Load Data

In [None]:
df = pd.read_csv('/kaggle/input/grimms-fairy-tales/grimms_fairytales.csv', index_col=0)
df.head()

In [None]:
df.shape

# Pipeline

In [None]:
pipe1 = StableDiffusionPipeline.from_pretrained(
    'CompVis/stable-diffusion-v1-4', revision='fp16', torch_dtype=torch.float16
)
pipe1.to('cuda')

In [None]:
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

# Generate

In [None]:
device = 'cuda'
seed = 7

In [None]:
for i, item in df[:1].iterrows():
    test = item['Text']
    inputs = bart_tokenizer([test], max_length=1024, return_tensors="pt")
    summary_ids = bart_model.generate(inputs['input_ids'], num_beams=4, max_length=100, early_stopping=True)
    summary = ([bart_tokenizer.decode(i, skip_special_tokens=True, clean_up_tokenization_spaces=False) for i in summary_ids])
    prompt = summary[0]
    
    images = pipe1(
        prompt=prompt,
        generator=torch.Generator('cuda').manual_seed(seed)
    ).images
    
    plt.figure(figsize=(5, 5))
    plt.imshow(images[0])
    plt.title(item['Title'])
    plt.axis('off')
    plt.show()

In [None]:
# For all dataset
#device = 'cuda'
#fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(20, 20))
#for idx, item in tqdm(df.iterrows()):
#    test = item['Text']
#    inputs = bart_tokenizer([test], max_length=1024, return_tensors="pt")
#    summary_ids = bart_model.generate(inputs['input_ids'], num_beams=4, max_length=100, early_stopping=True)
#    summary = ([bart_tokenizer.decode(i, skip_special_tokens=True, clean_up_tokenization_spaces=False) for i in summary_ids])
#    prompt = summary[0]
#
#    images = pipe1(
#        prompt=prompt,
#        generator=torch.Generator('cuda').manual_seed(seed)
#    ).images
#
#    axes[idx // 8][idx % 8].imshow(images[0])
#    axes[idx // 8][idx % 8].set_xticks([])
#    axes[idx // 8][idx % 8].set_yticks([])
#    axes[idx // 8][idx % 8].set_title(item['Title'])
#plt.show()