In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import os
import json
import re
import pickle
from collections import Counter
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score
import seaborn as sns
from tqdm import tqdm
import random


In [2]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [3]:
class TextPreprocessor:
    """Advanced text preprocessing for text-to-image generation"""

    def __init__(self, vocab_size=1000, max_length=20):
        self.vocab_size = vocab_size
        self.max_length = max_length
        self.word_to_idx = {}
        self.idx_to_word = {}
        self.vocab_built = False

        # Special tokens
        self.PAD_TOKEN = '<PAD>'
        self.UNK_TOKEN = '<UNK>'
        self.START_TOKEN = '<START>'
        self.END_TOKEN = '<END>'

    def clean_text(self, text):
        """Clean and normalize text"""
        # Convert to lowercase
        text = text.lower()

        # Remove special characters except spaces and common punctuation
        text = re.sub(r'[^a-zA-Z0-9\s\-_]', '', text)

        # Remove extra whitespaces
        text = ' '.join(text.split())

        return text

    def tokenize(self, text):
        """Tokenize text into words"""
        cleaned_text = self.clean_text(text)
        tokens = cleaned_text.split()
        return tokens

    def build_vocabulary(self, texts):
        """Build vocabulary from list of texts"""
        # Collect all words
        all_words = []
        for text in texts:
            tokens = self.tokenize(text)
            all_words.extend(tokens)

        # Count word frequencies
        word_counts = Counter(all_words)

        # Select most frequent words
        most_common = word_counts.most_common(self.vocab_size - 4)  # Reserve 4 for special tokens

        # Build vocabulary
        self.word_to_idx = {
            self.PAD_TOKEN: 0,
            self.UNK_TOKEN: 1,
            self.START_TOKEN: 2,
            self.END_TOKEN: 3
        }

        for i, (word, _) in enumerate(most_common):
            self.word_to_idx[word] = i + 4

        # Build reverse mapping
        self.idx_to_word = {idx: word for word, idx in self.word_to_idx.items()}
        self.vocab_built = True

        print(f"Built vocabulary with {len(self.word_to_idx)} words")

    def encode_text(self, text):
        """Encode text to sequence of indices"""
        if not self.vocab_built:
            raise ValueError("Vocabulary not built. Call build_vocabulary first.")

        tokens = self.tokenize(text)

        # Add special tokens
        tokens = [self.START_TOKEN] + tokens + [self.END_TOKEN]

        # Convert to indices
        indices = []
        for token in tokens:
            if token in self.word_to_idx:
                indices.append(self.word_to_idx[token])
            else:
                indices.append(self.word_to_idx[self.UNK_TOKEN])

        # Pad or truncate to max_length
        if len(indices) < self.max_length:
            indices.extend([self.word_to_idx[self.PAD_TOKEN]] * (self.max_length - len(indices)))
        else:
            indices = indices[:self.max_length]

        return torch.tensor(indices, dtype=torch.long)

    def decode_text(self, indices):
        """Decode sequence of indices back to text"""
        if isinstance(indices, torch.Tensor):
            indices = indices.cpu().numpy()

        words = []
        for idx in indices:
            if idx in self.idx_to_word:
                word = self.idx_to_word[idx]
                if word not in [self.PAD_TOKEN, self.START_TOKEN, self.END_TOKEN]:
                    words.append(word)

        return ' '.join(words)

In [4]:
class TextEmbedding(nn.Module):

    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256):
        super(TextEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        # Word embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        # LSTM for sequence processing
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)

        # Attention mechanism
        self.attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=8, batch_first=True)

        # Final projection
        self.projection = nn.Linear(hidden_dim * 2, embedding_dim)

    def forward(self, text_indices):
        # Embed tokens
        embedded = self.embedding(text_indices)  # (batch, seq_len, embed_dim)

        # LSTM processing
        lstm_out, _ = self.lstm(embedded)  # (batch, seq_len, hidden_dim * 2)

        # Self-attention
        attended, _ = self.attention(lstm_out, lstm_out, lstm_out)

        # Global average pooling (ignoring padding)
        mask = (text_indices != 0).float().unsqueeze(-1)  # Padding mask
        attended_masked = attended * mask
        seq_lengths = mask.sum(dim=1, keepdim=True)
        pooled = attended_masked.sum(dim=1) / (seq_lengths + 1e-8)

        # Final projection
        text_embedding = self.projection(pooled)

        return text_embedding


In [5]:
class Generator(nn.Module):
    """Text-conditioned Generator for image synthesis"""

    def __init__(self, latent_dim=100, text_embedding_dim=128, img_channels=3, img_size=64):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.text_embedding_dim = text_embedding_dim
        self.img_channels = img_channels
        self.img_size = img_size

        # Text conditioning layer
        self.text_projection = nn.Sequential(
            nn.Linear(text_embedding_dim, latent_dim),
            nn.ReLU(inplace=True)
        )

        # Generator network
        self.model = nn.Sequential(
            # Input: latent_dim + latent_dim (for text)
            nn.Linear(latent_dim * 2, 256 * 8 * 8),
            nn.BatchNorm1d(256 * 8 * 8),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (256, 8, 8)),

            # Upsampling layers
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),  # 64x64
            nn.Tanh()  # Output between -1 and 1
        )

    def forward(self, z, text_embedding):
        # Project text embedding to match latent dimension
        text_proj = self.text_projection(text_embedding)

        # Concatenate latent vector and text projection
        combined_input = torch.cat((z, text_proj), dim=1)

        # Generate image
        img = self.model(combined_input)

        return img