# Text Encoder

In [3]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class RNNEncoder(nn.Module):
    def __init__(self, ntoken, ninput=300, drop_prob=0.5,
                 nhidden=128, nlayers=1, bidirectional=True):
        super().__init__()
        self.n_steps = 18
        self.ntoken = ntoken  # size of the dictionary
        self.ninput = ninput  # size of each embedding vector
        self.drop_prob = drop_prob  # probability of an element to be zeroed
        self.nlayers = nlayers  # Number of recurrent layers
        self.bidirectional = bidirectional
        if bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1
        # number of features in the hidden state
        self.nhidden = nhidden // self.num_directions

        self.encoder = nn.Embedding(self.ntoken, self.ninput)
        self.drop = nn.Dropout(self.drop_prob)

        # dropout: If non-zero, introduces a dropout layer on
        # the outputs of each RNN layer except the last layer
        self.rnn = nn.LSTM(self.ninput, self.nhidden,
                           self.nlayers, batch_first=True,
                           dropout=self.drop_prob,
                           bidirectional=self.bidirectional)


    def forward(self, captions, cap_lens):
        # input: torch.LongTensor of size batch x n_steps
        # --> emb: batch x n_steps x ninput
        emb = self.drop(self.encoder(captions))
        #
        # Returns: a PackedSequence object
        cap_lens = cap_lens.data.tolist()
        emb = pack_padded_sequence(emb, cap_lens, batch_first=True)
   
        output, hidden = self.rnn(emb)
        # PackedSequence object
        # --> (batch, seq_len, hidden_size * num_directions)
        output = pad_packed_sequence(output, batch_first=True)[0]
        words_emb = output.transpose(1, 2)

        sent_emb = hidden[0].transpose(0, 1).contiguous()
        sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions)
        return sent_emb, words_emb

    @staticmethod
    def load(weights_path: str, ntoken: int) -> 'RNNEncoder':
        text_encoder = RNNEncoder(ntoken, nhidden=256)
        state_dict = torch.load(weights_path, map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(state_dict)
        return text_encoder


# Generator

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

def conv1x1(in_planes, out_planes):
    """1x1 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
                   padding=0, bias=False)

class GlobalAttentionGeneral(nn.Module):
    def __init__(self, idf, cdf):
        super().__init__()
        self.conv_context = conv1x1(cdf, idf)
        self.sm = nn.Softmax(dim=1)
        self.mask = None

    def applyMask(self, mask):
        self.mask = mask  # batch x sourceL

    def forward(self, input, context):
        """
        input: batch x idf x ih x iw (queryL=ihxiw)
        context: batch x cdf x sourceL
        """
        ih, iw = input.size(2), input.size(3)
        queryL = ih * iw
        batch_size, sourceL = context.size(0), context.size(2)

        # --> batch x queryL x idf
        target = input.view(batch_size, -1, queryL)
        targetT = torch.transpose(target, 1, 2).contiguous()
        
        # batch x cdf x sourceL --> batch x cdf x sourceL x 1
        sourceT = context.unsqueeze(3)
        # --> batch x idf x sourceL
        sourceT = self.conv_context(sourceT).squeeze(3)

        # Get attention
        attn = torch.bmm(targetT, sourceT)  # batch x queryL x sourceL
        attn = attn.view(batch_size * queryL, sourceL)
        
        if self.mask is not None:
            mask = self.mask.repeat(queryL, 1)
            attn.data.masked_fill_(mask.data, -float('inf'))
        
        attn = self.sm(attn)
        attn = attn.view(batch_size, queryL, sourceL)
        attn = torch.transpose(attn, 1, 2).contiguous()

        # Apply attention
        weightedContext = torch.bmm(sourceT, attn)
        weightedContext = weightedContext.view(batch_size, -1, ih, iw)
        attn = attn.view(batch_size, -1, ih, iw)

        return weightedContext, attn

class AffineBlock(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.gamma_mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
        self.beta_mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(self, x: Tensor, sentence_embed: Tensor) -> Tensor:
        scale_param = self.gamma_mlp(sentence_embed).unsqueeze(-1).unsqueeze(-1)
        shift_param = self.beta_mlp(sentence_embed).unsqueeze(-1).unsqueeze(-1)
        return scale_param * x + shift_param

class ResidualBlockG(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, text_dim: int = 256):
        super().__init__()
        hidden_dim = text_dim // 2
        
        self.affine1 = AffineBlock(text_dim, hidden_dim, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        
        self.affine2 = AffineBlock(text_dim, hidden_dim, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x: Tensor, sentence_embed: Tensor) -> Tensor:
        residual = self.skip(x)
        x = F.leaky_relu(self.affine1(x, sentence_embed), 0.2)
        x = self.conv1(x)
        x = F.leaky_relu(self.affine2(x, sentence_embed), 0.2)
        x = self.conv2(x)
        return residual + self.gamma * x


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

class Generator(nn.Module):
    def __init__(self, n_channels: int = 32, latent_dim: int = 100):
        super().__init__()
        self.n_channels = n_channels
        
        # Initial projection (4x4 spatial size)
        self.linear_in = nn.Linear(latent_dim, 8 * n_channels * 4 * 4)
        
        # Residual blocks with progressive upsampling
        self.res_blocks = nn.ModuleList([
            ResidualBlockG(8 * n_channels, 8 * n_channels),
            ResidualBlockG(8 * n_channels, 4 * n_channels),
            ResidualBlockG(4 * n_channels, 2 * n_channels),
            ResidualBlockG(2 * n_channels, n_channels),
        ])
        
        # Attention mechanism at 64x64 resolution
        self.att = GlobalAttentionGeneral(n_channels, 256)
        self.att_conv = nn.Conv2d(2 * n_channels, n_channels, kernel_size=1)
        
        # Final upsampling to 256x256
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),  # 64->128
            nn.Conv2d(n_channels, n_channels//2, kernel_size=3, padding=1),
            nn.BatchNorm2d(n_channels//2),
            nn.LeakyReLU(0.2, True),
            
            nn.Upsample(scale_factor=2, mode='nearest'),  # 128->256
            nn.Conv2d(n_channels//2, n_channels//4, kernel_size=3, padding=1),
            nn.BatchNorm2d(n_channels//4),
            nn.LeakyReLU(0.2, True),
        )
        
        # Output layers
        self.conv_out = nn.Sequential(
            nn.Conv2d(n_channels//4, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, noise: Tensor, sentence_embed: Tensor, word_embed: Tensor) -> Tensor:
        # Initial projection
        x = self.linear_in(noise).view(-1, 8 * self.n_channels, 4, 4)
        
        # Process through residual blocks with 2x upsampling each
        for block in self.res_blocks:
            x = block(x, sentence_embed)
            x = F.interpolate(x, scale_factor=2, mode='nearest')  # 4->8->16->32->64
        
        # Apply attention at 64x64 resolution
        attn_out, _ = self.att(x, word_embed)
        x = torch.cat([x, attn_out], dim=1)
        x = self.att_conv(x)
        
        # Final upsampling to 256x256
        x = self.upsample(x)
        return self.conv_out(x)
        
     


In [5]:
import torch
import torchinfo

gen = Generator(n_channels=32, latent_dim=100)

noise = torch.rand((24, 100))
sent = torch.rand((24, 256))
word = torch.rand((24, 256, 18))

torchinfo.summary(gen, input_data=(noise, sent, word))

Layer (type:depth-idx)                   Output Shape              Param #
Generator                                [24, 3, 256, 256]         --
├─Linear: 1-1                            [24, 4096]                413,696
├─ModuleList: 1-2                        --                        --
│    └─ResidualBlockG: 2-1               [24, 256, 4, 4]           1
│    │    └─Identity: 3-1                [24, 256, 4, 4]           --
│    │    └─AffineBlock: 3-2             [24, 256, 4, 4]           131,840
│    │    └─Conv2d: 3-3                  [24, 256, 4, 4]           590,080
│    │    └─AffineBlock: 3-4             [24, 256, 4, 4]           131,840
│    │    └─Conv2d: 3-5                  [24, 256, 4, 4]           590,080
│    └─ResidualBlockG: 2-2               [24, 128, 8, 8]           1
│    │    └─Conv2d: 3-6                  [24, 128, 8, 8]           32,896
│    │    └─AffineBlock: 3-7             [24, 256, 8, 8]           131,840
│    │    └─Conv2d: 3-8                  [24, 128, 8,

# Discriminator

In [6]:
import torch
import torch.nn as nn
import torchvision.models as models

class ShallowResNetDiscriminator(nn.Module):
    def __init__(self, n_c=32, sentence_embed_dim=256):
        super().__init__()
        resnet = models.resnet18(pretrained=True)

        # Use only the first two blocks of ResNet18
        self.feature_extractor = nn.Sequential(
            resnet.conv1,  # Initial conv layer
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,  # First ResNet block
            resnet.layer2,  # Second ResNet block
        )

        # Output channels from layer2 in ResNet-18 is 128
        self.extra_layers = nn.Sequential(
            nn.Conv2d(128, n_c * 16, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(n_c * 16, n_c * 16, kernel_size=4, stride=2, padding=1),
            nn.Conv2d(n_c * 16, n_c * 16, kernel_size=4, stride=2, padding=1),
        )

        # Final classification layers
        in_c_logit = 16 * n_c + sentence_embed_dim
        self.img_sentence_forward = nn.Sequential(
            nn.Conv2d(in_c_logit, n_c * 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(n_c * 2, 1, kernel_size=2, stride=1, padding=0, bias=False)
        )

    def build_embeds(self, image: torch.Tensor) -> torch.Tensor:
        """
        Extract feature embeddings from the input image using ResNet and additional conv layers.
        Expected output shape: [batch_size, n_c*16, 2, 2] for typical 256x256 inputs.
        """
        out = self.feature_extractor(image)  # Shape: [batch, 128, h/4, w/4]
        out = self.extra_layers(out)  # Shape: [batch, n_c*16, 2, 2]
        return out

    def get_logits(self, image_embed: torch.Tensor, sentence_embed: torch.Tensor) -> torch.Tensor:
        """
        Combine image features and sentence embeddings, then compute real/fake logits.
        """
        # Reshape sentence_embed: [batch, 256] -> [batch, 256, 1, 1] and then repeat to match image spatial dims.
        sentence_embed = sentence_embed.view(-1, 256, 1, 1).repeat(1, 1, image_embed.shape[2], image_embed.shape[3])
        h_c_code = torch.cat((image_embed, sentence_embed), 1)  # Concatenate along channel dimension
        logits = self.img_sentence_forward(h_c_code)
        return logits

    def forward(self, image: torch.Tensor, sentence_embed: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the shallow discriminator.
        """
        image_embed = self.build_embeds(image)
        logits = self.get_logits(image_embed, sentence_embed)
        return logits


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

class ResidualBlockD(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        
        # Main convolution path
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Shortcut path (channel and spatial adjustment)
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False) 
            if in_channels != out_channels else nn.Identity(),
            nn.AvgPool2d(2)
        )

        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x: Tensor) -> Tensor:
        return self.shortcut(x) + self.gamma * self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, n_c: int = 32, sentence_embed_dim: int = 256):
        super().__init__()
        
        # Image processing pathway
        self.img_encoder = nn.Sequential(
            # Initial convolution (no residual)
            nn.Conv2d(3, n_c, kernel_size=3, stride=1, padding=1),
            
            # Residual downsampling blocks
            ResidualBlockD(n_c * 1, n_c * 2),  # 256x256 -> 128x128
            ResidualBlockD(n_c * 2, n_c * 4),  # 128x128 -> 64x64
            ResidualBlockD(n_c * 4, n_c * 8),  # 64x64 -> 32x32
            ResidualBlockD(n_c * 8, n_c * 16), # 32x32 -> 16x16
            
            # Final downsampling
            nn.Conv2d(n_c*16, n_c*16, 4, 2, 1),  # 16x16 -> 8x8
            nn.Conv2d(n_c*16, n_c*16, 4, 2, 1)    # 8x8 -> 4x4
        )

        # Text-image fusion
        self.judge_net = nn.Sequential(
            nn.Conv2d(n_c*16 + sentence_embed_dim, n_c*2, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(n_c*2, 1, 4, 1, 0)  # Final 1x1 output
        )

    def build_embeds(self, image: Tensor) -> Tensor:
        """Extract image features (same as original)"""
        return self.img_encoder(image)

    def get_logits(self, image_embed: Tensor, sentence_embed: Tensor) -> Tensor:
        """Fuse image and text features (same interface)"""
        # Expand text to spatial dimensions
        sentence_embed = sentence_embed.view(-1, 256, 1, 1).expand(-1, -1, 4, 4)
        
        # Concatenate and classify
        combined = torch.cat((image_embed, sentence_embed), dim=1)
        return self.judge_net(combined)

# Dataset

In [8]:
import os
import pickle
from typing import Dict, List, Optional, Tuple

import numpy as np
import numpy.random as random
import pandas as pd
import torchvision.transforms as transforms
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.transforms import Compose


class DFGANDataset(Dataset):
    def __init__(self, data_dir: str, split: str = "train", transform: Optional[Compose] = None):
        self.split = split
        self.data_dir = data_dir

        self.split_dir = os.path.join(data_dir, split)
        self.captions_path = os.path.join(self.data_dir, "captions.pickle")
        self.filenames_path = os.path.join(self.split_dir, "filenames.pickle")

        self.transform = transform

        self.embeddings_num = 10

        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.images_dir = os.path.join(self.data_dir, "CUB_200_2011/CUB_200_2011/images")
        self.bbox_path = os.path.join(self.data_dir, "CUB_200_2011/CUB_200_2011/bounding_boxes.txt")
        self.images_path = os.path.join(self.data_dir, "CUB_200_2011/CUB_200_2011/images.txt")

        self.bbox = self._load_bbox()

        self.file_names, self.captions, self.code2word, self.word2code = self._load_text_data()

        self.n_words = len(self.code2word)
        self.num_examples = len(self.file_names)

        self._print_info()

    def _print_info(self):
        print(f"Total filenames: {len(self.bbox)}")
        print(f"Load captions from: {self.captions_path}")
        print(f"Load file names from: {self.filenames_path} ({self.num_examples})")
        print(f"Dictionary size: {self.n_words}")
        print(f"Embeddings number: {self.embeddings_num}")

    def _load_bbox(self) -> Dict[str, List[int]]:
        df_bbox = pd.read_csv(self.bbox_path, delim_whitespace=True, header=None).astype(int)

        df_image_names = pd.read_csv(self.images_path, delim_whitespace=True, header=None)
        image_names = df_image_names[1].tolist()

        filename_bbox = dict()
        for i, file_name in enumerate(image_names):
            bbox = df_bbox.iloc[i][1:].tolist()
            filename_bbox[file_name[:-4]] = bbox

        return filename_bbox

    def _load_text_data(self) -> Tuple[List[str], List[List[int]],
                                       Dict[int, str], Dict[str, int]]:
        with open(self.captions_path, 'rb') as file:
            train_captions, test_captions, code2word, word2code = pickle.load(file)

        filenames = self._load_filenames()

        if self.split == 'train':
            return filenames, train_captions, code2word, word2code

        return filenames, test_captions, code2word, word2code

    def _load_filenames(self) -> List[str]:
        if os.path.isfile(self.filenames_path):
            with open(self.filenames_path, 'rb') as file:
                return pickle.load(file)

        raise ValueError(f"File {self.filenames_path} does not exist")

    def _get_caption(self, caption_idx: int) -> Tuple[np.ndarray, int]:
        caption = np.array(self.captions[caption_idx])
        pad_caption = np.zeros((18, 1), dtype='int64')

        if len(caption) <= 18:
            pad_caption[:len(caption), 0] = caption
            return pad_caption, len(caption)

        indices = list(np.arange(len(caption)))
        np.random.shuffle(indices)
        pad_caption[:, 0] = caption[np.sort(indices[:18])]

        return pad_caption, 18

    def _get_image(self, image_path: str, bbox: List[int]) -> Tensor:

        image = Image.open(image_path).convert('RGB')
        width, height = image.size

        r = int(np.maximum(bbox[2], bbox[3]) * 0.75)
        center_x = int((2 * bbox[0] + bbox[2]) / 2)
        center_y = int((2 * bbox[1] + bbox[3]) / 2)

        y1 = np.maximum(0, center_y - r)
        y2 = np.minimum(height, center_y + r)
        x1 = np.maximum(0, center_x - r)
        x2 = np.minimum(width, center_x + r)

        image = image.crop((x1, y1, x2, y2))
        image = self.normalize(self.transform(image))

        return image

    def _get_random_caption(self, idx: int) -> Tuple[np.ndarray, int]:
        caption_shift = random.randint(0, self.embeddings_num-1)
        caption_idx = idx * self.embeddings_num + caption_shift

        if caption_idx >= len(self.captions):
            caption_idx = len(self.captions) - 1 
            
        return self._get_caption(caption_idx)

    def __getitem__(self, idx: int) -> Tuple[Tensor, np.ndarray, int, str]:
        file_name = self.file_names[idx]
        image = self._get_image(f"{self.images_dir}/{file_name}.jpg", self.bbox[file_name])

        encoded_caption, caption_len = self._get_random_caption(idx)

        return image, encoded_caption, caption_len, file_name

    def __len__(self) -> int:
        return self.num_examples

In [9]:
from typing import List, Tuple

import torch
from torch import Tensor


def prepare_data(batch: Tuple[Tensor, Tensor, Tensor, Tuple[str]],
                 device: torch.device) -> Tuple[Tensor, Tensor, Tensor, List[str]]:
    images, captions, captions_len, file_names = batch

    sorted_cap_lens, sorted_cap_indices = torch.sort(captions_len, 0, True)
    sorted_cap_lens = sorted_cap_lens.to(device)

    sorted_images = images[sorted_cap_indices].to(device)
    sorted_captions = captions[sorted_cap_indices].squeeze().to(device)
    sorted_file_names = [file_names[i] for i in sorted_cap_indices.numpy()]

    return sorted_images, sorted_captions, sorted_cap_lens, sorted_file_names


In [10]:
import random
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader,Subset
from torchvision.transforms import transforms




def create_loader(imsize: int, batch_size: int, data_dir: str, split: str) -> DataLoader:
    assert split in ["train", "test"], "Wrong split type, expected train or test"
    image_transform = transforms.Compose([
        transforms.Resize(int(imsize * 76 / 64)),
        transforms.RandomCrop(imsize),
        transforms.RandomHorizontalFlip()
    ])

    dataset = DFGANDataset(data_dir, split, image_transform)
    
    n_words = dataset.n_words
    
    subset_size=6000
    shuffled_indices = torch.randperm(len(dataset))[:6000].tolist()
    dataset = Subset(dataset, shuffled_indices)
    

    print(len(dataset))

    return DataLoader(dataset, batch_size=batch_size, drop_last=True,shuffle=True),n_words


def fix_seed(seed: int = 123321):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    print(f"Seed {seed} fixed")


In [11]:
import os
from typing import Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.auto import trange

class DeepFusionGAN:
    def __init__(self, n_words, encoder_weights_path: str, image_save_path: str, gen_path_save: str, use_pretrained_discriminator: bool = True):
        super().__init__()
        self.image_save_path = image_save_path
        self.gen_path_save = gen_path_save
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.generator = Generator(n_channels=32, latent_dim=100).to(self.device)

        if use_pretrained_discriminator:
            self.discriminator = ShallowResNetDiscriminator(n_c=32).to(self.device)
        else:
            self.discriminator = Discriminator(n_c=32).to(self.device)

        self.text_encoder = RNNEncoder.load(encoder_weights_path, n_words).to(self.device)
        for p in self.text_encoder.parameters():
            p.requires_grad = False
        self.text_encoder.eval()

        self.g_optim = torch.optim.Adam(self.generator.parameters(), lr=0.0003, betas=(0.5, 0.999))
        self.d_optim = torch.optim.Adam(self.discriminator.parameters(), lr=5e-6, betas=(0.5, 0.999))

    def _compute_gp(self, images: Tensor, sentence_embeds: Tensor) -> Tensor:
        batch_size = images.shape[0]
        images_interpolated = images.data.requires_grad_()
        sentences_interpolated = sentence_embeds.data.requires_grad_()
        
        embeds = self.discriminator.build_embeds(images_interpolated)
        logits = self.discriminator.get_logits(embeds, sentences_interpolated)
        
        grad_outputs = torch.ones_like(logits)
        grads = torch.autograd.grad(
            outputs=logits,
            inputs=(images_interpolated, sentences_interpolated),
            grad_outputs=grad_outputs,
            retain_graph=True,
            create_graph=True
        )
        
        grad_0 = grads[0].reshape(batch_size, -1)
        grad_1 = grads[1].reshape(batch_size, -1)
        
        grad = torch.cat((grad_0, grad_1), dim=1)
        grad_norm = grad.norm(2, dim=1)
        return grad_norm

    def fit(self, train_loader: DataLoader, num_epochs: int = 500, checkpoint_path: str = None) -> Tuple[List[float], List[float], List[float]]:
        g_losses_epoch, d_losses_epoch, d_gp_losses_epoch = [], [], []
        lambda_gp = 1.0
        start_epoch = 150

        if checkpoint_path and os.path.exists(checkpoint_path):
            # start_epoch = self._load_gen_weights(checkpoint_path)
            print(f"Resuming from epoch {start_epoch}")

        path="/kaggle/input/workin/gennormal_200.pth"
        start_epoch=self._load_gen_weights(path)

        for epoch in trange(start_epoch, num_epochs + start_epoch, desc="Training DeepFusionGAN"):
            g_losses, d_losses, d_gp_losses = [], [], []
            for batch in train_loader:
                images, captions, captions_len, _ = prepare_data(batch, self.device)
                batch_size = images.shape[0]

                sentence_embeds, words_embs = self.text_encoder(captions, captions_len)
                sentence_embeds, words_embs = sentence_embeds.detach(), words_embs.detach()

                for _ in range(3): 
                    real_logits = self.discriminator(images, sentence_embeds)
                    d_loss_real = -torch.mean(real_logits)

                    noise = torch.randn(batch_size, 100, device=self.device)
                    fake_images = self.generator(noise, sentence_embeds, words_embs)
                    fake_logits = self.discriminator(fake_images.detach(), sentence_embeds)
                    d_loss_fake = torch.mean(fake_logits)

                    d_loss_gp = torch.mean((self._compute_gp(images, sentence_embeds) - 1) ** 2)
                    d_loss = d_loss_real + d_loss_fake + lambda_gp * d_loss_gp

                    self.d_optim.zero_grad()
                    d_loss.backward()
                    self.d_optim.step()

                    d_losses.append(d_loss.item())
                    d_gp_losses.append(d_loss_gp.item())

              
                noise = torch.randn(batch_size, 100, device=self.device)
                fake_images = self.generator(noise, sentence_embeds, words_embs)
                fake_logits = self.discriminator(fake_images, sentence_embeds)
                g_loss = -torch.mean(fake_logits)

                self.g_optim.zero_grad()
                g_loss.backward()
                self.g_optim.step()

                g_losses.append(g_loss.item())

            g_losses_epoch.append(sum(g_losses) / len(g_losses))
            d_losses_epoch.append(sum(d_losses) / len(d_losses))
            d_gp_losses_epoch.append(sum(d_gp_losses) / len(d_gp_losses))

            self._save_fake_image(fake_images, epoch)
            self._save_gen_weights(epoch)
            print(f"Epoch {epoch}: G Loss {g_losses_epoch[-1]:.4f}, D Loss {d_losses_epoch[-1]:.4f}, GP Loss {d_gp_losses_epoch[-1]:.4f}")

        return g_losses_epoch, d_losses_epoch, d_gp_losses_epoch

    def _save_fake_image(self, fake_images: Tensor, epoch: int):
        img_path = os.path.join(self.image_save_path, f"fake_wgangp_epoch_{epoch}.png")
        vutils.save_image(fake_images.data, img_path, normalize=True)

    def _save_gen_weights(self, epoch: int):
        gen_path = os.path.join(self.gen_path_save, f"gennormal_{epoch}.pth")
        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': self.generator.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'g_optim_state_dict': self.g_optim.state_dict(),
            'd_optim_state_dict': self.d_optim.state_dict(),
        }
        torch.save(checkpoint, gen_path)
        print(f"Model checkpoint saved at epoch {epoch} to {gen_path}")

    def _load_gen_weights(self, checkpoint_path: str) -> int:
        checkpoint = torch.load(checkpoint_path)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        self.g_optim.load_state_dict(checkpoint['g_optim_state_dict'])
        self.d_optim.load_state_dict(checkpoint['d_optim_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Checkpoint loaded from {checkpoint_path} - Starting from epoch {start_epoch}")
        return start_epoch


In [14]:
def train() -> Tuple[List[float], List[float], List[float]]:
    fix_seed()

    data_path = "/kaggle/input/cv-seattention/cv_seattention/data"
    encoder_weights_path = "/kaggle/input/cv-seattention/cv_seattention/text_encoder_weights/text_encoder200.pth"
    image_save_path = "/kaggle/working/gen_images"
    gen_path_save = "/kaggle/working/gen_weights"

    os.makedirs(image_save_path, exist_ok=True)
    os.makedirs(gen_path_save, exist_ok=True)

    train_loader, n_words = create_loader(256, 32, data_path, "train")
  # model = DeepFusionGAN(n_words=n_words,
    #                      encoder_weights_path=encoder_weights_path,
    #                      image_save_path=image_save_path,
    #                      gen_path_save=gen_path_save)
    model = DeepFusionGAN(
        n_words = train_loader.dataset.dataset.n_words,
        encoder_weights_path=encoder_weights_path,
        image_save_path=image_save_path,
        gen_path_save=gen_path_save,
        use_pretrained_discriminator=True  # Switch to small trainable D
    )  

    # Load your specific checkpoint
    checkpoint_path = os.path.join(gen_path_save, "/kaggle/input/workin/gennormal_200.pth")
    
    checkpoint = torch.load(checkpoint_path, map_location=model.device)

    
    # Load model states
    model.generator.load_state_dict(checkpoint['generator_state_dict'])
    # model.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    model.discriminator.load_state_dict(checkpoint['discriminator_state_dict'], strict=False)

    # Load optimizer states
    model.g_optim.load_state_dict(checkpoint['g_optim_state_dict'])
    model.d_optim.load_state_dict(checkpoint['d_optim_state_dict'])
    start_epoch = 200
    # start_epoch = checkpoint['epoch'] + 1  # Resume from next epoch
    print(f"Loaded checkpoint from epoch {start_epoch} , resuming from epoch {start_epoch}")
    
    num_epochs = 20
    

    return model.fit(train_loader, num_epochs)

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [15]:
import os
import sys
import csv

current_cwd = os.getcwd()
src_path = '/'.join(current_cwd.split('/')[:-1])
sys.path.append(src_path)

g_losses_epoch, d_losses_epoch, d_gp_losses_epoch = train()

path = "/kaggle/working/loss"

os.makedirs(path, exist_ok=True)

filenames = [f'{path}/loss1.csv', f'{path}/loss2.csv', f'{path}/loss3.csv']

loss_values = [g_losses_epoch, d_losses_epoch, d_gp_losses_epoch]

for i in range(len(loss_values)):
    filename = filenames[i]
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Loss'])
        for loss in loss_values[i]:  
            writer.writerow([loss])



In [17]:
import os

import numpy as np
import torch
from PIL import Image
from torch import Tensor




@torch.no_grad()
def generate_images(generator: Generator, sentence_embeds: Tensor,
                    device: torch.device) -> Tensor:
    batch_size = sentence_embeds.shape[0]
    noise = torch.randn(batch_size, 100).to(device)
    return generator(noise, sentence_embeds)


def save_image(image: np.ndarray, save_dir: str, file_name: str):
    # [-1, 1] --> [0, 255]
    image = (image + 1.0) * 127.5
    image = image.astype(np.uint8)
    image = np.transpose(image, (1, 2, 0))
    image = Image.fromarray(image)
    fullpath = os.path.join(save_dir, f"{file_name.replace('/', '_')}.png")
    image.save(fullpath)


def sample(generator: Generator, text_encoder: RNNEncoder, batch, save_dir: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs(save_dir, exist_ok=True)

    images, captions, captions_len, file_names = prepare_data(batch, device)
    sent_emb = text_encoder(captions, captions_len).detach()

    fake_images = generate_images(generator, sent_emb, device)

    for i in range(images.shape[0]):
        im = fake_images[i].data.cpu().numpy()
        save_image(im, save_dir, file_names[i])


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.linalg import sqrtm
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from scipy.stats import entropy

# === Generator checkpoints ===
gen_ckpts = [
    (150, "/kaggle/input/workin/gennormal_150.pth"),
    (160, "/kaggle/input/workin/gennormal_158.pth"),
    (180, "/kaggle/input/workin/gennormal_180.pth"),
    (200, "/kaggle/input/workin/gennormal_200.pth"),
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32

test_loader, n_words = create_loader(
    256, batch_size,
    "/kaggle/input/cv-seattention/cv_seattention/data",
    "test"
)

# initialize generator & encoder once
dummy_ckpt = torch.load(gen_ckpts[-1][1], map_location=device)
generator = Generator(n_channels=32, latent_dim=100).to(device)
generator.load_state_dict(dummy_ckpt['generator_state_dict'])
generator.eval()

text_encoder = RNNEncoder.load(
    "/kaggle/input/cv-seattention/cv_seattention/text_encoder_weights/text_encoder200.pth",
    n_words
).to(device).eval()
for p in text_encoder.parameters():
    p.requires_grad = False

# InceptionV3 feature extractor
class InceptionV3(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.hub.load(
            'pytorch/vision:v0.6.0',
            'inception_v3',
            pretrained=True
        ).to(device).eval()
        self.linear = self.model.fc
        self.model.fc = nn.Identity()
        self.model.dropout = nn.Identity()

    @torch.no_grad()
    def get_last_layer(self, x):
        x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        return self.model(x)

classifier = InceptionV3().to(device).eval()

# FID & IS functions
def calculate_fid(real_feats, fake_feats):
    mu1, mu2 = real_feats.mean(0), fake_feats.mean(0)
    s1, s2 = np.cov(real_feats, rowvar=False), np.cov(fake_feats, rowvar=False)
    diff = mu1 - mu2
    cm = sqrtm(s1 @ s2)
    if not np.isfinite(cm).all():
        eps = np.eye(s1.shape[0]) * 1e-8
        cm = sqrtm((s1+eps) @ (s2+eps))
    if np.iscomplexobj(cm): cm = cm.real
    return float(diff@diff + np.trace(s1 + s2 - 2*cm))

def inception_score(fake_feats, bs=32):
    preds = []
    for i in range(0, len(fake_feats), bs):
        b = torch.tensor(fake_feats[i:i+bs], device=device)
        p = F.softmax(classifier.linear(b), dim=1).detach().cpu().numpy()
        preds.append(p)
    preds = np.vstack(preds)
    py = preds.mean(0)
    kl = preds * (np.log(preds+1e-10) - np.log(py+1e-10))
    return float(np.exp(np.mean(kl.sum(1)))), float(np.std(kl.sum(1)))

def build_representations(generator, text_encoder, max_batches=None):
    real_list, fake_list = [], []
    for i, batch in enumerate(tqdm(test_loader, desc="Building repr")):
        if max_batches and i >= max_batches:
            break
        imgs, caps, lens, _ = prepare_data(batch, device)
        with torch.no_grad():
            se, we = text_encoder(caps, lens)
            noise = torch.randn(imgs.size(0), 100, device=device)
            fakes = generator(noise, se, we)
            real_list.append(classifier.get_last_layer(imgs).cpu().numpy())
            fake_list.append(classifier.get_last_layer(fakes).cpu().numpy())
    return np.vstack(real_list), np.vstack(fake_list)

# === 1) Dry-run each checkpoint on 1 batch ===
for epoch, path in gen_ckpts:
    print(f"Dry-run Epoch {epoch}...")
    ckpt = torch.load(path, map_location=device)
    generator.load_state_dict(ckpt['generator_state_dict'])
    try:
        r, f = build_representations(generator, text_encoder, max_batches=1)
        print(f"  ✓ shapes: {r.shape}, {f.shape}")
    except Exception as e:
        raise RuntimeError(f"Error at checkpoint {path}: {e}")

print("\nAll dry-runs passed. Proceeding to full evaluation...\n")

# === 2) Full evaluation ===
epochs, fid_scores, is_means, is_stds = [], [], [], []
for epoch, path in gen_ckpts:
    print(f"\n=== Full run Epoch {epoch} ===")
    ckpt = torch.load(path, map_location=device)
    generator.load_state_dict(ckpt['generator_state_dict'])
    real_feats, fake_feats = build_representations(generator, text_encoder)
    fid = calculate_fid(real_feats, fake_feats)
    ism, iss = inception_score(fake_feats, batch_size)
    print(f"  FID={fid:.2f} | IS={ism:.2f} ± {iss:.2f}")
    epochs.append(epoch)
    fid_scores.append(fid)
    is_means.append(ism)
    is_stds.append(iss)

# === Plot & summary ===
print("\nResults per epoch:")
for e, f, im, isd in zip(epochs, fid_scores, is_means, is_stds):
    print(f" Epoch {e}: FID {f:.2f}, IS {im:.2f}±{isd:.2f}")

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(epochs, fid_scores,'b-o'); plt.title("FID"); plt.xlabel("Epoch")
plt.subplot(1,2,2)
plt.errorbar(epochs, is_means, yerr=is_stds, fmt='r-o', capsize=3); plt.title("IS"); plt.xlabel("Epoch")
plt.tight_layout(); plt.show()

best_f = epochs[np.argmin(fid_scores)]
best_i = epochs[np.argmax(is_means)]
print(f"\nBest FID @ epoch {best_f}, Best IS @ epoch {best_i}")


Total filenames: 11788
Load captions from: /kaggle/input/cv-seattention/cv_seattention/data/captions.pickle
Load file names from: /kaggle/input/cv-seattention/cv_seattention/data/test/filenames.pickle (2933)
Dictionary size: 5450
Embeddings number: 10
2933


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0


Dry-run Epoch 150...


Building repr:   0%|          | 0/91 [00:00<?, ?it/s]

  ✓ shapes: (32, 2048), (32, 2048)
Dry-run Epoch 160...


Building repr:   0%|          | 0/91 [00:00<?, ?it/s]

  ✓ shapes: (32, 2048), (32, 2048)
Dry-run Epoch 180...


Building repr:   0%|          | 0/91 [00:00<?, ?it/s]

  ✓ shapes: (32, 2048), (32, 2048)
Dry-run Epoch 200...


Building repr:   0%|          | 0/91 [00:00<?, ?it/s]

  ✓ shapes: (32, 2048), (32, 2048)

All dry-runs passed. Proceeding to full evaluation...


=== Full run Epoch 150 ===


Building repr:   0%|          | 0/91 [00:00<?, ?it/s]

  FID=280.10 | IS=2.99 ± 0.50

=== Full run Epoch 160 ===


Building repr:   0%|          | 0/91 [00:00<?, ?it/s]

  FID=299.19 | IS=3.21 ± 0.43

=== Full run Epoch 180 ===


Building repr:   0%|          | 0/91 [00:00<?, ?it/s]

In [None]:
# import os
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from scipy.linalg import sqrtm  # <-- use SciPy’s matrix sqrt
# from tqdm.auto import tqdm
# import matplotlib.pyplot as plt
# from scipy.stats import entropy

# # === Generator checkpoints ===
# gen_ckpts = [
#     (150, "/kaggle/input/workin/gennormal_150.pth"),
#     (160, "/kaggle/input/workin/gennormal_158.pth"),
#     (180, "/kaggle/input/workin/gennormal_180.pth"),
#     (200, "/kaggle/input/workin/gennormal_200.pth"),
# ]

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # === Load data and text encoder ===
# batch_size = 32
# test_loader, n_words = create_loader(
#     256, batch_size,
#     "/kaggle/input/cv-seattention/cv_seattention/data",
#     "test"
# )

# # Load one generator checkpoint just to initialize the model architecture
# checkpoint = torch.load("/kaggle/input/workin/gennormal_200.pth", map_location=device)
# generator = Generator(n_channels=32, latent_dim=100).to(device)
# generator.load_state_dict(checkpoint['generator_state_dict'])

# # Load and freeze your RNN text encoder
# text_encoder = RNNEncoder.load(
#     "/kaggle/input/cv-seattention/cv_seattention/text_encoder_weights/text_encoder200.pth",
#     n_words
# ).to(device).eval()
# for p in text_encoder.parameters():
#     p.requires_grad = False

# # === InceptionV3 for FID/IS ===
# class InceptionV3(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
#         self.model = torch.hub.load(
#             'pytorch/vision:v0.6.0',
#             'inception_v3',
#             pretrained=True
#         ).to(self.device).eval()
#         self.linear = self.model.fc
#         self.model.fc = nn.Identity()
#         self.model.dropout = nn.Identity()

#     @torch.no_grad()
#     def get_last_layer(self, x):
#         x = F.interpolate(
#             x,
#             size=(299, 299),
#             mode='bilinear',
#             align_corners=False
#         )
#         return self.model(x)

# classifier = InceptionV3().to(device).eval()

# # === FID calculation ===
# def calculate_fid(real_feats: np.ndarray, fake_feats: np.ndarray) -> float:
#     mu1, mu2 = real_feats.mean(axis=0), fake_feats.mean(axis=0)
#     sigma1 = np.cov(real_feats, rowvar=False)
#     sigma2 = np.cov(fake_feats, rowvar=False)
#     diff = mu1 - mu2

#     covmean = sqrtm(sigma1 @ sigma2)
#     if not np.isfinite(covmean).all():  # numerical stability
#         offset = np.eye(sigma1.shape[0]) * 1e-8
#         covmean = sqrtm((sigma1 + offset) @ (sigma2 + offset))
#     if np.iscomplexobj(covmean):
#         covmean = covmean.real

#     fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
#     return float(fid)

# # === Inception Score ===
# # === Inception Score ===
# def inception_score(fake_feats: np.ndarray, batch_size: int = 32) -> Tuple[float, float]:
#     preds = []
#     for i in range(0, len(fake_feats), batch_size):
#         batch = torch.tensor(
#             fake_feats[i:i + batch_size],
#             dtype=torch.float,
#             device=device
#         )
#         logits = classifier.linear(batch)
#         # detach before moving to CPU/Numpy
#         p = F.softmax(logits, dim=1).detach().cpu().numpy()
#         preds.append(p)
#     preds = np.vstack(preds)

#     py = preds.mean(axis=0)
#     kl = preds * (np.log(preds + 1e-10) - np.log(py + 1e-10))
#     scores = np.exp(np.mean(np.sum(kl, axis=1)))
#     return float(scores), float(np.std(np.sum(kl, axis=1)))


# # === Compute Features ===
# def build_representations(generator, text_encoder):
#     real_feats, fake_feats = [], []
#     generator.eval()

#     for batch in tqdm(test_loader, desc="Generating features"):
#         images, captions, cap_lens, _ = prepare_data(batch, device)
#         with torch.no_grad():
#             sent_emb, word_emb = text_encoder(captions, cap_lens)
#             noise = torch.randn(images.size(0), 100, device=device)
#             fakes = generator(noise, sent_emb, word_emb)

#             real_arr = classifier.get_last_layer(images).cpu().numpy()
#             fake_arr = classifier.get_last_layer(fakes).cpu().numpy()

#             real_feats.append(real_arr)
#             fake_feats.append(fake_arr)

#     real_feats = np.vstack(real_feats)
#     fake_feats = np.vstack(fake_feats)
#     return real_feats, fake_feats

# # === Run analysis ===
# epochs, fid_scores, is_means, is_stds = [], [], [], []

# for epoch, ckpt_path in gen_ckpts:
#     print(f"\n--- Epoch {epoch} ---")
#     gen_ckpt = torch.load(ckpt_path, map_location=device)
#     generator.load_state_dict(gen_ckpt['generator_state_dict'])

#     real_feats, fake_feats = build_representations(generator, text_encoder)

#     fid = calculate_fid(real_feats, fake_feats)
#     is_mean, is_std = inception_score(fake_feats, batch_size)

#     print(f"FID: {fid:.2f} | IS: {is_mean:.2f} ± {is_std:.2f}")
#     epochs.append(epoch)
#     fid_scores.append(fid)
#     is_means.append(is_mean)
#     is_stds.append(is_std)

# # === Print all results ===
# print("\n=== All Epoch Results ===")
# for e, f, im, isd in zip(epochs, fid_scores, is_means, is_stds):
#     print(f"Epoch {e}: FID = {f:.2f}, IS = {im:.2f} ± {isd:.2f}")

# # === Plot results ===
# plt.figure(figsize=(12, 5))

# plt.subplot(1, 2, 1)
# plt.plot(epochs, fid_scores, 'b-o', label='FID')
# plt.xlabel("Epoch")
# plt.ylabel("FID Score")
# plt.title("FID vs Epoch")
# plt.grid(True)
# plt.legend()

# plt.subplot(1, 2, 2)
# plt.errorbar(epochs, is_means, yerr=is_stds, fmt='r-o', capsize=5, label='Inception Score')
# plt.xlabel("Epoch")
# plt.ylabel("Inception Score")
# plt.title("Inception Score vs Epoch")
# plt.grid(True)
# plt.legend()

# plt.tight_layout()
# plt.show()

# # === Summary ===
# best_fid_epoch = epochs[np.argmin(fid_scores)]
# best_is_epoch = epochs[np.argmax(is_means)]
# print("\n--- Summary ---")
# print(f"Best FID: {min(fid_scores):.2f} at Epoch {best_fid_epoch}")
# print(f"Best IS: {max(is_means):.2f} at Epoch {best_is_epoch}")


In [None]:
import os
import sys
import numpy as np
import torch
import time

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure correct path handling
current_cwd = os.getcwd()
src_path = '/'.join(current_cwd.split('/')[:-1])
sys.path.append(src_path)


# Initialize generator
generator = Generator(n_channels=32, latent_dim=100).to(device)

# Load generator checkpoint
checkpoint_path = "/kaggle/input/workin/gennormal_200.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)

generator.load_state_dict(checkpoint['generator_state_dict'])  # Corrected
generator.eval()  # Set to evaluation mode

# Create data loader
train_loader,n_words = create_loader(256, 24, "/kaggle/input/cv-seattention/cv_seattention/data", "test")

# Load text encoder
text_encoder = RNNEncoder.load("/kaggle/input/cv-seattention/cv_seattention/text_encoder_weights/text_encoder200.pth", n_words)
text_encoder.to(device)

# Freeze text encoder parameters
for p in text_encoder.parameters():
    p.requires_grad = False
text_encoder.eval()

import matplotlib.pyplot as plt
import numpy as np

def gen_own_bird(word_caption, name, i):
    dataset = train_loader.dataset.dataset
    codes = [dataset.word2code[w] for w in word_caption.lower().split()]
    
    caption = np.array(codes)
    pad_caption = np.zeros((18, 1), dtype='int64')

    if len(caption) <= 18:
        pad_caption[:len(caption), 0] = caption
        len_ = len(caption)
    else:
        indices = list(np.arange(len(caption)))
        np.random.shuffle(indices)
        pad_caption[:, 0] = caption[np.sort(indices[:18])]
        len_ = 18

    tensor1 = torch.tensor(pad_caption).reshape(1, -1).to(device)
    tensor2 = torch.tensor([len_]).to(device)

    # Encode text
    embed, word = text_encoder(tensor1, tensor2)

    # Generate image
    batch_size = embed.shape[0]
    noise = torch.randn(batch_size, 100).to(device)
    img = generator(noise, embed, word)

    # Convert to numpy image
    img_np = img[0].data.cpu().numpy()
    img_np = (img_np + 1.0) / 2.0  # Rescale from [-1, 1] to [0, 1]
    img_np = np.transpose(img_np, (1, 2, 0))  # CHW to HWC

    # Save image
    save_image(img[0].data.cpu().numpy(), path, name + str(i))

    # Plot image
    plt.figure(figsize=(4, 4))
    plt.imshow(img_np)
    plt.axis("off")
    plt.title(f"Generated: '{word_caption}'")
    plt.show()


# Run generation
start_time = time.time()
caption = "a small bird with a red breast eyebrow and crown black and white wings with two white wing bars and black feet and tarsus"
i = 1
gen_own_bird(caption, caption, i)

print(f"Time taken: {time.time() - start_time:.2f} seconds")
