<a href="https://colab.research.google.com/github/sarayu-nar/GATIS-project/blob/main/GATIS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np



In [22]:
from google.colab import files
uploaded = files.upload()

KeyboardInterrupt: 

In [20]:
import os

text_file = r"C:\Users\mural\Desktop\ACM GATIS Project\CUB_200_2011\attributes.TXT"
if os.path.exists(text_file):
    print("Attributes file exists.")
else:
    print("Attributes file not found!")

img_dir = r"C:\Users\mural\Desktop\ACM GATIS Project\CUB_200_2011\CUB_200_2011\images"
if os.path.exists(img_dir):
    print("Image file exists.")
else:
    print("Images file not found!")


Attributes file not found!
Images file not found!


In [13]:
class TextToImageDataset(Dataset):
    def __init__(self, img_dir= r"C:\Users\mural\Desktop\ACM GATIS Project\CUB_200_2011\CUB_200_2011\images", text_file= r"C:\Users\mural\Desktop\ACM GATIS Project\CUB_200_2011\attributes.txt", transform=None):
        #the init method initializes the dataset with essential variables like the image directory, text file, and any transformations.
        #transform-optional transformations for the images
        self.img_dir = img_dir
        self.text_descriptions = self.load_text(text_file)
        self.img_names = list(self.text_descriptions.keys())
        self.transform = transform

    def load_text(self, text_file):
        descriptions = {}
        with open(text_file, 'r') as file:
            for line in file:
                img_name, description = line.strip().split('\t')
                descriptions[img_name] = description
        return descriptions

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        description = self.text_descriptions[img_name]
        return image, description


The Dataset class provides an interface for loading and processing data.
The TextToImageDataset class is a custom implementation that:
Loads images from a directory.
Loads corresponding text descriptions from a file.
Preprocesses the images and descriptions.
Returns a single image-description pair each time it’s called.

Load Method:
The load function reads and processes the text file containing descriptions for each image.
Opens the text file and reads it line by line.
For each line it:
1. Splits the line into an image name (img_name) and its corresponding description.
2. Stores them in a dictionary (descriptions), where the key is the image name and the value is the description.
3. Returns the dictionary.

In [3]:
from transformers import BertTokenizer, BertModel

class TextEncoder(torch.nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        # Load a pre-trained BERT model and tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

    def forward(self, text):
        # Tokenize and encode the text
        inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        outputs = self.bert_model(**inputs)

        # We will use the last hidden state as the text embedding
        text_embedding = outputs.last_hidden_state[:, 0, :]  # [batch_size, hidden_size]
        return text_embedding

In [4]:
import torch.nn.functional as F
from torchvision import models

class Generator(nn.Module):
    def __init__(self, text_embedding_size, latent_dim=100, output_channels=3):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(latent_dim + text_embedding_size, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, 1024)
        self.fc5 = nn.Linear(1024, 64 * 64 * output_channels)
        self.output_channels = output_channels

    def forward(self, noise, text_embedding):
        # Concatenate noise and text embedding
        x = torch.cat([noise, text_embedding], dim=1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))

        x = self.fc5(x)
        x = x.view(-1, self.output_channels, 64, 64)  # Reshape to image dimensions
        return torch.tanh(x)  # Scale image pixels to [-1, 1]


In [5]:
class Discriminator(nn.Module):
    def __init__(self, text_embedding_size, input_channels=3):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(512 * 4 * 4 + text_embedding_size, 1024)
        self.fc2 = nn.Linear(1024, 1)

    def forward(self, image, text_embedding):
        # Pass image through convolutional layers
        x = F.leaky_relu(self.conv1(image), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = F.leaky_relu(self.conv4(x), 0.2)

        x = x.view(x.size(0), -1)  # Flatten

        # Concatenate text embedding to image features
        x = torch.cat([x, text_embedding], dim=1)

        x = F.leaky_relu(self.fc1(x), 0.2)
        x = torch.sigmoid(self.fc2(x))  # Output between 0 and 1
        return x

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 100
batch_size = 64
num_epochs = 100
learning_rate = 0.0002
text_embedding_size = 768  # Output size of BERT

# Initialize the models
text_encoder = TextEncoder().to(device)
generator = Generator(text_embedding_size, latent_dim).to(device)
discriminator = Discriminator(text_embedding_size).to(device)

# Loss functions and optimizers
adversarial_loss = torch.nn.BCEWithLogitsLoss()
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# Load dataset
dataset = TextToImageDataset(
    img_dir=r"C:\Users\mural\Desktop\ACM GATIS Project\CUB_200_2011\CUB_200_2011\images",
    text_file=r"C:\Users\mural\Desktop\ACM GATIS Project\CUB_200_2011\attributes.txt"
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(num_epochs):
    for i, (images, descriptions) in enumerate(dataloader):
        batch_size = images.size(0)

        images = images.to(device)
        descriptions = descriptions.to(device)

        # Get text embeddings
        text_embeddings = text_encoder(descriptions).detach()

        # Train Discriminator
        optimizer_d.zero_grad()

        # Real images
        real_labels = torch.ones(batch_size, 1).to(device)
        real_preds = discriminator(images, text_embeddings)
        d_loss_real = adversarial_loss(real_preds, real_labels)

        # Fake images
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(noise, text_embeddings)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        fake_preds = discriminator(fake_images.detach(), text_embeddings)
        d_loss_fake = adversarial_loss(fake_preds, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # Train Generator
        optimizer_g.zero_grad()

        # Generator loss
        fake_preds = discriminator(fake_images, text_embeddings)
        g_loss = adversarial_loss(fake_preds, real_labels)
        g_loss.backward()
        optimizer_g.step()

    print(f"Epoch [{epoch}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

    # Save checkpoints
    if epoch % 10 == 0:
        torch.save(generator.state_dict(), f'generator_epoch_{epoch}.pth')
        torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch}.pth')


FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\mural\\Desktop\\ACM GATIS Project\\CUB_200_2011\\attributes.txt'