In [None]:
# Install dependencies
!pip install torch torchvision transformers tqdm pillow scipy numpy matplotlib torchmetrics


In [None]:
# Import libraries
import torch, torchvision
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid, save_image
import numpy as np, matplotlib.pyplot as plt, os
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
from PIL import Image
from scipy.io import loadmat


In [None]:
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 64
BATCH_SIZE = 32
EPOCHS = 10
LR = 2e-4
Z_DIM = 100
EMBED_DIM = 768  # BERT base hidden size
DATA_PATH = '/content/drive/MyDrive/102flowers/jpg/'
LABEL_PATH = '/content/drive/MyDrive/102flowers/imagelabels.mat'


In [None]:
# Dataset: Oxford Flowers
class FlowersDataset(Dataset):
    def __init__(self, img_folder, label_path, transform=None, tokenizer=None, bert_model=None):
        self.img_folder = img_folder
        self.transform = transform
        self.tokenizer = tokenizer
        self.bert = bert_model
        self.labels = loadmat(label_path)['labels'][0]
        self.imgs = sorted([os.path.join(img_folder, x) for x in os.listdir(img_folder) if x.endswith('.jpg')])

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        caption = f"A photo of a flower of class {self.labels[idx]}"
        tokens = self.tokenizer(caption, return_tensors='pt', truncation=True, padding='max_length', max_length=16)
        with torch.no_grad():
            emb = self.bert(**tokens.to(device)).last_hidden_state.mean(dim=1).squeeze(0)
        return emb, image


In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = FlowersDataset(DATA_PATH, LABEL_PATH, transform, tokenizer, bert_model)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)


In [None]:
# Model definitions
class Generator(nn.Module):
    def __init__(self, z_dim, embed_dim):
        super(Generator, self).__init__()

        self.embed_proj = nn.Linear(embed_dim, z_dim)

        self.layers = nn.Sequential(
            # Input: z_dim*2 x 1 x 1
            nn.ConvTranspose2d(z_dim * 2, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 512 x 4 x 4

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 256 x 8 x 8

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 128 x 16 x 16

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 64 x 32 x 32

            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),  # This outputs 64x64
            nn.Tanh()
            # 3 x 64 x 64
        )

    def forward(self, noise, embeddings):
        proj_emb = self.embed_proj(embeddings)
        combined = torch.cat([noise, proj_emb], dim=1)
        combined = combined.unsqueeze(2).unsqueeze(3)  # [batch, z_dim*2, 1, 1]
        return self.layers(combined)

class Discriminator(nn.Module):
    def __init__(self, embed_dim):
        super(Discriminator, self).__init__()

        self.conv_layers = nn.Sequential(
            # Input: 3 x 64 x 64
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 64 x 32 x 32

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # 128 x 16 x 16

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 256 x 8 x 8

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 512 x 4 x 4
        )

        # Final classification layer - fixed size
        # Adjust input size to accommodate concatenated features and embeddings
        self.fc = nn.Linear(512 * 4 * 4 + embed_dim, 1)

    def forward(self, x, embeddings):
        features = self.conv_layers(x)
        features = features.view(features.size(0), -1)  # Flatten to [batch, 512*4*4]
        # Concatenate features and embeddings before the final linear layer
        combined = torch.cat([features, embeddings], dim=1)
        output = torch.sigmoid(self.fc(combined))
        return output

In [None]:
# Training loop
gen = Generator(Z_DIM, EMBED_DIM).to(device)
disc = Discriminator(EMBED_DIM).to(device)
criterion = nn.BCELoss()
opt_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.999))

for epoch in range(EPOCHS):
    for embeds, imgs in tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        embeds, imgs = embeds.to(device), imgs.to(device)
        bs = imgs.size(0)
        noise = torch.randn(bs, Z_DIM, device=device)
        fake_imgs = gen(noise, embeds)

        # Train Discriminator
        real_loss = criterion(disc(imgs, embeds), torch.ones(bs, 1, device=device))
        fake_loss = criterion(disc(fake_imgs.detach(), embeds), torch.zeros(bs, 1, device=device))
        d_loss = (real_loss + fake_loss) / 2
        opt_disc.zero_grad(); d_loss.backward(); opt_disc.step()

        # Train Generator
        g_loss = criterion(disc(fake_imgs, embeds), torch.ones(bs, 1, device=device))
        opt_gen.zero_grad(); g_loss.backward(); opt_gen.step()

    print(f"Epoch [{epoch+1}/{EPOCHS}]  D_loss: {d_loss.item():.4f}  G_loss: {g_loss.item():.4f}")
    save_image(fake_imgs[:16]*0.5+0.5, f"gen_samples_epoch{epoch+1}.png")


In [None]:
# Evaluation: FID & Inception Score
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

fid = FrechetInceptionDistance(normalize=True).to(device)
is_metric = InceptionScore().to(device)

real_images, fake_images = next(iter(loader))[1].to(device), fake_imgs.to(device)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
print("FID:", fid.compute().item())

is_score = is_metric(fake_images)
print("Inception Score:", is_score[0].item(), "+/-", is_score[1].item())
