In [3]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import os

In [None]:
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

class WSConv2d(nn.Module):
    """
    Weight scaled Conv2d (Equalized Learning Rate)
    Note that input is multiplied rather than changing weights
    this will have the same result.

    Inspired and looked at:
    https://github.com/nvnbny/progressive_growing_of_gans/blob/master/modelUtils.py
    """

    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)


class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x


class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()

        # initial takes 1x1 -> 4x4
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(
            len(factors) - 1
        ):  # -1 to prevent index error because of factors[i+1]
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        # The number of channels in upscale will stay the same, while
        # out which has moved through prog_blocks might change. To ensure
        # we can convert both to rgb we use different rgb_layers
        # (steps-1) and steps for upscaled, out respectively
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)
    
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
def denormalize(image):
    min_val = MIN_HU_VALUE
    max_val = MAX_HU_VALUE
    return image * (max_val - min_val) + min_val

In [4]:
def generate_images(
    gen_checkpoint,
    n,
    hist_type,
    path,
    alpha = 1.0,
    step=4,
    z_dim=512,
    device="cuda" if torch.cuda.is_available() else "cpu",
    learning_rate=1e-3,
    in_channels=512,
    channels_img=1
    ):
    
    gen = Generator(z_dim, in_channels, img_channels=channels_img).to(device)
    opt_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.0, 0.99))
    load_checkpoint(gen_checkpoint, gen, opt_gen, learning_rate)
    gen.train()
    with torch.no_grad():
        for i in range(n):
            noise = torch.randn(1, z_dim, 1, 1).to(device)
            fake_image = gen(noise, alpha, step)
            fake_image = fake_image.cpu().numpy()
            fake_image = fake_image * 0.5 + 0.5
            fake_image = fake_image[0,0]
            fake_image = denormalize(fake_image)
            fake_image = np.around(fake_image)
            image_file = f'GAN_{hist_type}_{str(i).zfill(4)}.dcm'
            image_path = os.path.join(path, image_file)
            write_dicom(fake_image, image_path)

In [None]:
# Delete images
def delete_gan_images():
    directories = [
        "/Storage/PauloOctavioDir/splitted_folders/VGG-GAN/train/images/", "/Storage/PauloOctavioDir/splitted_folders/VGG-GAN/val/images/",
        "/Storage/PauloOctavioDir/splitted_folders/VGG-GAN/train/images/", "/Storage/PauloOctavioDir/splitted_folders/VGG-GAN/val/images/"
    ]

    for directory in directories:
        for file in os.listdir(directory):
            if file.startswith('GAN'):
                os.remove(os.path.join(directory, file))

In [None]:
import pydicom
from pydicom.dataset import Dataset, FileDataset
from pydicom.uid import ExplicitVRLittleEndian
import pydicom._storage_sopclass_uids

def write_dicom(image, filename, rescale_intercept="0", rescale_slope="1", pixel_spacing=r"1\1"): 
    # ref: https://stackoverflow.com/questions/14350675/create-pydicom-file-from-numpy-array
    if image.dtype != np.uint16:
        image = image.astype(np.uint16)
        
    meta = pydicom.Dataset()
    meta.MediaStorageSOPClassUID = pydicom._storage_sopclass_uids.CTImageStorage
    meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
    meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian  

    ds = Dataset()
    ds.file_meta = meta

    ds.is_little_endian = True
    ds.is_implicit_VR = False

    ds.SOPClassUID = pydicom._storage_sopclass_uids.CTImageStorage

    ds.Modality = "CT"
    ds.SeriesInstanceUID = pydicom.uid.generate_uid()
    ds.StudyInstanceUID = pydicom.uid.generate_uid()
    ds.FrameOfReferenceUID = pydicom.uid.generate_uid()

    ds.BitsStored = 16
    ds.BitsAllocated = 16
    ds.SamplesPerPixel = 1
    ds.HighBit = 15

    ds.ImagesInAcquisition = "1"

    ds.Rows = image.shape[0]
    ds.Columns = image.shape[1]
    ds.InstanceNumber = 1
    
    ds.RescaleIntercept = rescale_intercept
    ds.RescaleSlope = rescale_slope
    ds.PixelSpacing = pixel_spacing
    ds.PhotometricInterpretation = "MONOCHROME2"
    ds.PixelRepresentation = 1

    pydicom.dataset.validate_file_meta(ds.file_meta, enforce_standard=True)
        
    ds.PixelData = image.tobytes()
    ds.save_as(filename, write_like_original=False)