In [1]:
! pip install torch torchvision sentence-transformers



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from sentence_transformers import SentenceTransformer




In [3]:
# Define the Generator class
class Generator(nn.Module):
    def __init__(self, noise_dim, text_dim, img_channels, img_size):
        super(Generator, self).__init__()
        self.text_embed = nn.Linear(text_dim, noise_dim)  # Embed text features
        self.model = nn.Sequential(
            nn.Linear(noise_dim * 2, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, img_size * img_size * img_channels),
            nn.Tanh()
        )

    def forward(self, noise, text_features):
        text_embedded = self.text_embed(text_features)
        combined_input = torch.cat([noise, text_embedded], dim=1)
        img = self.model(combined_input).view(-1, 3, 64, 64)  # Assumes 64x64 images, 3 channels
        return img

In [4]:
# Define the Discriminator class
class Discriminator(nn.Module):
    def __init__(self, img_channels, img_size, text_dim):
        super(Discriminator, self).__init__()
        self.img_embed = nn.Linear(img_size * img_size * img_channels, 256)
        self.text_embed = nn.Linear(text_dim, 256)
        self.model = nn.Sequential(
            nn.Linear(512, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, img, text_features):
        img_flat = img.view(img.size(0), -1)
        img_embedded = self.img_embed(img_flat)
        text_embedded = self.text_embed(text_features)
        combined_input = torch.cat([img_embedded, text_embedded], dim=1)
        validity = self.model(combined_input)
        return validity

In [5]:
# Initialize models
noise_dim = 100
text_dim = 384  # Using all-MiniLM-L6-v2, which outputs embeddings of size 384
img_channels = 3
img_size = 64

G = Generator(noise_dim, text_dim, img_channels, img_size)
D = Discriminator(img_channels, img_size, text_dim)

# Define loss and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [6]:
# Convert text description to embedding
text_model = SentenceTransformer('all-MiniLM-L6-v2')
text_description = "a red circle"  # Your text description
text_embedding = torch.tensor(text_model.encode(text_description)).unsqueeze(0)  # Shape: (1, text_dim)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [7]:
# Training loop
batch_size = 16
epochs = 50

for epoch in range(epochs):
    for _ in range(50):  # 50 iterations per epoch
        # Train Discriminator
        real_imgs = torch.randn(batch_size, img_channels, img_size, img_size)  # Replace with real images if available
        noise = torch.randn(batch_size, noise_dim)
        fake_imgs = G(noise, text_embedding.repeat(batch_size, 1)).detach()

        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        optimizer_D.zero_grad()
        real_loss = criterion(D(real_imgs, text_embedding.repeat(batch_size, 1)), real_labels)
        fake_loss = criterion(D(fake_imgs, text_embedding.repeat(batch_size, 1)), fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        noise = torch.randn(batch_size, noise_dim)
        fake_imgs = G(noise, text_embedding.repeat(batch_size, 1))

        optimizer_G.zero_grad()
        g_loss = criterion(D(fake_imgs, text_embedding.repeat(batch_size, 1)), real_labels)
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch + 1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # Save generated images every 10 epochs
    if (epoch + 1) % 10 == 0:
        save_image(fake_imgs.data[:16], f"single_text_generated_epoch_{epoch + 1}.png", nrow=4, normalize=True)

Epoch [1/50] | D Loss: 1.4676 | G Loss: 0.6919
Epoch [2/50] | D Loss: 1.3621 | G Loss: 0.7280
Epoch [3/50] | D Loss: 1.2151 | G Loss: 0.8043
Epoch [4/50] | D Loss: 0.9959 | G Loss: 1.0950
Epoch [5/50] | D Loss: 1.3373 | G Loss: 0.8915
Epoch [6/50] | D Loss: 1.5806 | G Loss: 0.9549
Epoch [7/50] | D Loss: 1.2623 | G Loss: 0.9554
Epoch [8/50] | D Loss: 1.2641 | G Loss: 0.9782
Epoch [9/50] | D Loss: 1.8527 | G Loss: 1.5321
Epoch [10/50] | D Loss: 1.1707 | G Loss: 1.1492
Epoch [11/50] | D Loss: 1.1353 | G Loss: 1.0077
Epoch [12/50] | D Loss: 1.2049 | G Loss: 1.8296
Epoch [13/50] | D Loss: 0.9696 | G Loss: 1.5436
Epoch [14/50] | D Loss: 1.1401 | G Loss: 1.2927
Epoch [15/50] | D Loss: 1.3489 | G Loss: 1.5601
Epoch [16/50] | D Loss: 1.1667 | G Loss: 1.1085
Epoch [17/50] | D Loss: 1.4207 | G Loss: 1.3174
Epoch [18/50] | D Loss: 1.8804 | G Loss: 1.5302
Epoch [19/50] | D Loss: 1.3367 | G Loss: 1.3369
Epoch [20/50] | D Loss: 1.4729 | G Loss: 0.7738
Epoch [21/50] | D Loss: 1.4301 | G Loss: 0.8436
E