In [1]:
!unzip smiley_ccr_128.zip

Archive:  smiley_ccr_128.zip
   creating: resizeopenmoji128/
  inflating: resizeopenmoji128/0023-FE0F-20E3.png  
  inflating: resizeopenmoji128/002A-FE0F-20E3.png  
  inflating: resizeopenmoji128/0030-FE0F-20E3.png  
  inflating: resizeopenmoji128/0031-FE0F-20E3.png  
  inflating: resizeopenmoji128/0032-FE0F-20E3.png  
  inflating: resizeopenmoji128/0033-FE0F-20E3.png  
  inflating: resizeopenmoji128/0034-FE0F-20E3.png  
  inflating: resizeopenmoji128/0035-FE0F-20E3.png  
  inflating: resizeopenmoji128/0036-FE0F-20E3.png  
  inflating: resizeopenmoji128/0037-FE0F-20E3.png  
  inflating: resizeopenmoji128/0038-FE0F-20E3.png  
  inflating: resizeopenmoji128/0039-FE0F-20E3.png  
  inflating: resizeopenmoji128/00A9.png  
  inflating: resizeopenmoji128/00AE.png  
  inflating: resizeopenmoji128/1F004.png  
  inflating: resizeopenmoji128/1F0CF.png  
  inflating: resizeopenmoji128/1F170.png  
  inflating: resizeopenmoji128/1F171.png  
  inflating: resizeopenmoji128/1F17E.png  
  inflating: res

In [2]:
import pandas as pd
import os

# Path to your CSV
file_path = '/content/metadata128.csv'  # change to your actual CSV path

# Load CSV
df = pd.read_csv(file_path)

# New base directory
new_base = "/content/resizeopenmoji128"

# Update the image_path column to use the new base + original filename
df['image_path'] = df['image_path'].apply(lambda p: os.path.join(new_base, os.path.basename(str(p))))

# Save updated CSV
updated_path = '/content/metadata128_updated.csv'
df.to_csv(updated_path, index=False)

print(f"Updated CSV saved to: {updated_path}")


Updated CSV saved to: /content/metadata128_updated.csv


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
import pandas as pd
import os
import numpy as np

# --- 1. DCGAN MODEL ARCHITECTURES FOR 128x128 ---
class Generator(nn.Module):
    def __init__(self, z_dim, num_classes, embedding_dim, img_channels, features_g=64):
        super(Generator, self).__init__()
        self.z_dim, self.embedding_dim = z_dim, embedding_dim
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)
        self.net = nn.Sequential(
            self._block(z_dim + embedding_dim, features_g * 16, 4, 1, 0), # 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),      # 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),       # 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),       # 32x32
            self._block(features_g * 2, features_g, 4, 2, 1),           # 64x64
            nn.ConvTranspose2d(features_g, img_channels, 4, 2, 1),      # 128x128
            nn.Tanh(),
        )
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True))
    def forward(self, noise, labels):
        label_embeddings = self.label_embedding(labels)
        x = torch.cat((noise, label_embeddings), -1).view(-1, self.z_dim + self.embedding_dim, 1, 1)
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self, num_classes, embedding_dim, img_channels, features_d=64):
        super(Discriminator, self).__init__()
        self.embedding_dim = embedding_dim
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)
        self.image_net = nn.Sequential(
            nn.Conv2d(img_channels + embedding_dim, features_d, 4, 2, 1), # Combine image and label at the start
            nn.LeakyReLU(0.2, inplace=True),
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            self._block(features_d * 8, features_d * 16, 4, 2, 1),
            nn.Conv2d(features_d * 16, 1, 4, 1, 0),
            nn.Sigmoid() # Use Sigmoid with BCELoss
        )
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2, inplace=True))
    def forward(self, img, labels):
        label_embeddings = self.label_embedding(labels).view(labels.shape[0], self.embedding_dim, 1, 1)
        label_features = label_embeddings.expand(-1, -1, img.shape[2], img.shape[3])
        x = torch.cat([img, label_features], dim=1)
        output = self.image_net(x)
        return output.view(x.size(0), -1)

# --- 2. DATASET AND UTILITY FUNCTIONS ---
class EmojiDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.metadata = dataframe.reset_index(drop=True)
        self.transform = transform
        self.shortcodes = self.metadata['shortcode'].unique()
        self.shortcode_to_idx = {c: i for i, c in enumerate(self.shortcodes)}
        self.num_classes = len(self.shortcodes)
    def __len__(self): return len(self.metadata)
    def __getitem__(self, idx):
        # Use the 'image_path' column which now points to the resized images
        img_path = self.metadata.loc[idx, 'image_path']
        shortcode = self.metadata.loc[idx, 'shortcode']
        try:
            image = Image.open(img_path).convert('RGB')
        except FileNotFoundError:
            return None, None
        label = self.shortcode_to_idx[shortcode]
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None and x[0] is not None, batch))
    if not batch: return torch.Tensor(), torch.Tensor()
    return torch.utils.data.dataloader.default_collate(batch)

def save_labeled_image_grid(images_tensor, labels, epoch, idx_to_shortcode, file_path="generated_images_128"):
    os.makedirs(file_path, exist_ok=True)
    num_images, img_size = images_tensor.shape[0], images_tensor.shape[2]
    nrow = int(np.sqrt(num_images))
    padding, text_h = 10, 20
    grid_size = nrow * (img_size + padding) + padding
    total_height = grid_size + nrow * text_h
    grid_img = Image.new('RGB', (grid_size, total_height), 'white')
    draw = ImageDraw.Draw(grid_img)
    try: font = ImageFont.truetype("Arial.ttf", 15)
    except IOError: font = ImageFont.load_default()
    for i in range(num_images):
        img_tensor = (images_tensor[i].detach().cpu() * 0.5) + 0.5
        img_pil = transforms.ToPILImage()(img_tensor)
        row, col = i // nrow, i % nrow
        x = col * (img_size + padding) + padding
        y = row * (img_size + padding + text_h) + padding
        grid_img.paste(img_pil, (x, y))
        shortcode = idx_to_shortcode[labels[i].item()]
        draw.text((x, y + img_size + 2), shortcode[:20], fill="black", font=font)
    grid_img.save(os.path.join(file_path, f"epoch_{epoch:03d}.png"))

# --- 3. MAIN EXECUTION BLOCK ---
if __name__ == '__main__':
    # Configuration
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    LR, BATCH_SIZE, IMAGE_SIZE = 4e-4, 16, 128 # Smaller batch size for larger images
    CHANNELS_IMG, Z_DIM, EMBEDDING_DIM, NUM_EPOCHS = 3, 100, 50, 200

    # Use the new 128x128 metadata file
    CSV_FILE = '/content/metadata128_updated.csv'

    # Step 1: Load and filter metadata
    print("Step 1: Loading and cleaning metadata...")
    try:
        metadata_df = pd.read_csv(CSV_FILE)
        metadata_df.dropna(subset=['unicode', 'shortcode', 'group'], inplace=True)
        groups_to_keep = [0.0]
        mask = metadata_df['group'].isin(groups_to_keep)
        metadata_df = metadata_df[mask].copy()
        print(f"  - Filtered dataset has {len(metadata_df)} samples.")
    except FileNotFoundError: print(f"FATAL: Metadata file not found at {CSV_FILE}"); exit()

    # Step 2: Setup Dataloader
    print("\nStep 2: Setting up data loader...")
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3),
    ])
    dataset = EmojiDataset(dataframe=metadata_df, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)
    if len(dataset) == 0: print("FATAL: The dataset is empty."); exit()
    NUM_CLASSES = dataset.num_classes
    idx_to_shortcode = {v: k for k, v in dataset.shortcode_to_idx.items()}
    print(idx_to_shortcode)

    # Step 3: Initialize Models
    print("\nStep 3: Initializing models...")
    generator = Generator(Z_DIM, NUM_CLASSES, EMBEDDING_DIM, CHANNELS_IMG).to(DEVICE)
    discriminator = Discriminator(NUM_CLASSES, EMBEDDING_DIM, CHANNELS_IMG).to(DEVICE)
    optimizer_G = optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.999))
    criterion = nn.BCELoss() # Using BCELoss as requested

    # Visualization setup
    fixed_noise = torch.randn(64, Z_DIM, device=DEVICE)
    fixed_labels = torch.randint(0, NUM_CLASSES, (64,), device=DEVICE)

    # Step 4: Training Loop
    print(f"\nStep 4: Starting training on {DEVICE}...")
    for epoch in range(NUM_EPOCHS):
        for batch_idx, (real_imgs, labels) in enumerate(dataloader):
            if real_imgs.nelement() == 0: continue
            real_imgs, labels = real_imgs.to(DEVICE), labels.to(DEVICE)
            # Label Smoothing
            valid = torch.full((real_imgs.shape[0], 1), 0.9, device=DEVICE)
            fake = torch.full((real_imgs.shape[0], 1), 0.1, device=DEVICE)

            # Train Discriminator
            optimizer_D.zero_grad()
            noise = torch.randn(real_imgs.shape[0], Z_DIM, device=DEVICE)
            fake_imgs = generator(noise, labels)
            d_real_loss = criterion(discriminator(real_imgs, labels), valid)
            d_fake_loss = criterion(discriminator(fake_imgs.detach(), labels), fake)
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward(); optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            g_loss = criterion(discriminator(fake_imgs, labels), valid)
            g_loss.backward(); optimizer_G.step()

            if batch_idx % 50 == 0:
                print(f"[Epoch {epoch+1:03d}/{NUM_EPOCHS}] [Batch {batch_idx:04d}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

        # Save visualization grid
        generator.eval()
        with torch.no_grad():
            sample_imgs = generator(fixed_noise, fixed_labels)
            save_labeled_image_grid(sample_imgs, fixed_labels, epoch + 1, idx_to_shortcode)
        generator.train()

    # Save final models
    torch.save(generator.state_dict(), 'generator_128.pth')
    torch.save(discriminator.state_dict(), 'discriminator_128.pth')
    print("\nTraining finished. 🎉")


Step 1: Loading and cleaning metadata...
  - Filtered dataset has 166 samples.

Step 2: Setting up data loader...

Step 3: Initializing models...





Step 4: Starting training on cuda...
[Epoch 001/200] [Batch 0000/11] [D loss: 0.6992] [G loss: 7.6160]
[Epoch 002/200] [Batch 0000/11] [D loss: 0.7896] [G loss: 0.7149]
[Epoch 003/200] [Batch 0000/11] [D loss: 0.7905] [G loss: 0.9946]
[Epoch 004/200] [Batch 0000/11] [D loss: 0.7380] [G loss: 1.0112]
[Epoch 005/200] [Batch 0000/11] [D loss: 0.7509] [G loss: 0.7890]
[Epoch 006/200] [Batch 0000/11] [D loss: 0.7403] [G loss: 0.6857]
[Epoch 007/200] [Batch 0000/11] [D loss: 0.7149] [G loss: 0.6695]
[Epoch 008/200] [Batch 0000/11] [D loss: 0.7182] [G loss: 0.6544]
[Epoch 009/200] [Batch 0000/11] [D loss: 0.7070] [G loss: 0.7273]
[Epoch 010/200] [Batch 0000/11] [D loss: 0.6941] [G loss: 0.6750]
[Epoch 011/200] [Batch 0000/11] [D loss: 0.7131] [G loss: 0.7542]
[Epoch 012/200] [Batch 0000/11] [D loss: 0.8459] [G loss: 0.8688]
[Epoch 013/200] [Batch 0000/11] [D loss: 0.8113] [G loss: 0.8476]
[Epoch 014/200] [Batch 0000/11] [D loss: 0.7055] [G loss: 0.7314]
[Epoch 015/200] [Batch 0000/11] [D los

In [5]:
if __name__ == '__main__':
    # Configuration
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    LR, BATCH_SIZE, IMAGE_SIZE = 4e-4, 16, 128 # Smaller batch size for larger images
    CHANNELS_IMG, Z_DIM, EMBEDDING_DIM, NUM_EPOCHS = 3, 100, 50, 200

    # Use the new 128x128 metadata file
    CSV_FILE = '/content/metadata128_updated.csv'

    # Step 1: Load and filter metadata
    print("Step 1: Loading and cleaning metadata...")
    try:
        metadata_df = pd.read_csv(CSV_FILE)
        metadata_df.dropna(subset=['unicode', 'shortcode', 'group'], inplace=True)
        groups_to_keep = [0.0]
        mask = metadata_df['group'].isin(groups_to_keep)
        metadata_df = metadata_df[mask].copy()
        print(f"  - Filtered dataset has {len(metadata_df)} samples.")
    except FileNotFoundError: print(f"FATAL: Metadata file not found at {CSV_FILE}"); exit()

    # Step 2: Setup Dataloader
    print("\nStep 2: Setting up data loader...")
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3),
    ])
    dataset = EmojiDataset(dataframe=metadata_df, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)
    if len(dataset) == 0: print("FATAL: The dataset is empty."); exit()
    NUM_CLASSES = dataset.num_classes
    idx_to_shortcode = {v: k for k, v in dataset.shortcode_to_idx.items()}
    print(idx_to_shortcode)

Step 1: Loading and cleaning metadata...
  - Filtered dataset has 166 samples.

Step 2: Setting up data loader...
{0: 'grinning', 1: 'smiley', 2: 'smile', 3: 'grin', 4: "['laughing', 'satisfied']", 5: 'sweat_smile', 6: 'rofl', 7: 'joy', 8: 'slightly_smiling_face', 9: 'upside_down_face', 10: 'melting_face', 11: 'wink', 12: 'blush', 13: 'innocent', 14: 'smiling_face_with_three_hearts', 15: 'heart_eyes', 16: 'star_struck', 17: 'kissing_heart', 18: 'kissing', 19: 'relaxed', 20: 'kissing_closed_eyes', 21: 'kissing_smiling_eyes', 22: 'smiling_face_with_tear', 23: 'yum', 24: 'stuck_out_tongue', 25: 'stuck_out_tongue_winking_eye', 26: 'zany_face', 27: 'stuck_out_tongue_closed_eyes', 28: 'money_mouth_face', 29: 'hugs', 30: 'hand_over_mouth', 31: 'face_with_open_eyes_and_hand_over_mouth', 32: 'face_with_peeking_eye', 33: 'shushing_face', 34: 'thinking', 35: 'saluting_face', 36: 'zipper_mouth_face', 37: 'raised_eyebrow', 38: 'neutral_face', 39: 'expressionless', 40: 'no_mouth', 41: 'dotted_line_f



In [9]:
import json

# Save the idx_to_shortcode dictionary to a JSON file
with open('idx_to_shortcode.json', 'w') as f:
    json.dump(idx_to_shortcode, f, indent=4)

print("idx_to_shortcode dictionary saved to idx_to_shortcode.json")

idx_to_shortcode dictionary saved to idx_to_shortcode.json


In [16]:
import json, ast, difflib, os
import torch
from torchvision.utils import save_image
# using emoji labels


# -----------------------------
# Config you likely already have
# -----------------------------
# Assumes you’ve already defined/loaded:
#   generator : your trained conditional generator
#   DEVICE    : e.g., torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#   Z_DIM     : latent noise dim used during training
#   (EMBEDDING_DIM is inferred from the model below)

# Example (uncomment / edit for your env):
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# generator = YourGeneratorClass(...)
# generator.load_state_dict(torch.load('generator_128.pth', map_location=DEVICE))
# generator.to(DEVICE).eval()
# Z_DIM = 128  # set to your training z-dim

# -----------------------------
# Text → Label-ID mapping utils
# -----------------------------
def _norm(s: str) -> str:
    """normalize user text & shortcodes."""
    s = s.strip().lower()
    if s.startswith(':') and s.endswith(':'):
        s = s[1:-1]
    s = s.replace('-', '_')
    s = ' '.join(s.replace('_', ' ').split())  # collapse spaces/underscores
    return s

def build_name_to_id_map(json_path: str):
    """
    Load idx_to_shortcode.json and build a dict:
      name_variant -> int_id
    Supports single strings or list-like strings in the json values.
    """
    with open(json_path, 'r', encoding='utf-8') as f:
        raw = json.load(f)

    name2id = {}
    for k, v in raw.items():
        idx = int(k)
        candidates = []

        if isinstance(v, list):
            candidates = v
        elif isinstance(v, str):
            # Some entries are stringified lists like "['laughing', 'satisfied']"
            parsed = None
            try:
                parsed = ast.literal_eval(v)
            except Exception:
                parsed = None
            if isinstance(parsed, list):
                candidates = parsed
            else:
                candidates = [v]
        else:
            continue

        for cand in candidates:
            if not isinstance(cand, str):
                continue
            # add multiple variants (underscore and space)
            norm = _norm(cand)
            if norm:
                name2id[norm] = idx
                name2id[norm.replace(' ', '_')] = idx
                name2id[norm.replace(' ', '')] = idx
                # also add colon-wrapped shortcode variant
                name2id[f":{norm.replace(' ', '_')}:"] = idx
    return name2id

def text_to_id(text: str, name2id: dict, topn_fuzzy: int = 3):
    """Exact match first; otherwise fuzzy match on keys."""
    t = _norm(text)
    # quick hits
    for key in (t, t.replace(' ', '_'), t.replace(' ', ''), f":{t.replace(' ', '_')}:"):
        if key in name2id:
            return name2id[key], key, None
    # fuzzy
    keys = list(name2id.keys())
    matches = difflib.get_close_matches(t, keys, n=topn_fuzzy, cutoff=0.6)
    if matches:
        best = matches[0]
        return name2id[best], best, matches
    return None, None, []

# -----------------------------
# Interpolation by text names
# -----------------------------
def interpolate_by_text(
    start_text: str,
    end_text: str,
    idx_json_path: str = "idx_to_shortcode.json",
    steps: int = 30,
    outdir: str = "interpolations",
    outfile_prefix: str = None
):
    os.makedirs(outdir, exist_ok=True)

    name2id = build_name_to_id_map(idx_json_path)

    start_id, start_key, start_suggestions = text_to_id(start_text, name2id)
    end_id, end_key, end_suggestions = text_to_id(end_text, name2id)

    def _suggest_msg(txt, sugg):
        return f" (did you mean: {', '.join(sugg)})" if sugg else ""

    if start_id is None:
        raise ValueError(f"Could not resolve start_text='{start_text}' to a label id."
                         + _suggest_msg(start_text, start_suggestions))
    if end_id is None:
        raise ValueError(f"Could not resolve end_text='{end_text}' to a label id."
                         + _suggest_msg(end_text, end_suggestions))

    print(f"Resolved: '{start_text}' -> id {start_id} ({start_key}), "
          f"'{end_text}' -> id {end_id} ({end_key})")

    # Infer embedding dim directly from the model
    # Assumes generator has attribute .label_embedding similar to nn.Embedding
    embedding_weight = generator.label_embedding.weight
    EMBEDDING_DIM = embedding_weight.shape[1]

    # Prepare noise (same noise across the strip)
    noise = torch.randn(1, Z_DIM, device=DEVICE).repeat(steps, 1)

    # Get embeddings for start & end
    labels_start = torch.tensor([start_id], device=DEVICE)
    labels_end   = torch.tensor([end_id],   device=DEVICE)
    with torch.no_grad():
        embedding_start = generator.label_embedding(labels_start)  # [1, E]
        embedding_end   = generator.label_embedding(labels_end)    # [1, E]

    # Interpolate label embeddings
    interpolated_embeddings = torch.zeros(steps, EMBEDDING_DIM, device=DEVICE)
    for i in range(steps):
        ratio = i / (steps - 1) if steps > 1 else 0.0
        interpolated_embeddings[i] = torch.lerp(embedding_start, embedding_end, ratio)

    # Generate without grads
    with torch.no_grad():
        gan_input = torch.cat((noise, interpolated_embeddings), dim=-1)  # [steps, Z+E]
        gan_input = gan_input.view(-1, Z_DIM + EMBEDDING_DIM, 1, 1)
        # If your generator uses .net as the main forward block:
        if hasattr(generator, "net"):
            interpolated_images = generator.net(gan_input)
        else:
            # or just call forward if your forward expects (z+embed)
            interpolated_images = generator(gan_input)

    # Filename
    if outfile_prefix is None:
        outfile_prefix = f"{start_key}_to_{end_key}".replace(':','').replace(' ','_')
    outpath = os.path.join(outdir, f"interpolation_{outfile_prefix}.png")

    save_image(
        interpolated_images,
        outpath,
        nrow=steps,
        normalize=True
    )
    print(f"Saved: {outpath}")
    return outpath

# -----------------------------
# Example usage
# -----------------------------
if __name__ == "__main__":
    # Edit these to whatever the user types:
    START_TEXT = "rofl"
    END_TEXT   = "money_mouth_face"
    STEPS      = 10

    # Make sure generator, DEVICE, Z_DIM are defined/loaded before this call.
    interpolate_by_text(START_TEXT, END_TEXT, "idx_to_shortcode.json", STEPS)


Resolved: 'rofl' -> id 6 (rofl), 'money_mouth_face' -> id 28 (money mouth face)
Saved: interpolations/interpolation_rofl_to_money_mouth_face.png


In [17]:
import torch
from torchvision.utils import save_image

# using emoji labels code

generator.load_state_dict(torch.load('generator_128.pth'))
generator.eval()

TARGET_LABEL_ID_START = 5  # The ID of the starting emoji
TARGET_LABEL_ID_END = 55 # The ID of the ending emoji
INTERPOLATION_STEPS = 30   # Number of images to generate in the sequence

# --- Interpolation Logic ---
print("Generating latent space interpolation...")

# Use the same noise for the whole sequence to see the label's effect
noise = torch.randn(1, Z_DIM, device=DEVICE).repeat(INTERPOLATION_STEPS, 1)

# Get the learned embeddings for the start and end labels from the generator
labels_start = torch.tensor([TARGET_LABEL_ID_START], device=DEVICE)
labels_end = torch.tensor([TARGET_LABEL_ID_END], device=DEVICE)
embedding_start = generator.label_embedding(labels_start)
embedding_end = generator.label_embedding(labels_end)

# Linearly interpolate between the two label embeddings
interpolated_embeddings = torch.zeros(INTERPOLATION_STEPS, EMBEDDING_DIM, device=DEVICE)
for i in range(INTERPOLATION_STEPS):
    ratio = i / (INTERPOLATION_STEPS - 1)
    interpolated_embeddings[i] = torch.lerp(embedding_start, embedding_end, ratio)

# Generate the sequence without calculating gradients
with torch.no_grad():
    # Manually combine noise and the interpolated embeddings
    gan_input = torch.cat((noise, interpolated_embeddings), -1).view(-1, Z_DIM + EMBEDDING_DIM, 1, 1)
    # Pass through the generator's network
    interpolated_images = generator.net(gan_input)

# Save the resulting image strip
save_image(
    interpolated_images,
    f"interpolation_from_{TARGET_LABEL_ID_START}_to_{TARGET_LABEL_ID_END}.png",
    nrow=INTERPOLATION_STEPS,
    normalize=True
)

print(f"Saved interpolation image strip.")

Generating latent space interpolation...
Saved interpolation image strip.


In [None]:
# using label id
import torch
from torchvision.utils import save_image

generator.load_state_dict(torch.load('generator_128.pth'))
generator.eval()
# --- Configuration ---
TARGET_LABEL_ID = 5 # The ID of the emoji you want to see variations of
NUM_SAMPLES = 16     # The number of images to generate (should be a perfect square, e.g., 9, 16, 25)

# --- Diversity Generation Logic ---
print(f"Generating {NUM_SAMPLES} diverse samples for label ID: {TARGET_LABEL_ID}...")

# Create different noise for each sample
noise = torch.randn(NUM_SAMPLES, Z_DIM, device=DEVICE)
# Use the same label for all samples
labels = torch.full((NUM_SAMPLES,), TARGET_LABEL_ID, dtype=torch.long, device=DEVICE)

# Generate images
with torch.no_grad():
    generated_images = generator(noise, labels)

# Save the resulting image grid
save_image(
    generated_images,
    f"diversity_for_label_{TARGET_LABEL_ID}.png",
    nrow=int(NUM_SAMPLES**0.5), # Arrange in a square grid
    normalize=True
)

print(f"Saved visual diversity grid.")

In [25]:
# using label name
import os
import torch
from torchvision.utils import save_image

# If the helpers are in another module, import them instead:
# from your_helpers import build_name_to_id_map, text_to_id

def generate_diversity_by_text_or_id(
    target,                 # e.g. 5 or "sweat_smile" or ":sweat_smile:"
    num_samples=1,
    idx_json_path="idx_to_shortcode.json",
    outdir="diversity"
):
    """
    Generate a grid of samples for a single label specified by text or id.
    Reuses the generator(noise, labels) interface when available; otherwise
    falls back to concatenating noise+embedding and calling generator.net.
    """
    os.makedirs(outdir, exist_ok=True)

    # Resolve label id ---------------------------------------------------------
    if isinstance(target, int):
        target_id = target
        resolved_key = str(target_id)
    else:
        # Build mapping and resolve the text
        name2id = build_name_to_id_map(idx_json_path)
        target_id, key_used, suggestions = text_to_id(str(target), name2id)
        if target_id is None:
            msg = f"Could not resolve '{target}' to a label id."
            if suggestions:
                msg += f" Did you mean: {', '.join(suggestions)} ?"
            raise ValueError(msg)
        resolved_key = key_used

    print(f"Generating {num_samples} samples for label id {target_id} ({resolved_key})")

    # Noise + labels -----------------------------------------------------------
    noise  = torch.randn(num_samples, Z_DIM, device=DEVICE)
    labels = torch.full((num_samples,), target_id, dtype=torch.long, device=DEVICE)

    # Generate -----------------------------------------------------------------
    with torch.no_grad():
        try:
            # Preferred path: your model supports (noise, labels)
            generated_images = generator(noise, labels)
        except TypeError:
            # Fallback: manually concat noise + label embedding and call main net
            emb = generator.label_embedding(labels)                     # [N, E]
            gan_in = torch.cat([noise, emb], dim=1).view(num_samples, -1, 1, 1)
            if hasattr(generator, "net"):
                generated_images = generator.net(gan_in)
            else:
                generated_images = generator(gan_in)  # if forward accepts concatenated input

    # Save grid ----------------------------------------------------------------
    nrow = int(num_samples ** 0.5)
    out_path = os.path.join(outdir, f"diversity_{resolved_key}_{target_id}.png")
    save_image(generated_images, out_path, nrow=nrow, normalize=True)
    print(f"Saved visual diversity grid to {out_path}")
    return out_path


# ----------------- Example usage (replace to your needs) -----------------
# Either text:
generate_diversity_by_text_or_id("smirk", idx_json_path="idx_to_shortcode.json")

# Or raw id (old behavior):
# generate_diversity_by_text_or_id(5, num_samples=16, idx_json_path="idx_to_shortcode.json")


Generating 1 samples for label id 43 (smirk)
Saved visual diversity grid to diversity/diversity_smirk_43.png


'diversity/diversity_smirk_43.png'