<a href="https://colab.research.google.com/github/staerkjoe/AdvNLP_miniproject/blob/main/NLP_MiniProject.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch torchvision transformers ftfy regex tqdm
!pip install accelerate
!pip install pytorch-pretrained-biggan

## CLIP - Load and Test

In [None]:
from transformers import CLIPProcessor, CLIPModel

# Load pretrained CLIP (text + image encoders)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

device = "cuda"
clip_model = clip_model.to(device)


In [None]:
from PIL import Image

text = ["a red apple", "a blue car"]
image = [Image.open("/content/drive/MyDrive/Colab Notebooks/NLP/MiniProject/Sample Pictures/redapple.jpg"), Image.open("/content/drive/MyDrive/Colab Notebooks/NLP/MiniProject/Sample Pictures/airplane.jpg")]

inputs = clip_processor(text=text, images=image, return_tensors="pt", padding=True).to(device)
outputs = clip_model(**inputs)
similarity = outputs.logits_per_text.softmax(dim=-1)


In [None]:
print(similarity)

## GAN - Load and Test

In [None]:
from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load pretrained model
biggan = BigGAN.from_pretrained('biggan-deep-128').to(device)
biggan.eval()


In [None]:
# BigGAN takes: noise (z), class vector (one-hot), truncation
z = torch.randn(1, 128, device=device)        # latent noise
class_vector = torch.zeros(1, 1000, device=device)  # "generic" input
truncation = 0.4

with torch.no_grad():
    output = biggan(z, class_vector, truncation)


In [None]:
from torchvision.utils import save_image
save_image((output + 1) / 2, "/content/drive/MyDrive/Colab Notebooks/NLP/MiniProject/Sample Pictures/GANgenerated.jpg")  # normalize to [0,1]


## optimizing one latent vector z per text prompt

In [None]:
def generate_image(z, truncation=0.4):
    class_vector = torch.zeros(1, 1000).to(device)  # generic class
    with torch.no_grad():
        out = biggan(z, class_vector, truncation)
    return (out + 1) / 2  # scale from [-1,1] → [0,1]

In [None]:
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm

def optimize_latent_for(prompt, steps=300, lr=0.07, truncation=0.4, seed=None):
    if seed: torch.manual_seed(seed)

    # Encode text prompt once
    text_tokens = clip_processor(text=[prompt], return_tensors="pt").to(device)
    with torch.no_grad():
        text_features = clip_model.get_text_features(**text_tokens).detach().clone()

    # Initialize random latent vector z
    z = torch.randn(1, 128, device=device, requires_grad=True)
    optimizer = torch.optim.Adam([z], lr=lr)

    # Optimization loop
    pbar = tqdm(range(steps), desc=f"Optimizing for: '{prompt}'")
    for step in pbar:
        optimizer.zero_grad()
        img = biggan(z, torch.zeros(1, 1000).to(device), truncation)
        img_norm = (img + 1) / 2

        # Resize to CLIP expected input
        img_pil = torch.clamp(img_norm, 0, 1)
        img_clip = torch.nn.functional.interpolate(img_pil, size=(224,224), mode='bilinear')

        # Preprocess the image for CLIP
        img_processed = clip_processor(images=img_clip, return_tensors="pt").to(device)

        img_features = clip_model.get_image_features(**img_processed)
        loss = 1 - F.cosine_similarity(img_features, text_features).mean()

        loss.backward()
        optimizer.step()
        pbar.set_postfix({"loss": loss.item()})

    final_img = generate_image(z, truncation)
    return final_img.detach()

In [None]:
prompts = [
    "a red apple on a table",
    "a blue car in the snow",
    "a yellow bird on a branch",
]

for text in prompts:
    result = optimize_latent_for(text, steps=250)
    filename = f"{text.replace(' ', '_')}.png"
    save_image(result, filename)
    print(f"Saved: {filename}")


In [None]:
import matplotlib.pyplot as plt
from torchvision.io import read_image

for text in prompts:
    img = read_image(f"{text.replace(' ', '_')}.png").permute(1,2,0)
    plt.imshow(img.cpu())
    plt.title(text)
    plt.axis('off')
    plt.show()
