In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D, concatenate, Activation, BatchNormalization, LeakyReLU, Dense, Reshape, Flatten
from tensorflow.keras.models import Model
import cv2
import matplotlib.pyplot as plt
import pathlib

In [None]:
# Setup environment
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

In [None]:
class ReadDataset:
    def __init__(self, datasetpath, labels, image_shape=(64, 64)):
        self.datasetpath = datasetpath
        self.labels = labels
        self.image_shape = image_shape

    def readImages(self):
        images = []
        imagesSegmentation = []
        labels = []
        for label in self.labels:
            image_paths = list(pathlib.Path(os.path.join(self.datasetpath, label)).glob('*.jpg'))
            for img_path in image_paths:
                img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
                img = cv2.resize(img, self.image_shape)
                img = img / 255.0

                imgSeg_path = str(img_path).replace('.jpg', '_mask.png')
                imgSeg = cv2.imread(imgSeg_path, cv2.IMREAD_GRAYSCALE)
                imgSeg = cv2.resize(imgSeg, self.image_shape)
                imgSeg = imgSeg / np.max(imgSeg) if np.max(imgSeg) > 0 else imgSeg

                images.append(img)
                imagesSegmentation.append(imgSeg)
                labels.append(self.labels.index(label))

        return np.array(images), np.array(imagesSegmentation), np.array(labels)

In [None]:
class ACGAN:
    def __init__(self, image_shape, segmentation_shape, num_classes, latent_dim=100):
        self.image_shape = image_shape
        self.segmentation_shape = segmentation_shape
        self.num_classes = num_classes
        self.latent_dim = latent_dim

    def build_unet_generator(self):
        noise = Input(shape=(self.latent_dim,))
        labels = Input(shape=(self.num_classes,))
        x = layers.concatenate([noise, labels])

        # Base network that learns common features
        x = Dense(8 * 8 * 256)(x)
        x = Reshape((8, 8, 256))(x)
        for filters in [256, 128, 64, 32]:
            x = Conv2DTranspose(filters, (3, 3), strides=(2, 2), padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('relu')(x)

        # Branch off into two separate paths: one for the image and one for the segmentation map
        img_output = Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same', name='img_output')(x)
        seg_output = Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same', name='seg_output')(x)

        return Model([noise, labels], [img_output, seg_output], name='generator')

    def build_discriminator(self):
        image_input = Input(shape=self.image_shape, name='img_input')
        segmentation_input = Input(shape=self.segmentation_shape, name='seg_input')
        
        x = concatenate([image_input, segmentation_input], axis=-1)
        for filters in [32, 64, 128, 256]:
            x = Conv2D(filters, (3, 3), strides=(2, 2), padding='same')(x)
            x = LeakyReLU(alpha=0.2)(x)
            x = BatchNormalization()(x)

        x = Flatten()(x)
        validity = Dense(1, activation='sigmoid')(x)
        label = Dense(self.num_classes, activation='softmax')(x)

        return Model([image_input, segmentation_input], [validity, label], name='discriminator')

    def compile_acgan(self, generator, discriminator):
        optimizer_gen = optimizers.Adam(0.0002, 0.5)
        optimizer_disc = optimizers.Adam(0.0002, 0.5)

        discriminator.compile(loss=['binary_crossentropy', 'sparse_categorical_crossentropy'],
                              optimizer=optimizer_disc,
                              metrics=['accuracy'])

        discriminator.trainable = False
        noise = Input(shape=(self.latent_dim,))
        labels = Input(shape=(self.num_classes,))
        img, seg = generator([noise, labels])
        valid, label = discriminator([img, seg])
        combined = Model([noise, labels], [valid, label])
        combined.compile(loss=['binary_crossentropy', 'sparse_categorical_crossentropy'],
                         optimizer=optimizer_gen)

        return combined

In [None]:
def calculate_metrics(real_img, generated_img, real_seg, generated_seg):
    real_img = real_img.astype(np.float32)
    generated_img = generated_img.astype(np.float32)
    real_seg = real_seg.astype(np.float32)
    generated_seg = generated_seg.astype(np.float32)

    psnr_val = psnr(real_img, generated_img)
    ssim_val = ssim(real_img, generated_img)

    # Assuming you have a function to calculate Dice and IoU
    dice_val = dice_coefficient(real_seg, generated_seg)
    iou_val = iou(real_seg, generated_seg)

    return psnr_val, ssim_val, dice_val, iou_val

In [None]:
def train_acgan(generator, discriminator, combined, images, segmentations, labels, epochs=10000, batch_size=32):
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        idx = np.random.randint(0, images.shape[0], batch_size)
        real_imgs, real_segs, real_labels = images[idx], segmentations[idx], labels[idx]

        noise = np.random.normal(0, 1, (batch_size, generator.input_shape[0][-1]))
        gen_imgs, gen_segs = generator.predict([noise, to_categorical(real_labels, num_classes=4)])

        # Train discriminator
        d_loss_real = discriminator.train_on_batch([real_imgs, real_segs], [valid, real_labels])
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_segs], [fake, real_labels])
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train generator
        g_loss = combined.train_on_batch([noise, to_categorical(real_labels, num_classes=4)], [valid, real_labels])

        if epoch % 100 == 0:
            print(f"Epoch {epoch}/{epochs}: [D loss: {d_loss[0]} - Acc.: {d_loss[3]}], [G loss: {g_loss}]")
            psnr_val, ssim_val, dice_val, iou_val = calculate_metrics(real_imgs[0], gen_imgs[0], real_segs[0], gen_segs[0])
            print(f"Metrics - PSNR: {psnr_val}, SSIM: {ssim_val}, Dice: {dice_val}, IoU: {iou_val}")


In [None]:
# Assuming 'images', 'segmentations', and 'labels' are loaded via ReadDataset
image_shape = (64, 64, 1)
segmentation_shape = (64, 64, 1)
num_classes = 4

acgan = ACGAN(image_shape, segmentation_shape, num_classes)
generator = acgan.build_unet_generator()
discriminator = acgan.build_discriminator()
combined = acgan.compile_acgan(generator, discriminator)

In [None]:
# Data loading and training here, e.g.:
images, segmentations, labels = ReadDataset('tumor_dataset/',
                               ['healthy','glioma','meningioma','pituitary'],
                               (128, 128, 3)).readImages()
train_acgan(generator, discriminator, combined, images, segmentations, labels, epochs=50000, batch_size=32)