In [None]:
# =========================
# STEP 1: Setup (Colab + Kaggle + Google Drive)
# =========================
# !pip install -q kaggle
import os
from google.colab import files, drive

# Mount Google Drive
drive.mount('/content/drive')
save_dir = '/content/drive/MyDrive/WGAN_Knee_Xray'
os.makedirs(save_dir, exist_ok=True)

# Upload kaggle.json
files.upload()  # Upload kaggle.json

# Setup Kaggle API
os.makedirs('/root/.kaggle', exist_ok=True)
os.system('mv kaggle.json /root/.kaggle/')
os.system('chmod 600 /root/.kaggle/kaggle.json')

# Download dataset if not present
data_root = '/content/knee_xray_data/KneeXray'
if not os.path.exists(data_root):
    print("Dataset not found. Downloading...")
    os.system('kaggle datasets download -d gauravduttakiit/osteoarthritis-knee-xray')
    os.system('unzip -q osteoarthritis-knee-xray.zip -d /content/knee_xray_data')
else:
    print("Dataset already exists. Skipping download.")

# =========================
# STEP 2: Imports
# =========================
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tqdm.notebook import tqdm
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import RMSprop
from IPython import display
import gc

# =========================
# CONFIGURABLE PARAMETERS
# =========================
INPUT_SHAPE = (224, 224, 1)
LATENT_DIM = 100
LEARNING_RATE = 0.00002
BATCH_SIZE = 64
EPOCHS = 800
N_CRITIC = 10
GRADIENT_PENALTY_WEIGHT = 10.0
EVAL_BATCH_SIZE = 64  # Reduced for SSIM/PSNR evaluation

# =========================
# STEP 3: Data Preprocessing
# =========================
train_csv = pd.read_csv(f'{data_root}/Train.csv')

# User-settable column names
IMAGE_COL_NAME = 'filename'
LABEL_COL_NAME = 'label'

# Safety check
if IMAGE_COL_NAME not in train_csv.columns or LABEL_COL_NAME not in train_csv.columns:
    print("❌ Train.csv columns are:", train_csv.columns)
    raise KeyError(f"Please set IMAGE_COL_NAME and LABEL_COL_NAME correctly based on above columns.")

image_folder = f'{data_root}/train'
image_data, labels = [], []

for idx, row in train_csv.iterrows():
    img_file = row[IMAGE_COL_NAME]
    label = row[LABEL_COL_NAME]
    img_path = os.path.join(image_folder, img_file)
    if not os.path.exists(img_path):
        print(f"Image not found: {img_path}")
        continue
    try:
        img = Image.open(img_path).convert('L').resize((224, 224))
        img_array = np.array(img)
        image_data.append(img_array)
        labels.append(label)
    except Exception as e:
        print(f"Error loading image {img_path}: {e}")
        continue

X = np.array(image_data)
y = np.array(labels)
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

X = (X.astype('float32') - 127.5) / 127.5
X = np.expand_dims(X, axis=-1)

X_train, X_val, y_train, y_val = train_test_split(X, y_encoded, test_size=0.2, random_state=42)

# Verify dataset
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
if X_train.shape[0] == 0:
    raise ValueError("X_train is empty. Check dataset loading and image processing.")

NUM_CLASSES = len(np.unique(y_encoded))
num_to_class = {i: label_encoder.classes_[i] for i in range(NUM_CLASSES)}
print(f"Number of classes: {NUM_CLASSES}")

# =========================
# STEP 4: Define Models
# =========================
def define_discriminator(in_shape=(224, 224, 1), n_classes=7, lr=0.00005):
    # Label input
    in_label = layers.Input(shape=(1,), dtype=tf.int32)
    li = layers.Embedding(n_classes, 50)(in_label)
    li = layers.Dense(in_shape[0]*in_shape[1], kernel_initializer='he_normal')(li)
    li = layers.Reshape((in_shape[0], in_shape[1], 1))(li)

    # Image input
    in_image = layers.Input(shape=in_shape, dtype=tf.float32)
    x = layers.Concatenate()([in_image, li])
    x = layers.Conv2D(64, (5,5), strides=(2,2), padding='same', kernel_initializer='he_normal')(x)
    print(f"Conv2D 1 output shape: {x.shape}")
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(64, (5,5), strides=(2,2), padding='same', kernel_initializer='he_normal')(x)
    print(f"Conv2D 2 output shape: {x.shape}")
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(64, (5,5), strides=(2,2), padding='same', kernel_initializer='he_normal')(x)
    print(f"Conv2D 3 output shape: {x.shape}")
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(64, (5,5), strides=(2,2), padding='same', kernel_initializer='he_normal')(x)
    print(f"Conv2D 4 output shape: {x.shape}")
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(0.4)(x)
    out = layers.Dense(1, kernel_initializer='he_normal')(x)

    model = Model([in_image, in_label], out)
    return model

def define_generator(latent_dim, n_classes=7):
    # Label input
    in_label = layers.Input(shape=(1,), dtype=tf.int32)
    li = layers.Embedding(n_classes, 50)(in_label)
    li = layers.Dense(14*14*64, kernel_initializer='he_normal')(li)
    li = layers.Reshape((14, 14, 64))(li)

    # Latent input
    in_lat = layers.Input(shape=(latent_dim,), dtype=tf.float32)
    x = layers.Dense(14*14*64, kernel_initializer='he_normal')(in_lat)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Reshape((14, 14, 64))(x)

    x = layers.Concatenate()([x, li])
    x = layers.Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer='he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    out = layers.Conv2D(1, (16,16), activation='tanh', padding='same', kernel_initializer='he_normal')(x)

    model = Model([in_lat, in_label], out)
    return model

def gradient_penalty(discriminator, batch_size, real_images, fake_images, labels):
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
    interpolated = real_images + alpha * (fake_images - real_images)

    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        pred = discriminator([interpolated, labels], training=True)

    grads = gp_tape.gradient(pred, [interpolated])[0]
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

def discriminator_loss(real_output, fake_output, gradient_penalty, gp_weight=20.0):
    real_output = tf.clip_by_value(real_output, -100.0, 100.0)
    fake_output = tf.clip_by_value(fake_output, -100.0, 100.0)
    d_loss = tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
    d_loss += gp_weight * gradient_penalty
    return d_loss

def generator_loss(fake_output):
    fake_output = tf.clip_by_value(fake_output, -100.0, 100.0)
    return -tf.reduce_mean(fake_output)

def define_gan(generator, discriminator):
    gen_lat, gen_label = generator.input
    gen_output = generator.output
    gan_output = discriminator([gen_output, gen_label])
    model = Model([gen_lat, gen_label], gan_output)
    return model

# =========================
# STEP 5: Training + Save + Plots
# =========================
def generate_real_samples(X, y, n_samples):
    idx = np.random.randint(0, X.shape[0], n_samples)
    return X[idx], y[idx]

def generate_latent_points(latent_dim, n_samples, n_classes):
    x_input = np.random.randn(latent_dim * n_samples).reshape(n_samples, latent_dim)
    labels = np.random.randint(0, n_classes, n_samples)
    return [x_input, labels]

def generate_fake_samples(generator, latent_dim, n_samples, n_classes):
    x_input, labels_input = generate_latent_points(latent_dim, n_samples, n_classes)
    x_input = tf.convert_to_tensor(x_input, dtype=tf.float32)
    labels_input = tf.convert_to_tensor(labels_input, dtype=tf.int32)
    X = generator([x_input, labels_input], training=False)
    return X, labels_input

def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=10, batch_size=32, n_critic=3, gp_weight=20.0):
    X, y = dataset
    batches_per_epoch = X.shape[0] // batch_size
    d_loss_hist, g_loss_hist = [], []
    ssim_hist = []
    psnr_hist = []

    # Check trainable variables
    print(f"Discriminator trainable variables: {len(d_model.trainable_variables)}")
    print(f"Generator trainable variables: {len(g_model.trainable_variables)}")
    if len(d_model.trainable_variables) == 0 or len(g_model.trainable_variables) == 0:
        raise ValueError("No trainable variables in discriminator or generator. Check model architecture.")

    # Optimizers
    d_optimizer = RMSprop(learning_rate=LEARNING_RATE)
    g_optimizer = RMSprop(learning_rate=LEARNING_RATE)

    # Resume logic
    start_epoch = 0
    last_epoch_file = os.path.join(save_dir, "wgan_last_epoch.txt")
    if os.path.exists(last_epoch_file):
        with open(last_epoch_file) as f:
            start_epoch = int(f.read())
        g_model.load_weights(os.path.join(save_dir, f"wgan_generator_epoch_{start_epoch}.keras"))
        d_model.load_weights(os.path.join(save_dir, f"wgan_discriminator_epoch_{start_epoch}.keras"))
        print(f"Resumed WGAN training from epoch {start_epoch}")
    else:
        print("Starting new WGAN training")

    for epoch in range(start_epoch, n_epochs):
        batch_bar = tqdm(range(batches_per_epoch), desc=f"Epoch {epoch+1}/{n_epochs}", leave=False)
        for batch in batch_bar:
            # Train discriminator
            d_loss = 0
            for _ in range(n_critic):
                X_real, labels_real = generate_real_samples(X, y, batch_size)
                X_fake, labels_fake = generate_fake_samples(g_model, latent_dim, batch_size, NUM_CLASSES)

                X_real = tf.convert_to_tensor(X_real, dtype=tf.float32)
                labels_real = tf.convert_to_tensor(labels_real, dtype=tf.int32)
                X_fake = tf.convert_to_tensor(X_fake, dtype=tf.float32)
                labels_fake = tf.convert_to_tensor(labels_fake, dtype=tf.int32)

                with tf.GradientTape() as d_tape:
                    real_output = d_model([X_real, labels_real], training=True)
                    fake_output = d_model([X_fake, labels_fake], training=True)
                    gp = gradient_penalty(d_model, batch_size, X_real, X_fake, labels_real)
                    d_loss = discriminator_loss(real_output, fake_output, gp, gp_weight)

                d_grads = d_tape.gradient(d_loss, d_model.trainable_variables)
                if any(g is None for g in d_grads):
                    print("Warning: Some discriminator gradients are None")
                    print(f"Real output shape: {real_output.shape}, Fake output shape: {fake_output.shape}")
                    print(f"GP: {gp}, d_loss: {d_loss}")

                d_grads = [tf.clip_by_norm(g, 1.0) if g is not None else tf.zeros_like(v) for g, v in zip(d_grads, d_model.trainable_variables)]
                d_optimizer.apply_gradients(zip(d_grads, d_model.trainable_variables))

            # Train generator
            x_gan, labels_gan = generate_latent_points(latent_dim, batch_size, NUM_CLASSES)
            x_gan = tf.convert_to_tensor(x_gan, dtype=tf.float32)
            labels_gan = tf.convert_to_tensor(labels_gan, dtype=tf.int32)

            d_model.trainable = False
            with tf.GradientTape() as g_tape:
                fake_images = g_model([x_gan, labels_gan], training=True)
                fake_output = d_model([fake_images, labels_gan], training=True)
                g_loss = generator_loss(fake_output)

            g_grads = g_tape.gradient(g_loss, g_model.trainable_variables)
            if any(g is None for g in g_grads):
                print("Warning: Some generator gradients are None")
                print(f"Fake output shape: {fake_output.shape}, g_loss: {g_loss}")

            g_grads = [tf.clip_by_norm(g, 1.0) if g is not None else tf.zeros_like(v) for g, v in zip(g_grads, g_model.trainable_variables)]
            g_optimizer.apply_gradients(zip(g_grads, g_model.trainable_variables))
            d_model.trainable = True

            batch_bar.set_postfix(d_loss=float(d_loss), g_loss=float(g_loss), gp=float(gp))

        d_loss_hist.append(float(d_loss))
        g_loss_hist.append(float(g_loss))

        display.clear_output(wait=True)
        show_generated_images(g_model, latent_dim, NUM_CLASSES)

        # Clear GPU memory before evaluation
        tf.keras.backend.clear_session()
        gc.collect()

        # SSIM + PSNR evaluation
        num_samples = EVAL_BATCH_SIZE
        idx = np.random.choice(X.shape[0], num_samples, replace=False)
        real_images = X[idx]
        labels = y[idx]
        latent_points, input_labels = generate_latent_points(latent_dim, num_samples, NUM_CLASSES)
        latent_points = tf.convert_to_tensor(latent_points, dtype=tf.float32)
        input_labels = tf.convert_to_tensor(input_labels, dtype=tf.int32)
        fake_images = g_model([latent_points, input_labels], training=False)

        ssim_scores = []
        psnr_scores = []
        for real, fake in zip(real_images, fake_images):
            print(f"Real shape before squeeze: {real.shape}, Fake shape before squeeze: {fake.shape}")
            real_img = tf.squeeze(real, axis=-1).numpy()
            fake_img = tf.squeeze(fake, axis=-1).numpy()
            print(f"Real shape after squeeze: {real_img.shape}, Fake shape after squeeze: {fake_img.shape}")
            real_img = ((real_img + 1) * 127.5).astype('uint8')
            fake_img = ((fake_img + 1) * 127.5).astype('uint8')
            ssim_score = compare_ssim(real_img, fake_img, data_range=255)
            psnr_score = compare_psnr(real_img, fake_img, data_range=255)
            ssim_scores.append(ssim_score)
            psnr_scores.append(psnr_score)

        mean_ssim = np.mean(ssim_scores)
        mean_psnr = np.mean(psnr_scores)
        print(f"Epoch {epoch+1}: SSIM={mean_ssim:.4f}, PSNR={mean_psnr:.2f}")

        ssim_hist.append(mean_ssim)
        psnr_hist.append(mean_psnr)

        # Save best model based on SSIM
        if len(ssim_hist) <= 1 or mean_ssim > max(ssim_hist[:-1]):
            print(f"Saving best model with SSIM={mean_ssim:.4f}")
            tf.keras.backend.clear_session()
            gc.collect()
            g_model.save(os.path.join(save_dir, "wgan_generator_best_ssim.keras"))
            d_model.save(os.path.join(save_dir, "wgan_discriminator_best_ssim.keras"))

        # Plotting SSIM and PSNR over epochs (similar to the first image)
        epochs_range = range(1, len(ssim_hist) + 1)

        # First plot: SSIM and PSNR
        plt.figure(figsize=(10, 5))
        plt.subplot(2, 1, 1)
        plt.plot(epochs_range, ssim_hist, label='SSIM', color='blue')
        plt.title('SSIM and PSNR Over Epochs (Up to {})'.format(epoch + 1))
        plt.ylabel('Metric Value')
        plt.grid(True)
        plt.legend()

        plt.subplot(2, 1, 2)
        plt.plot(epochs_range, psnr_hist, label='PSNR (dB)', color='orange')
        plt.xlabel('Epoch')
        plt.ylabel('Metric Value')
        plt.grid(True)
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'ssim_psnr_epoch_{epoch+1}.png'))
        plt.close()

        # Second plot: Generator vs Discriminator Loss
        plt.figure(figsize=(10, 5))
        plt.plot(epochs_range, g_loss_hist, label='Generator Loss', color='green')
        plt.plot(epochs_range, d_loss_hist, label='Discriminator Loss', color='red')
        plt.title('Generator vs Discriminator Loss Over Epochs (Up to {})'.format(epoch + 1))
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'gen_disc_loss_epoch_{epoch+1}.png'))
        plt.close()

        # Save WGAN models every epoch
        tf.keras.backend.clear_session()
        gc.collect()
        g_model.save(os.path.join(save_dir, f"wgan_generator_epoch_{epoch+1}.keras"))
        d_model.save(os.path.join(save_dir, f"wgan_discriminator_epoch_{epoch+1}.keras"))
        with open(last_epoch_file, "w") as f:
            f.write(str(epoch+1))

        # Log losses, gradient norms, and memory usage
        d_grad_norm = tf.reduce_mean([tf.norm(g) for g in d_grads if g is not None]).numpy()
        g_grad_norm = tf.reduce_mean([tf.norm(g) for g in g_grads if g is not None]).numpy()
        log_file = os.path.join(save_dir, "wgan_training_log.txt")
        with open(log_file, "a") as f:
            f.write(f"Epoch {epoch+1}: d_loss={float(d_loss):.4f}, g_loss={float(g_loss):.4f}, SSIM={mean_ssim:.4f}, PSNR={mean_psnr:.2f}, GP={float(gp):.4f}, d_grad_norm={d_grad_norm:.4f}, g_grad_norm={g_grad_norm:.4f}
")
        os.system(f'df -h /content >> {log_file}')
        os.system(f'nvidia-smi >> {log_file}')

def show_generated_images(generator, latent_dim, num_classes):
    fig, axs = plt.subplots(1, num_classes, figsize=(15, 15))
    for i in range(num_classes):
        noise = np.random.randn(1, latent_dim)
        label = np.array([i])
        noise = tf.convert_to_tensor(noise, dtype=tf.float32)
        label = tf.convert_to_tensor(label, dtype=tf.int32)
        generated = generator([noise, label], training=False)
        axs[i].imshow(generated[0,:,:,0]*127.5+127.5, cmap='gray')
        axs[i].axis('off')
        axs[i].set_title(num_to_class[i])
    plt.savefig(os.path.join(save_dir, 'generated_images.png'))
    plt.close()

# =========================
# STEP 6: Train WGAN-GP
# =========================
discriminator = define_discriminator(INPUT_SHAPE, NUM_CLASSES, LEARNING_RATE)
generator = define_generator(LATENT_DIM, NUM_CLASSES)
gan = define_gan(generator, discriminator)

# Verify model trainability
print("Discriminator trainable variables:", [v.name for v in discriminator.trainable_variables])
print("Generator trainable variables:", [v.name for v in generator.trainable_variables])

train(generator, discriminator, gan, [X_train, y_train], LATENT_DIM, EPOCHS, BATCH_SIZE, N_CRITIC, GRADIENT_PENALTY_WEIGHT)

Real shape before squeeze: (224, 224, 1), Fake shape before squeeze: (224, 224, 1)
Real shape after squeeze: (224, 224), Fake shape after squeeze: (224, 224)
Real shape before squeeze: (224, 224, 1), Fake shape before squeeze: (224, 224, 1)
Real shape after squeeze: (224, 224), Fake shape after squeeze: (224, 224)
Real shape before squeeze: (224, 224, 1), Fake shape before squeeze: (224, 224, 1)
Real shape after squeeze: (224, 224), Fake shape after squeeze: (224, 224)
Real shape before squeeze: (224, 224, 1), Fake shape before squeeze: (224, 224, 1)
Real shape after squeeze: (224, 224), Fake shape after squeeze: (224, 224)
Real shape before squeeze: (224, 224, 1), Fake shape before squeeze: (224, 224, 1)
Real shape after squeeze: (224, 224), Fake shape after squeeze: (224, 224)
Real shape before squeeze: (224, 224, 1), Fake shape before squeeze: (224, 224, 1)
Real shape after squeeze: (224, 224), Fake shape after squeeze: (224, 224)
Real shape before squeeze: (224, 224, 1), Fake shape

Epoch 456/800:   0%|          | 0/195 [00:00<?, ?it/s]