<a href="https://colab.research.google.com/github/uqmshawn/uqmshawn-4-7-1-8-4-5-3-0-r/blob/main/47184530.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import zipfile
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from google.colab import drive
import zipfile

In [6]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mount Google Drive
drive.mount('/content/drive')

class BrainSlicesDataset(Dataset):
    def __init__(self, image_slices):
        self.image_slices = image_slices

    def __len__(self):
        return len(self.image_slices)

    def __getitem__(self, idx):
        image = self.image_slices[idx]

        # Ensure the image has a channel dimension
        if len(image.shape) == 2:  # If the image is of shape [H, W]
            image = torch.unsqueeze(image, 0)  # Convert it to [1, H, W]

        return image

def get_image_slices():
    zip_path = "/content/drive/MyDrive/Colab_Notebooks_Course/image_process/A3/testgans/GAN_Dataset.zip"
    extraction_path = "/content/GAN_Dataset"
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extraction_path)

    parent_dir = "/content/GAN_Dataset"
    train_path = os.path.join(parent_dir, "keras_png_slices_train")
    test_path = os.path.join(parent_dir, "keras_png_slices_test")
    val_path = os.path.join(parent_dir, "keras_png_slices_validate")

    def load_images_from_folder(folder_path):
            images = []
            for filename in os.listdir(folder_path):
                img = Image.open(os.path.join(folder_path, filename)).convert('L').resize((128, 128))
                if img is not None:
                    images.append(torch.tensor(np.array(img, dtype=np.float32)))
            return torch.stack(images)

    train_images = load_images_from_folder(train_path)
    test_images = load_images_from_folder(test_path)
    validate_images = load_images_from_folder(val_path)

    # Print statements to understand the data
    print(f"Total train images: {len(train_images)}")
    print(f"Shape of a single train image: {train_images[0].shape}")
    print(f"Total test images: {len(test_images)}")
    print(f"Shape of a single test image: {test_images[0].shape}")
    print(f"Total validation images: {len(validate_images)}")
    print(f"Shape of a single validation image: {validate_images[0].shape}")

    return train_images, test_images, validate_images

# Call the function to see the print outputs
get_image_slices()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Total train images: 9664
Shape of a single train image: torch.Size([128, 128])
Total test images: 544
Shape of a single test image: torch.Size([128, 128])
Total validation images: 1120
Shape of a single validation image: torch.Size([128, 128])


(tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 

In [None]:
# Model Definitions: VectorQuantizer, Encoder, Decoder, VQVAE, PixelCNN, etc.

# VectorQuantizer Layer
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta

        # Learnable codebook embeddings
        self.embeddings = nn.Parameter(torch.randn(embedding_dim, num_embeddings))

    def forward(self, x):
        # Reshape input tensor to [batch_size, num_channels, height, width]
        z_e_x = x.permute(0, 2, 3, 1).contiguous()

        # Flatten the input tensor to [batch_size * height * width, embedding_dim]
        z_e_x_ = z_e_x.view(-1, self.embedding_dim)

        # Calculate distances between input vectors and codebook vectors
        distances = (torch.sum(z_e_x_**2, dim=1, keepdim=True)
                    + torch.sum(self.embeddings**2, dim=0)
                    - 2 * torch.matmul(z_e_x_, self.embeddings))

        # Find the index of the nearest codebook vector for each input vector
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)

        # Create a one-hot encoding based on the indices
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings).to(x.device)
        encodings.scatter_(1, encoding_indices, 1)

         # Reshape encoding_indices to match the original input shape
        encoding_indices = encoding_indices.view(*z_e_x.shape[:-1])

        # Compute the quantized representation using the codebook
        quantized = torch.matmul(encodings, self.embeddings.t()).view(*z_e_x.shape)

        # Compute loss components
        e_latent_loss = F.mse_loss(quantized.detach(), z_e_x)
        q_latent_loss = F.mse_loss(quantized, z_e_x.detach())

        # Calculate the total loss as a combination of the two losses
        loss = q_latent_loss + self.beta * e_latent_loss

        # Ensure that quantized is used in the gradient computation
        quantized = z_e_x + (quantized - z_e_x).detach()

        # Compute the perplexity of the encoding distribution
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # Return the loss, quantized tensor, perplexity, and encoding indices
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices

# Encoder
class Encoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, embedding_dim):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels // 2, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels // 2, embedding_dim, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.encoder(x)
# Decoder
class Decoder(nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(input_channels, hidden_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_channels, hidden_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_channels // 2, 1, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.decoder(x)
# VQVAE
# VQVAETrainer
# PixelConvLayer & PixelCNN

In [None]:
# Training Functions: train_vqvae, train_pixelcnn, etc.

# train_vqvae
# train_pixelcn

In [None]:
import matplotlib.pyplot as plt

def visualize_reconstructions(originals, reconstructions, num_samples=3):
    # ... [Code for visualizing reconstructions]

def visualize_samples(samples, num_samples=3):
    # ... [Code for visualizing samples]

In [None]:
def main():
    # ... [Main function code]

if __name__ == "__main__":
    main()