<a href="https://colab.research.google.com/github/podoisthebestdog/EasyOCR/blob/master/recapd_full_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RecapD 기반 텍스트-이미지 생성 모델

# **1. library import**

1-1. pytorch 신경망 구조 및 학습 핵심 모듈 import

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

1-2. 이미지 처리 및 coco 캡션 데이터셋 사용을 위한 torchvision tools

In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import CocoCaptions
from torchvision.utils import save_image
from torchvision.transforms.functional import resize
from torchvision.models import inception_v3

1-3. normal utilities + 평가지표 계산용 CLIP + GUI용 streamlit

In [None]:
import numpy as np
from PIL import Image
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity
import clip
import streamlit as st
from torchvision.transforms import ToPILImage

# **2. 텍스트 전처리와 vocab 생성**

2-1. 한글을 영어로 번역

2-2.텍스트를 소문자로 바꾸고 쉽표 제거 후 단어 단위로 나눔

In [None]:
def simple_tokenizer(text):
    return text.strip().lower().replace('.', '').replace(',', '').split()

2-3. COCO 캡션 전체에서 자주 나오는 단어를 기반으로 vocab(어휘사전)생성

In [None]:
def build_vocab(dataset, tokenizer, min_freq=1):
    word_freq = defaultdict(int)
    for _, captions in dataset:
        for caption in captions:
            for token in tokenizer(caption):
                word_freq[token] += 1
    #특수토큰 포함
    vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
    index = 4
    for word, freq in word_freq.items():
        if freq >= min_freq:
            vocab[word] = index
            index += 1
    return vocab

# **3. coco data set**

이미지와 캡션을 하나의 텍스트-이미지 쌍으로 불러옴

캡션을 토큰화하고 vocab index로 변환하여 고정 길이 시퀀스 생성

In [None]:
class CocoText2ImageDataset(Dataset):
    def __init__(self, image_dir, ann_file, tokenizer, vocab, transform=None, max_length=20):
        self.dataset = CocoCaptions(root=image_dir, annFile=ann_file, transform=transform)
        self.tokenizer = tokenizer
        self.vocab = vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, captions = self.dataset[idx]
        caption = captions[0]
        tokens = self.tokenizer(caption.lower())[:self.max_length]
        token_ids = [self.vocab.get(token, self.vocab['<unk>']) for token in tokens]
        token_ids = [self.vocab['<sos>']] + token_ids + [self.vocab['<eos>']]
        token_ids += [self.vocab['<pad>']] * (self.max_length + 2 - len(token_ids))
        return image, torch.tensor(token_ids)

# **4. 텍스트인코더**

BiLSTM 기반 문장 인코더

단어 시퀀스를 받아 양방향으로 처리한 후 마지막 hidden state를 결합하여 문장 임베딩 생성

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(TextEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)

    def forward(self, captions):
        embeds = self.embedding(captions)
        _, (h_n, _) = self.lstm(embeds)
        return torch.cat((h_n[0], h_n[1]), dim=-1)

# **5. IMAGE GENERATOR**

문장 벡터와 랜덤 노이즈 벡터를 입력받아 이미지를 생성

Linear로 시작해 4×4 텐서를 만든 후 ConvTranspose2d로 업샘플링



In [None]:
class Generator(nn.Module):
    def __init__(self, text_dim, z_dim, ngf):
        super(Generator, self).__init__()
        self.fc = nn.Linear(text_dim + z_dim, ngf * 8 * 4 * 4)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, text_feat, z):
        x = torch.cat((text_feat, z), dim=1)
        x = self.fc(x).view(-1, 512, 4, 4)
        return self.conv_blocks(x)


# **6.Discriminator + Captioning Head**

6-1. 이미지 특정 맵으로 인코딩_CNN 기반

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, ndf):
        super(ImageEncoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.conv(x)

6-2. 이미지 특징을 입력받아 캡션 단어를 순서대로 생성_Transformer 기반 캡셔닝 모듈

In [None]:
class CaptionDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, max_len=20):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_dim))
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, tgt_seq, memory):
        tgt_embed = self.embedding(tgt_seq) + self.pos_embedding[:, :tgt_seq.size(1), :]
        tgt_embed = tgt_embed.permute(1, 0, 2)
        memory = memory.permute(1, 0, 2)
        tgt_mask = nn.Transformer().generate_square_subsequent_mask(tgt_seq.size(1)).to(tgt_seq.device)
        output = self.transformer_decoder(tgt_embed, memory, tgt_mask)
        return self.fc_out(output.permute(1, 0, 2))


6-3. 판별자(Discriminator) + 이미지 캡셔닝을 함께 수행

생성된 이미지가 텍스트와 일치하는지 판별 + 캡션까지 생성하여 일치성 강화

In [None]:
class RecapD(nn.Module):
    def __init__(self, text_dim, ndf, vocab_size, cap_embed=256):
        super().__init__()
        self.image_encoder = ImageEncoder(ndf)
        self.fc_text = nn.Linear(text_dim, ndf * 8)
        self.disc_head = nn.Conv2d(ndf * 8 + ndf * 8, 1, 4)
        self.caption_decoder = CaptionDecoder(vocab_size, cap_embed, 4, 512, 2)
        self.fc_vis2seq = nn.Linear(ndf * 8 * 4 * 4, cap_embed)

    def forward(self, image, text_feat, tgt_seq):
        v = self.image_encoder(image)
        t = self.fc_text(text_feat).unsqueeze(2).unsqueeze(3).expand_as(v)
        joint = torch.cat((v, t), dim=1)
        disc_score = self.disc_head(joint).view(-1)
        memory = self.fc_vis2seq(v.view(v.size(0), -1).unsqueeze(1))
        cap_logits = self.caption_decoder(tgt_seq, memory)
        return disc_score, cap_logits

# **7. 평가지표 FID, R-Precision**



7-1.FID_생성 이미지와 실제 이미지의 분포 차이를 측정하는 지표 (Fréchet Inception Distance)

In [None]:
def compute_fid(fake_features, real_features):
    mu1, mu2 = fake_features.mean(0), real_features.mean(0)
    sigma1, sigma2 = torch.cov(fake_features.T), torch.cov(real_features.T)
    diff = mu1 - mu2
    covmean = torch.linalg.sqrtm((sigma1 @ sigma2).cpu()).real
    fid = diff.dot(diff) + torch.trace(sigma1 + sigma2 - 2 * covmean)
    return fid.item()

7-2.생성 이미지와 텍스트의 일치도를 평가

CLIP 모델로 이미지/텍스트 임베딩을 구하고 cosine 유사도로 측정

In [None]:
def compute_r_precision(model, image_tensor_list, text_list, device):
    model.eval()
    image_features, text_features = [], []
    with torch.no_grad():
        for img in image_tensor_list:
            img = resize(img, (224, 224)).unsqueeze(0).to(device)
            image_feat = model.encode_image(img).cpu()
            image_features.append(image_feat)
        for text in text_list:
            text_tokens = clip.tokenize([text]).to(device)
            text_feat = model.encode_text(text_tokens).cpu()
            text_features.append(text_feat)
    sims = cosine_similarity(torch.cat(image_features), torch.cat(text_features))
    ranks = sims.argsort(axis=1)[:, -1]
    return np.mean([i == rank for i, rank in enumerate(ranks)])


#**8. Steamlit GUI**

8-1. 사용자가 입력한 텍스트로부터 이미지를 생성하는 함수

In [None]:
@torch.no_grad()
def generate_image(model, text_encoder, generator, vocab, tokenizer, text, device):
    model.eval()
    generator.eval()
    token_ids = [vocab.get(tok, vocab['<unk>']) for tok in tokenizer(text.lower())]
    token_ids = [vocab['<sos>']] + token_ids + [vocab['<eos>']]
    token_ids += [vocab['<pad>']] * (22 - len(token_ids))
    tokens = torch.tensor(token_ids).unsqueeze(0).to(device)
    text_feat = text_encoder(tokens)
    z = torch.randn(1, 100).to(device)
    fake_image = generator(text_feat, z)[0].cpu()
    return ToPILImage()(fake_image.add(1).div(2).clamp(0, 1))

8-2. Streamlit을 사용해 간단한 웹 앱 GUI 구성

텍스트 입력창, 버튼, 생성 이미지 출력 포함


In [None]:
def run_gui(text_encoder, generator, vocab, tokenizer, device):
    st.title("RecapD 텍스트-이미지 생성기")
    user_text = st.text_input("텍스트 설명을 입력하세요:", "a yellow fire hydrant on the sidewalk")
    if st.button("이미지 생성"):
        image = generate_image(None, text_encoder, generator, vocab, tokenizer, user_text, device)
        st.image(image, caption="생성된 이미지", use_column_width=True)

# **9. 실행예시**

실제 실행을 위해 필요한 코드

학습된 모델 불러오고 run_gui를 실행하면 GUI 사용 가능

In [None]:
# 실행 예시:
# if __name__ == '__main__':
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     run_gui(text_encoder, generator, vocab, simple_tokenizer, device)
