# Import Libraries

In [1]:
pip install pytorch_msssim

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import time
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import numpy as np
from pytorch_msssim import ssim

# Define hyperparameters

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Training classifier params
CLS_EPOCHS = 5
CLS_BATCH = 128
CLS_LR = 0.1

# Generator / inversion params
GEN_EPOCHS = 10
GEN_BATCH = 64
latent_dim = 100
cond_dim = 10
F_ch = 64                 # channels at 10x10 stage
alpha, beta, gamma, delta = 100.0, 100.0, 1000.0, 1000.0

# Other
NUM_CLASSES = 10
SAMPLES_PER_CLASS = 20    # for evaluation
SAVE_MODEL_PATH = "vit_cifar10.pth"
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

Device: cuda


# Data loaders

In [4]:
transform = transforms.Compose([
    transforms.Resize(32),     # Resize MNIST to 32x32 so generator architecture can remain same
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False)


In [5]:
len(trainloader)

391

# Classifier (ViT)

In [6]:
class ViTClassifier(nn.Module):
    def __init__(self, num_classes=10, pretrained=True):
        super().__init__()
        # Load pretrained Vision Transformer (ImageNet)
        weights = models.ViT_B_16_Weights.IMAGENET1K_V1 if pretrained else None
        self.vit = models.vit_b_16(weights=weights)

        # Replace classification head
        in_features = self.vit.heads.head.in_features
        self.vit.heads.head = nn.Identity() # Replace with Identity to get features before the head

        # Optional: adapt for 32x32 (resize patch embedding)
        self.resize = nn.Upsample(size=(224, 224), mode="bilinear", align_corners=False)
        self.classifier_head = nn.Linear(in_features, num_classes) # New classification head

    def forward(self, x, return_features=False):
        # Upsample small CIFAR-10 image to 224x224 expected by ViT
        x = self.resize(x)

        # Extract penultimate features
        features = self.vit(x) # Get features after the identity head
        logits = self.classifier_head(features) # Pass features through the new head


        if return_features:
            return logits, features
        return logits



# Instantiate
classifier = ViTClassifier(num_classes=NUM_CLASSES).to(device)

# Classifier training utilities


In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=3e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
    
def train_classifier_one_epoch(epoch):
    classifier.train()
    running_loss, correct, total = 0.0, 0, 0
    start = time.time()
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), (targets).to(device)
        optimizer.zero_grad()
        outputs = classifier(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * targets.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    train_loss = running_loss / total
    train_acc = 100. * correct / total
    print(f"[CLS] Epoch {epoch} Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% Time: {time.time()-start:.1f}s")
    return train_loss, train_acc

def test_classifier(epoch):
    classifier.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), (targets).to(device)
            outputs = classifier(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * targets.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    test_loss = running_loss / total
    acc = 100. * correct / total
    print(f"[CLS] Epoch {epoch} Test Loss: {test_loss:.4f} Acc: {acc:.2f}%")
    return test_loss, acc

# Train classifier


In [8]:
def train_classifier(epochs=CLS_EPOCHS):
    best_acc = 0.0
    train_losses, train_accs = [], []
    test_losses, test_accs = [], []
    for ep in range(1, epochs+1):
        tr_loss, tr_acc = train_classifier_one_epoch(ep)
        te_loss, te_acc = test_classifier(ep)
        scheduler.step()
        train_losses.append(tr_loss); train_accs.append(tr_acc)
        test_losses.append(te_loss); test_accs.append(te_acc)
        if te_acc > best_acc:
            best_acc = te_acc
            torch.save(classifier.state_dict(), SAVE_MODEL_PATH)
            print(f"[CLS] Saved best model (acc={best_acc:.2f}%)")
    # plot training curves
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.plot(range(1,len(train_accs)+1), train_accs, label='Train Acc')
    plt.plot(range(1,len(test_accs)+1), test_accs, label='Test Acc')
    plt.title("Accuracy")
    plt.legend()
    plt.subplot(1,2,2)
    plt.plot(range(1,len(train_losses)+1), train_losses, label='Train Loss')
    plt.plot(range(1,len(test_losses)+1), test_losses, label='Test Loss')
    plt.title("Loss")
    plt.legend()
    plt.show()

# Only train classifier if model file not present
if not os.path.isfile(SAVE_MODEL_PATH):
    print("Training classifier from scratch...")
    train_classifier(epochs=CLS_EPOCHS)
else:
    print("Found saved classifier. Loading...")
    classifier.load_state_dict(torch.load(SAVE_MODEL_PATH, map_location=device))

# Freeze classifier (but allow autograd to flow into generator through the classifier's ops)
for p in classifier.parameters():
    p.requires_grad = False
classifier.eval()

Training classifier from scratch...


OutOfMemoryError: CUDA out of memory. Tried to allocate 222.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 74.12 MiB is free. Process 8534 has 14.67 GiB memory in use. Of the allocated memory 14.52 GiB is allocated by PyTorch, and 24.12 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Conditioning utils & precompute hot matrices


In [None]:
# We will pass the **softmaxed** conditioning vector INTO the generator (so it's a distribution).
def sample_vector_condition(batch, num_classes, device):
    v = torch.randn(batch, num_classes, device=device)
    P = F.softmax(v, dim=1)      # conditioning distribution P (batch, num_classes)
    labels = P.argmax(dim=1)     # integer labels used in CE target
    return P, P, labels         # return vector_cond (P), P_target, labels

# Precompute 'hot' matrices for all classes and reuse
_precomputed_hot = torch.zeros(NUM_CLASSES, 1, NUM_CLASSES, NUM_CLASSES, device=device)
for k in range(NUM_CLASSES):
    mat = torch.zeros(NUM_CLASSES, NUM_CLASSES, device=device)
    mat[k, :] = 1.0
    mat[:, k] = 1.0
    _precomputed_hot[k, 0] = mat

def build_hot_matrix(labels, N=NUM_CLASSES, device=device):
    # labels: (batch,) ints on device
    # returns (batch,1,10,10) by creating appropriate matrices
    batch_size = labels.size(0)
    hot_mats = torch.zeros(batch_size, 1, 10, 10, device=device)

    for i, label in enumerate(labels):
        # Create a 10x10 matrix with the label information
        # You can experiment with different patterns here
        mat = torch.zeros(10, 10, device=device)

        # Simple pattern: set rows and columns corresponding to label
        mat[label, :] = 1.0
        mat[:, label] = 1.0

        hot_mats[i, 0] = mat

    return hot_mats

# Generator



In [None]:
# Fix the generator architecture first
class Generator(nn.Module):
    def __init__(self, latent_dim=latent_dim, cond_dim=cond_dim, F=F_ch, out_ch=1):
        super().__init__()
        self.latent_dim = latent_dim
        self.cond_dim = cond_dim

        # Initial projection to 4x4
        self.convT1 = nn.ConvTranspose2d(latent_dim + cond_dim, F * 8, kernel_size=4, stride=1, padding=0)  # 1x1 -> 4x4
        self.bn1 = nn.BatchNorm2d(F * 8)

        # 4x4 -> 8x8
        self.convT2 = nn.ConvTranspose2d(F * 8 + 1, F * 4, kernel_size=4, stride=2, padding=1)  # +1 for hot channel
        self.bn2 = nn.BatchNorm2d(F * 4)

        # 8x8 -> 16x16
        self.convT3 = nn.ConvTranspose2d(F * 4, F * 2, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(F * 2)

        # 16x16 -> 32x32
        self.convT4 = nn.ConvTranspose2d(F * 2, F, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(F)

        # Final convolution to get 3 channels
        self.final_conv = nn.Conv2d(F, out_ch, kernel_size=3, stride=1, padding=1)
        self.dropout = nn.Dropout2d(0.2)  # Reduced dropout

    def forward(self, z, vec_cond, hot_mat):
        # z: (B, latent), vec_cond: (B, cond_dim), hot_mat: (B,1,10,10)
        x = torch.cat([z, vec_cond], dim=1)           # (B, latent+cond)
        x = x.view(x.size(0), x.size(1), 1, 1)        # (B, latent+cond,1,1)

        # 1x1 -> 4x4
        x = F.leaky_relu(self.bn1(self.convT1(x)), 0.2)

        # Resize hot_mat to match current feature size (4x4)
        hot_resized = F.interpolate(hot_mat, size=(4, 4), mode='bilinear', align_corners=False)
        x = torch.cat([x, hot_resized], dim=1)

        # 4x4 -> 8x8
        x = F.leaky_relu(self.bn2(self.convT2(x)), 0.2)
        x = self.dropout(x)

        # 8x8 -> 16x16
        x = F.leaky_relu(self.bn3(self.convT3(x)), 0.2)

        # 16x16 -> 32x32
        x = F.leaky_relu(self.bn4(self.convT4(x)), 0.2)

        img = torch.tanh(self.final_conv(x))  # range [-1,1]
        return img

# Loss utilities

In [None]:
def kl_divergence_targetP(logits, P_target):
    # D_KL(P_target || Q_pred)
    # logits: (B,C); P_target: (B,C)
    log_softmax = F.log_softmax(logits, dim=1)
    kl_div = F.kl_div(log_softmax, P_target, reduction='batchmean', log_target=False)
    return kl_div

def cosine_similarity_loss(features):
    # features: (B, D)
    eps = 1e-8
    f = features / (features.norm(dim=1, keepdim=True) + eps)  # normalize
    sim = f @ f.T  # (B,B)
    B = features.size(0)
    off_diag_sum = sim.sum() - B  # subtract diagonal ones
    denom = B * (B - 1)
    return off_diag_sum / (denom + 1e-12)

def orthogonality_loss(features):
    # make normalized features orthogonal (Gram approx I)
    eps = 1e-8
    f = features / (features.norm(dim=1, keepdim=True) + eps)
    G = f @ f.T
    B = features.size(0)
    I = torch.eye(B, device=features.device)
    return ((G - I)**2).mean()
def ssim_loss(imgs, ref_imgs):
    """
    Structural Similarity (SSIM) loss between generated and reference images.
    Both tensors in range [-1, 1].
    Returns (1 - SSIM) averaged over the batch.
    """
    return 1 - ssim(imgs, ref_imgs, data_range=2.0, size_average=True)
# Improved loss function with better balancing
def compute_inversion_loss(
    logits, features, P_target, labels,
    imgs=None, ref_imgs=None,
    alpha=1.0, beta=1.0, gamma=0.01, delta=0.01, eta=0.5
):
    """
    Total inversion loss = α * KL + β * CE + γ * CosSim + δ * Orthogonality + η * (1 - SSIM)
    """

    # Core inversion losses
    L_kl = kl_divergence_targetP(logits, P_target)
    L_ce = F.cross_entropy(logits, labels)
    L_cos = cosine_similarity_loss(features)
    L_ortho = orthogonality_loss(features)

    # SSIM Loss (optional, requires reference images)
    if imgs is not None and ref_imgs is not None:
        L_ssim = ssim_loss(imgs, ref_imgs)
    else:
        L_ssim = torch.tensor(0.0, device=logits.device)

    total_loss = (
        alpha * L_kl +
        beta * L_ce +
        gamma * L_cos +
        delta * L_ortho +
        eta * L_ssim
    )

    return {
        "L_kl": L_kl,
        "L_ce": L_ce,
        "L_cos": L_cos,
        "L_ortho": L_ortho,
        "L_ssim": L_ssim,
        "total": total_loss
    }



# Inversion training


In [None]:
alpha, beta, gamma, delta, eta = 100.0, 200.0, 1000.0, 1000.0, 1000.0  # tuned weights
gen = Generator().to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=1e-4, betas=(0.5, 0.999))

def train_generator(epochs=GEN_EPOCHS, batch_size=GEN_BATCH):
    gen.train()

    for ep in range(1, epochs + 1):
        epoch_loss = epoch_Lkl = epoch_Lce = epoch_Lcos = epoch_Lortho = epoch_Lssim = 0.0
        t0 = time.time()
        steps = 200

        for step in range(steps):
            # Sample conditioning
            vec_cond, P_target, labels = sample_vector_condition(batch_size, NUM_CLASSES, device)
            z = torch.randn(batch_size, latent_dim, device=device)
            hot = build_hot_matrix(labels, N=NUM_CLASSES, device=device)

            # Generate images
            imgs = gen(z, vec_cond, hot)

            # Create structural reference (slightly perturbed copy)
            ref_imgs = imgs + 0.02 * torch.randn_like(imgs)

            # Forward through frozen classifier
            logits, features = classifier(imgs, return_features=True)

            # Compute combined inversion + SSIM loss
            loss_dict = compute_inversion_loss(
                logits, features, P_target, labels,
                imgs=imgs, ref_imgs=ref_imgs,
                alpha=alpha, beta=beta, gamma=gamma, delta=delta, eta=eta
            )

            total_loss = loss_dict["total"] # Rename the variable to avoid conflict

            gen_opt.zero_grad()
            total_loss.backward() # Use the renamed variable
            torch.nn.utils.clip_grad_norm_(gen.parameters(), max_norm=1.0)
            gen_opt.step()

            # Accumulate stats
            epoch_loss += total_loss.item()
            epoch_Lkl += loss_dict["L_kl"].item()
            epoch_Lce += loss_dict["L_ce"].item()
            epoch_Lcos += loss_dict["L_cos"].item()
            epoch_Lortho += loss_dict["L_ortho"].item()
            epoch_Lssim += loss_dict["L_ssim"].item()


        # Epoch summary
        steps = float(steps)
        print(f"[GEN] Ep {ep}/{epochs} | Loss {epoch_loss/steps:.4f} | "
              f"KL {epoch_Lkl/steps:.4f} | CE {epoch_Lce/steps:.4f} | "
              f"COS {epoch_Lcos/steps:.4f} | ORT {epoch_Lortho/steps:.4f} | "
              f"SSIM {epoch_Lssim/steps:.4f} | time {time.time()-t0:.1f}s")

        # Evaluate & visualize every few epochs
        if ep % 10 == 0:
            acc = inversion_accuracy(gen, classifier)
            print(f"[GEN] Inversion Accuracy: {acc:.2f}%")
            show_generated_images(gen, classifier, samples_per_class=4)
            torch.save(gen.state_dict(), f"gen_epoch{ep}.pth")

# Evaluation


In [None]:
def inversion_accuracy(
    gen,
    classifier,
    num_classes=NUM_CLASSES,
    n_per_class=100,
    latent_dim=latent_dim,
    batch_size=32
):
    """
    Inversion Accuracy refers to the percentage of images generated with desired labels same as the output labels from the classifier.
    """
    gen.eval()
    classifier.eval()

    correct, total = 0, 0

    with torch.no_grad():
        for cls in range(num_classes):
            for start in range(0, n_per_class, batch_size):
                bsz = min(batch_size, n_per_class - start)
                z = torch.randn(bsz, latent_dim, device=device)
                c = torch.zeros(bsz, num_classes, device=device)
                c[:, cls] = 1.0
                hot = build_hot_matrix(torch.full((bsz,), cls, device=device), N=num_classes, device=device)

                imgs = gen(z, c, hot)
                logits, _ = classifier(imgs, return_features=True)
                preds = logits.argmax(dim=1)

                correct += (preds == cls).sum().item()
                total += bsz

    return 100.0 * correct / total


def show_generated_images(
    gen,
    classifier,
    latent_dim=latent_dim,
    num_classes=NUM_CLASSES,
    samples_per_class=8
):
    gen.eval()
    classifier.eval()

    with torch.no_grad():
        fig, axes = plt.subplots(num_classes, samples_per_class, figsize=(samples_per_class, num_classes))

        for cls in range(num_classes):
            # 1. Sample Latent + Conditioning
            z = torch.randn(samples_per_class, latent_dim, device=device)
            c = torch.zeros(samples_per_class, num_classes, device=device)
            c[:, cls] = 1.0  # one-hot class conditioning
            hot = build_hot_matrix(torch.full((samples_per_class,), cls, device=device),
                                   N=num_classes, device=device)

            # 2. Generate Images
            imgs = gen(z, c, hot)

            # 3. Classifier Prediction (batch-based → efficient)
            logits, _ = classifier(imgs, return_features=True)
            preds = logits.argmax(dim=1)

            # 4. Convert to displayable format
            imgs = (imgs + 1.0) / 2.0
            imgs = imgs.clamp(0, 1)

            for i in range(samples_per_class):
                ax = axes[cls, i] if num_classes > 1 else axes[i]
                img = imgs[i].cpu().squeeze().numpy()
                plt.imshow(img, cmap='gray')


                ax.imshow(img)
                ax.axis("off")

                true_label = cls
                pred_label = preds[i].item()
                color = "green" if true_label == pred_label else "red"

                ax.set_title(f"T:{true_label} | P:{pred_label}", fontsize=8, color=color)

        plt.tight_layout()
        plt.show()


def eval_tsne_grid(
    gen,
    classifier,
    samples_per_class=100,
    num_classes=NUM_CLASSES,
    latent_dim=latent_dim,
    seed=42
):
    gen.eval()
    classifier.eval()
    all_features, all_labels, all_preds = [], [], []

    torch.manual_seed(seed)
    np.random.seed(seed)

    with torch.no_grad():
        for label in range(num_classes):
            labels = torch.full((samples_per_class,), label, dtype=torch.long, device=device)
            v_raw = F.softmax(torch.randn(samples_per_class, num_classes, device=device), dim=1)
            hot_mat = build_hot_matrix(labels, N=num_classes, device=device)
            z = torch.randn(samples_per_class, latent_dim, device=device)

            imgs = gen(z, v_raw, hot_mat)
            logits, feats = classifier(imgs, return_features=True)
            preds = logits.argmax(dim=1)

            all_features.append(feats.cpu())
            all_labels.append(labels.cpu())
            all_preds.append(preds.cpu())

    # Concatenate
    all_features = torch.cat(all_features, dim=0).numpy()
    all_labels = torch.cat(all_labels, dim=0).numpy()
    all_preds = torch.cat(all_preds, dim=0).numpy()

    acc = np.mean(all_labels == all_preds) * 100
    print(f"[t-SNE] Inversion Accuracy: {acc:.2f}%")

    # Dimensionality reduction
    pca = PCA(n_components=min(50, all_features.shape[1]))
    features_pca = pca.fit_transform(all_features)

    tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=seed)
    features_2d = tsne.fit_transform(features_pca)
    plt.figure(figsize=(8, 6))
    for label in range(num_classes):
        idx = all_labels == label
        plt.scatter(features_2d[idx, 0], features_2d[idx, 1], label=f"Class {label}", alpha=0.7, s=12)
    plt.legend()
    plt.title("t-SNE of Generator Feature Embeddings")
    plt.xlabel("Dim 1")
    plt.ylabel("Dim 2")
    plt.show()

# Run generator training


In [None]:
print("Starting generator training...")
train_generator(epochs=10, batch_size=64)



In [None]:
# Final t-SNE visualization
print("Running t-SNE evaluation...")
eval_tsne_grid(gen, classifier, samples_per_class=SAMPLES_PER_CLASS)

In [None]:
# Final grid of generated images
print("Generating final image grid...")
show_generated_images(gen, classifier, samples_per_class=8)

In [None]:
final_acc = inversion_accuracy(gen, classifier, n_per_class=100)
print(f"FINAL INVERSION ACCURACY: {final_acc:.2f}%")