# DataLoader and Dataset Classes

In [1]:
import torch

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print device info
print("Using device:", device)
if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
    print("Memory allocated:", torch.cuda.memory_allocated(0) / 1024**2, "MB")
    print("Memory cached:", torch.cuda.memory_reserved(0) / 1024**2, "MB")
else:
    print("No GPU found. Using CPU.")


Using device: cuda
GPU name: Tesla P100-PCIE-16GB
Memory allocated: 0.0 MB
Memory cached: 0.0 MB


In [2]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageFile
import random
import PIL # Ensure PIL is imported if needed for error handling

ImageFile.LOAD_TRUNCATED_IMAGES = True

class CelebAFromSeparateAnnotationsDataset(Dataset):
    """
    CelebA Dataset loading images and annotations from two separate files
    (one full, one causal), assuming they correspond row-by-row based on
    a shared identifier (e.g., filename in the first column).
    """

    def __init__(self, image_dir, full_annotations_file, causal_annotations_file,
                 transform=None, num_samples=None):
        """
        Args:
            image_dir (string): Directory with all the images.
            full_annotations_file (string): Path to the file with ALL annotations.
                                            Assumes first column is image filename.
            causal_annotations_file (string): Path to the file with ONLY causal annotations.
                                              Assumes first column is image filename.
            transform (callable, optional): Optional transform to be applied on the image sample.
            num_samples (int, optional): Number of images to randomly sample. If None, use all images.
        """
        self.image_dir = image_dir
        self.transform = transform

        try:
            # Load both annotation files
            self.full_annotations_df = self._read_annotation_file(full_annotations_file)
            self.causal_annotations_df = self._read_annotation_file(causal_annotations_file)
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Annotation file error: {e}")
        except ValueError as e:
             raise ValueError(f"Annotation file error: {e}")


        # --- Crucial Assumption & Verification ---
        # Assume the first column is the filename/identifier in both files
        self.filename_col_name = self.full_annotations_df.columns[0]
        if self.causal_annotations_df.columns[0] != self.filename_col_name:
            raise ValueError(f"Filename column mismatch: '{self.filename_col_name}' vs "
                             f"'{self.causal_annotations_df.columns[0]}'")

        # Verify that the files have the same number of rows (important for direct indexing)
        if len(self.full_annotations_df) != len(self.causal_annotations_df):
            print(f"Warning: Annotation files have different lengths "
                  f"({len(self.full_annotations_df)} vs {len(self.causal_annotations_df)}). "
                  f"This might lead to errors unless filenames match perfectly.")
            # Consider adding a merge or stricter check here if lengths differ significantly
            # For now, we proceed assuming filenames will align for the sampled indices

        # You could optionally add a check here that the filename columns themselves match
        # if not self.full_annotations_df[self.filename_col_name].equals(self.causal_annotations_df[self.filename_col_name]):
        #    print("Warning: Filename columns do not match exactly between files.")

        # --- Sampling ---
        total_images = len(self.full_annotations_df) # Sample based on the full annotations file length
        if num_samples is None:
            self.num_samples = total_images
            self.sampled_indices = list(range(total_images)) # Use all indices
            print(f"Using all {self.num_samples} images.")
        else:
            self.num_samples = min(num_samples, total_images)
            print(f"Randomly sampling {self.num_samples} images from {total_images} total.")
            self.sampled_indices = random.sample(range(total_images), self.num_samples)

    def _read_annotation_file(self, filepath):
        """Helper function to read CSV or Excel."""
        if filepath.endswith('.xlsx') or filepath.endswith('.xls'):
            return pd.read_excel(filepath)
            
        elif filepath.endswith('.csv'):
            return pd.read_csv(filepath)
        else:
            raise ValueError("Unsupported annotations file format. Use .csv, .xlsx, or .xls")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if not isinstance(idx, int):
             raise TypeError(f"Index must be an integer. Got: {type(idx)}")
        if not (0 <= idx < self.num_samples):
             raise IndexError(f"Index {idx} out of bounds for dataset with length {self.num_samples}")

        # Use the pre-sampled index to access the correct row in the *full* annotations df
        original_idx = self.sampled_indices[idx]
        full_row_data = self.full_annotations_df.iloc[original_idx]

        img_filename = full_row_data.iloc[0] # Get filename from the full annotations row
        img_name = os.path.join(self.image_dir, img_filename)

        # --- Find the corresponding row in the causal annotations file ---
        # This assumes filenames are unique identifiers and present in both files for the sampled indices
        causal_row_data = self.causal_annotations_df[
            self.causal_annotations_df[self.filename_col_name] == img_filename
        ]

        if causal_row_data.empty:
             print(f"Warning: Filename '{img_filename}' found in full annotations (idx {original_idx}) "
                   f"but not found in causal annotations file. Returning None.")
             return None, None, None # Need to handle this in collate_fn

        # Ensure only one matching row was found
        if len(causal_row_data) > 1:
            print(f"Warning: Multiple rows found for filename '{img_filename}' in causal annotations. "
                  f"Using the first match. Returning None.")
            return None, None, None # Treat as error for now


        # --- Load Image ---
        try:
            image = Image.open(img_name).convert('RGB')
        except FileNotFoundError:
            print(f"Warning: Image file not found: {img_name} (Index: {idx}, Original Index: {original_idx}). Returning None.")
            return None, None, None
        except (PIL.UnidentifiedImageError, OSError) as e:
            print(f"Warning: Error loading image {img_name} (Index: {idx}, Original Index: {original_idx}): {e}. Returning None.")
            return None, None, None

        # --- Extract Annotations ---
        try:
            # Extract ALL annotations (from 2nd column onwards in full df row)
            all_annotations_np = full_row_data.iloc[1:].values.astype('float32')

            # Extract CAUSAL annotations (from 2nd column onwards in the matched causal df row)
            # Use .iloc[0] because filtering returns a DataFrame, we need the Series/row
            causal_annotations_np = causal_row_data.iloc[0, 1:].values.astype('float32')
        except (ValueError, IndexError) as e:
             print(f"Warning: Error converting/extracting annotations for image {img_filename} (Index: {idx}, Orig Idx: {original_idx}): {e}. Returning None.")
             return None, None, None

        # Apply image transformation
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image) # Basic conversion if no transform

        # Convert numpy arrays to tensors
        all_annotations_tensor = torch.from_numpy(all_annotations_np)
        causal_annotations_tensor = torch.from_numpy(causal_annotations_np)

        # --- Final Check (Optional but Recommended) ---
        # Add checks here if needed to ensure dimensions match expected values (e.g., 40 and 4)
        # if all_annotations_tensor.shape[0] != 40: print("Warning: Full annotation dim mismatch")
        # if causal_annotations_tensor.shape[0] != 4: print("Warning: Causal annotation dim mismatch")

        return image, all_annotations_tensor, causal_annotations_tensor

# --- Example Usage ---

# Define your image transformations
data_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(), # Scales images to [0, 1]
])

# Define file paths (MAKE SURE THESE ARE CORRECT)
image_directory = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba"
full_annotations_filepath = "/kaggle/input/annotation/updated_Annotations_Complete.csv" # File with all 40 attributes
causal_annotations_filepath = "/kaggle/input/annotation/updated_Annotations_Causal.csv" # File with only the 4 causal attributes

# Instantiate the dataset
try:
    separate_files_dataset = CelebAFromSeparateAnnotationsDataset(
        image_dir=image_directory,
        full_annotations_file=full_annotations_filepath,
        causal_annotations_file=causal_annotations_filepath,
        transform=data_transform,
        num_samples=30000 # Or None to use all
    )

    # Re-use the safe collate function
    def safe_collate(batch):
        batch = [item for item in batch if item is not None and item[0] is not None]
        if not batch: return None, None, None
        images = torch.stack([item[0] for item in batch])
        full_annotations = torch.stack([item[1] for item in batch])
        causal_annotations = torch.stack([item[2] for item in batch])
        return images, full_annotations, causal_annotations

    # Create DataLoader
    batch_size = 32
    dataloader = DataLoader(
        separate_files_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        collate_fn=safe_collate,
        pin_memory=True
    )

    # --- Test the DataLoader ---
    print(f"\nTesting DataLoader with separate annotation files...")
    num_batches_to_test = 3
    for i, batch_data in enumerate(dataloader):
        if i >= num_batches_to_test: break
        if batch_data[0] is None:
            print(f"Batch {i}: Skipped (collate fn returned None)")
            continue

        images_batch, full_annotations_batch, causal_annotations_batch = batch_data
        print(f"\nBatch {i}:")
        print(f"  Images shape:           {images_batch.shape}, dtype: {images_batch.dtype}")
        # Verify dimensions match expected full/causal counts
        print(f"  Full Annotations shape: {full_annotations_batch.shape}, dtype: {full_annotations_batch.dtype}") # Expect [B, 40]
        print(f"  Causal Annotations shape:{causal_annotations_batch.shape}, dtype: {causal_annotations_batch.dtype}") # Expect [B, 4]
        print(f"  Image range:            min={images_batch.min().item():.4f}, max={images_batch.max().item():.4f}")

    print("\nDataLoader test finished.")

except (ValueError, FileNotFoundError, IndexError) as e:
     print(f"Error initializing dataset: {e}")

Randomly sampling 30000 images from 202599 total.

Testing DataLoader with separate annotation files...

Batch 0:
  Images shape:           torch.Size([32, 3, 64, 64]), dtype: torch.float32
  Full Annotations shape: torch.Size([32, 40]), dtype: torch.float32
  Causal Annotations shape:torch.Size([32, 4]), dtype: torch.float32
  Image range:            min=0.0000, max=1.0000

Batch 1:
  Images shape:           torch.Size([32, 3, 64, 64]), dtype: torch.float32
  Full Annotations shape: torch.Size([32, 40]), dtype: torch.float32
  Causal Annotations shape:torch.Size([32, 4]), dtype: torch.float32
  Image range:            min=0.0000, max=1.0000

Batch 2:
  Images shape:           torch.Size([32, 3, 64, 64]), dtype: torch.float32
  Full Annotations shape: torch.Size([32, 40]), dtype: torch.float32
  Causal Annotations shape:torch.Size([32, 4]), dtype: torch.float32
  Image range:            min=0.0000, max=1.0000

DataLoader test finished.


In [3]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageFile
import random

ImageFile.LOAD_TRUNCATED_IMAGES = True

class CelebADataset(Dataset):
    """CelebA Dataset with image and annotations from Excel, randomly sampled."""

    def __init__(self, image_dir, annotations_file, transform=None, num_samples=30000):
        """
        Args:
            image_dir (string): Directory with all the images.
            annotations_file (string): Path to the CSV/Excel file with annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
            num_samples (int, optional): Number of images to randomly sample.
        """
        self.image_dir = image_dir
        self.annotations = pd.read_excel(annotations_file)  # or pd.read_csv
        self.transform = transform
        self.num_samples = min(num_samples, len(self.annotations))  # Ensure we don't sample more than available
        self.sampled_indices = random.sample(range(len(self.annotations)), self.num_samples) # Store the randomly sampled indices

    def __len__(self):
        return self.num_samples  # Return the number of sampled images

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Use the pre-sampled indices to access the correct rows in the annotation file
        original_idx = self.sampled_indices[idx]
        img_name = os.path.join(self.image_dir, self.annotations.iloc[original_idx, 0])  # Assuming filename is the first column

        try:
            image = Image.open(img_name).convert('RGB')  # Ensure RGB format
        except FileNotFoundError:
            print(f"Warning: Image file not found: {img_name}")
            return None, None  # Return None if the image is not found
        except PIL.UnidentifiedImageError:
            print(f"Warning: Could not open or read image file (corrupted?): {img_name}")
            return None, None
        except OSError as e:  # Catch other potential image opening errors
            print(f"Warning: Could not open image due to OSError: {e}")
            return None, None

        # Assuming annotations are in the remaining columns.  Adjust indices if needed.
        annotations = self.annotations.iloc[original_idx, 1:].values.astype('float32')  # Convert annotations to numpy array of float32

        if self.transform:
            image = self.transform(image)

        annotations = torch.tensor(annotations)  # convert numpy array to tensor

        return (image, annotations)



# Convolutional Encoder

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# =============================================================================
# Placeholder for the ut.gaussian_parameters function
# This assumes it splits the channels (dim=1) into mean and logvar/softplus(var)
# Add a small epsilon to variance for numerical stability.
# =============================================================================
def gaussian_parameters(x, dim=1):
    """
    Splits the tensor x along the specified dimension into two halves,
    interpreting them as parameters (mean and variance) for a Gaussian.
    Applies Softplus to the variance part to ensure positivity.

    Args:
        x (torch.Tensor): Input tensor.
        dim (int): Dimension along which to split.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: Mean and Variance tensors.
    """
    n_channels = x.size(dim)
    if n_channels % 2 != 0:
        raise ValueError(f"Number of channels ({n_channels}) must be even to split into mu and var.")
    mu, var_input = torch.chunk(x, 2, dim=dim)
    # Ensure variance is positive and numerically stable
    var = F.softplus(var_input) + 1e-6
    return mu, var
# =============================================================================


class ConvEncoder(nn.Module):
    """
    Convolutional Encoder for 64x64x3 images.
    Uses a specific convolutional stack (`conv6`) to produce
    parameters (mean and variance) for a latent distribution.

    Input: Batch of images, shape [B, 3, 64, 64]
    Output: mu, var (mean and variance tensors), each shape [B, z_dim]
            where z_dim is determined by the final layer output channels / 2.
    """
    def __init__(self, z_dim_target=64):
        super().__init__()

        # Define the convolutional stack
        # Output channels of the second-to-last conv layer determine features
        # Final conv layer outputs 2 * z_dim channels (for mu and var)
        final_out_channels = 2 * z_dim_target

        self.conv6 = nn.Sequential(
            # Input: B x 3 x 64 x 64
            nn.Conv2d(3, 32, 4, 2, 1),  # B x 32 x 32 x 32
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1), # B x 32 x 16 x 16
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), # B x 64 x 8 x 8
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1), # B x 64 x 4 x 4
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1), # B x 64 x 2 x 2
            nn.ReLU(True),
            # --- Correction: Added padding=1 to handle 2x2 input with 4x4 kernel ---
            nn.Conv2d(64, 256, 4, 1, padding=1), # B x 256 x 1 x 1
            nn.ReLU(True),
            # --- Final 1x1 convolution to get desired output channels ---
            nn.Conv2d(256, final_out_channels , 1) # B x (2*z_dim) x 1 x 1
        )

        # Store the latent dimension
        self.z_dim = z_dim_target
        print(f"ConvEncoder initialized for z_dim = {self.z_dim}")
        print(f"Expected input shape: [B, 3, 64, 64]")
        print(f"Output shape: mu=[B, {self.z_dim}], var=[B, {self.z_dim}]")


    def forward(self, x):
        """
        Encodes the input image batch x into latent distribution parameters.

        Args:
            x (torch.Tensor): Input image batch, shape [B, 3, 64, 64].

        Returns:
            tuple[torch.Tensor, torch.Tensor]: Mean and Variance tensors (mu, var),
                                                each with shape [B, z_dim].
        """
        # Pass input through the convolutional layers
        features = self.conv6(x) # Shape: B x (2*z_dim) x 1 x 1

        # Extract mean and variance parameters
        # Assumes gaussian_parameters splits along dim=1 (channels)
        # and handles the (1x1) spatial dimensions appropriately (e.g., squeezes)
        mu, var = gaussian_parameters(features, dim=1) # mu/var shape: [B, z_dim, 1, 1]

        # Remove trailing spatial dimensions if they exist
        mu = mu.squeeze(-1).squeeze(-1) # Shape: [B, z_dim]
        var = var.squeeze(-1).squeeze(-1) # Shape: [B, z_dim]

        return mu, var

# Conv_Conditional_Encoder

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvEncoderConditional(nn.Module):
    """
    Conditional Convolutional Encoder for 64x64x3 images and annotations u.
    Uses a convolutional stack to process the image, concatenates features
    with annotations, and then uses linear layers to produce parameters
    (mean and log-variance) for the conditional latent distribution q(z|x, u).

    Input:
        x: Batch of images, shape [B, 3, 64, 64]
        u: Batch of annotations, shape [B, annotation_dim]
    Output:
        mu, logvar: Mean and log-variance tensors, each shape [B, z_dim]
    """
    def __init__(self, z_dim_target=64, annotation_dim=4, hidden_fc_dim=128):
        super().__init__()

        # --- Convolutional part (processing image x) ---
        # We'll stop before the final 1x1 conv that produced 2*z_dim channels
        # The output here will be the flattened image features
        image_feature_dim = 256 # Output channels of the last conv layer before flattening

        # Define convolutional layers (excluding the original final 1x1 conv)
        self.conv_layers = nn.Sequential(
            # Input: B x 3 x 64 x 64
            nn.Conv2d(3, 32, 4, 2, 1),  # B x 32 x 32 x 32
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1), # B x 32 x 16 x 16
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1), # B x 64 x 8 x 8
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1), # B x 64 x 4 x 4
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1), # B x 64 x 2 x 2
            nn.ReLU(True),
            # --- Correction: Added padding=1 to handle 2x2 input with 4x4 kernel ---
            nn.Conv2d(64, image_feature_dim, 4, 1, padding=1), # B x 256 x 1 x 1
            nn.ReLU(True)
            # Removed the final Conv2d(256, 2*z_dim, 1)
        )

        # --- Fully Connected part (combining image features and annotation u) ---
        self.fc1 = nn.Linear(image_feature_dim + annotation_dim, hidden_fc_dim)

        # Output layers for mean (mu) and log-variance (logvar)
        self.fc_mu = nn.Linear(hidden_fc_dim, z_dim_target)
        self.fc_logvar = nn.Linear(hidden_fc_dim, z_dim_target)

        # Store dimensions
        self.z_dim = z_dim_target
        self.annotation_dim = annotation_dim
        self.image_feature_dim = image_feature_dim

        print(f"ConvEncoderConditional initialized for z_dim = {self.z_dim}, annotation_dim = {self.annotation_dim}")
        print(f"Expected input shapes: x=[B, 3, 64, 64], u=[B, {self.annotation_dim}]")
        print(f"Output shape: mu=[B, {self.z_dim}], logvar=[B, {self.z_dim}]")


    def forward(self, x, u):
        """
        Encodes the input image batch x and annotations u into latent distribution parameters.

        Args:
            x (torch.Tensor): Input image batch, shape [B, 3, 64, 64].
            u (torch.Tensor): Input annotation batch, shape [B, annotation_dim].

        Returns:
            tuple[torch.Tensor, torch.Tensor]: Mean and Log-variance tensors (mu, logvar),
                                                each with shape [B, z_dim].
        """
        if u.shape[0] != x.shape[0]:
            raise ValueError("Batch size mismatch between image input x and annotation u.")
        if u.shape[1] != self.annotation_dim:
             raise ValueError(f"Input annotation dimension ({u.shape[1]}) does not match "
                              f"encoder's expected annotation_dim ({self.annotation_dim})")

        # 1. Process image x through convolutional layers
        img_features = self.conv_layers(x) # Shape: [B, image_feature_dim, 1, 1]

        # 2. Flatten image features
        img_features_flat = img_features.view(img_features.size(0), -1) # Shape: [B, image_feature_dim]

        # 3. Concatenate flattened image features and annotations u
        combined_features = torch.cat([img_features_flat, u], dim=1) # Shape: [B, image_feature_dim + annotation_dim]

        # 4. Pass combined features through fully connected layers
        hidden = F.relu(self.fc1(combined_features))

        # 5. Compute mu and logvar
        mu = self.fc_mu(hidden)           # Shape: [B, z_dim]
        logvar = self.fc_logvar(hidden)   # Shape: [B, z_dim]

        
        return mu, logvar
  
def sample_z(mu, logvar):
  """
  Samples z from the distribution N(mu, exp(logvar)) using the
  reparameterization trick.

  Args:
      mu (torch.Tensor): Mean tensor from the encoder. Shape [B, z_dim].
      logvar (torch.Tensor): Log-variance tensor from the encoder. Shape [B, z_dim].

  Returns:
      torch.Tensor: Sampled latent variable z. Shape [B, z_dim].
  """
  std = torch.exp(0.5 * logvar)  # Calculate standard deviation sigma
  eps = torch.randn_like(std)   # Sample epsilon ~ N(0, I)
  z = mu + eps * std            # Apply the reparameterization trick
  return z

# --- Example Usage (assuming you have encoder output) ---
# mu, logvar = encoder(x, u) # Get output from your ConvEncoderConditional
# z = sample_z(mu, logvar)   # Sample z

# Now 'z' is the latent vector you can pass to your decoder or classifier      

# Resnet_Convolutional_Encoder


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DownsampleResBlock(nn.Module):
    """
    A Residual Block for Downsampling using Conv2d with stride.
    Applies BN and ReLU before convolutions (pre-activation style).
    Handles changes in channels and spatial dimensions via the skip connection.
    """
    def __init__(self, in_channels, out_channels, stride=2, kernel_size=4, padding=1):
        super().__init__()
        self.stride = stride
        self.channels_match = (in_channels == out_channels)

        # Main (Residual) Path - F(x)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        # Downsamples spatially if stride > 1, also changes channels
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        # Second conv (kernel 3x3) refines features without changing size/channels
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)

        # Skip Connection Path
        self.needs_skip_proj = (stride > 1) or not self.channels_match
        if self.needs_skip_proj:
            # Use a single Conv2d layer for both downsampling (via stride)
            # and channel matching (via out_channels).
            self.skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, # Kernel 1 only changes channels
                                       stride=stride, padding=0, bias=False)      # Stride handles downsampling


    def forward(self, x):
        # --- Skip Connection Path ---
        identity = x
        if self.needs_skip_proj:
            identity = self.skip_conv(identity) # Downsample and/or change channels

        # --- Main Path ---
        out = self.bn1(x)
        out = self.relu1(out)
        out = self.conv1(out) # Downsample and change channels
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out) # Refine features

        # --- Add Skip Connection ---
        out = out + identity # Element-wise addition

        return out


class ConvEncoderConditionalResNet(nn.Module):
    """
    Conditional Convolutional Encoder using DownsampleResBlocks for 64x64x3 images and annotations u.
    Outputs parameters (mean and log-variance) for the conditional latent distribution q(epsilon|x, u).

    Input:
        x: Batch of images, shape [B, 3, 64, 64]
        u: Batch of annotations, shape [B, annotation_dim] (e.g., 40)
    Output:
        mu, logvar: Mean and log-variance tensors, each shape [B, z_dim]
    """
    def __init__(self, z_dim_target=64, annotation_dim=40, start_channels=16, dropout_p=0.3, hidden_fc_dim=128):
        super().__init__()

        self.z_dim = z_dim_target
        self.annotation_dim = annotation_dim

        # --- Initial Convolution ---
        # Map input channels (3) to starting channels for ResBlocks
        self.initial_conv = nn.Conv2d(3, start_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.initial_bn = nn.BatchNorm2d(start_channels)
        self.initial_relu = nn.ReLU(inplace=True)

        # --- Sequence of Downsampling ResBlocks ---
        layers = []
        current_channels = start_channels
        # Define channel progression (example, mirroring decoder approximately)
        channel_config = [32, 64, 128, 256, 256] # Channels *after* each stage

        # Stage 1: 64x64 -> 32x32
        out_ch = channel_config[0] # 32
        layers.append(DownsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 2: 32x32 -> 16x16
        out_ch = channel_config[1] # 64
        layers.append(DownsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 3: 16x16 -> 8x8
        out_ch = channel_config[2] # 128
        layers.append(DownsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 4: 8x8 -> 4x4
        out_ch = channel_config[3] # 256
        layers.append(DownsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 5: 4x4 -> 2x2
        out_ch = channel_config[4] # 256 (keep same or increase)
        layers.append(DownsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 6: 2x2 -> 1x1 (Optional - can use AdaptiveAvgPool later)
        # Instead of a ResBlock, use a Conv layer with kernel size matching input spatial dim
        # This effectively acts like a flatten + linear projection
        self.final_conv_collapse = nn.Conv2d(current_channels, current_channels, kernel_size=2, stride=1, padding=0)
        # Alternatively, could use Adaptive Average Pooling:
        # self.final_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.res_blocks = nn.Sequential(*layers)
        self.final_bn_relu = nn.Sequential( # BN/ReLU after blocks, before final pooling/conv
            nn.BatchNorm2d(current_channels),
            nn.ReLU(inplace=True)
        )

        # --- Fully Connected part (combining image features and annotation u) ---
        self.image_feature_dim = current_channels # Features after pooling/collapsing
        self.fc1 = nn.Linear(self.image_feature_dim + annotation_dim, hidden_fc_dim)
        self.dropout = nn.Dropout(p=dropout_p)
        self.fc_mu = nn.Linear(hidden_fc_dim, z_dim_target)
        self.fc_logvar = nn.Linear(hidden_fc_dim, z_dim_target)

        print(f"ConvEncoderConditionalResNet initialized for z_dim = {self.z_dim}, annotation_dim = {self.annotation_dim}")
        print(f"Output shape (mu, logvar): [B, {self.z_dim}]")

    def forward(self, x, u):
        """
        Encodes the input image batch x and annotations u into latent distribution parameters.
        """
        if u.shape[0] != x.shape[0]:
            raise ValueError("Batch size mismatch between image input x and annotation u.")
        if u.shape[1] != self.annotation_dim:
             raise ValueError(f"Input annotation dimension ({u.shape[1]}) does not match "
                              f"encoder's expected annotation_dim ({self.annotation_dim})")

        # 1. Initial Conv
        out = self.initial_conv(x)
        out = self.initial_bn(out)
        out = self.initial_relu(out) # Shape: [B, start_channels, 64, 64]

        # 2. Pass through ResBlocks
        out = self.res_blocks(out) # Shape: [B, final_block_channels, 2, 2]

        # 3. Final BN/ReLU and spatial collapse
        out = self.final_bn_relu(out)
        # out = self.final_pool(out) # Option 1: Use Pooling
        out = self.final_conv_collapse(out) # Option 2: Use Conv
        # Output shape after collapse/pool: [B, final_block_channels, 1, 1]

        # 4. Flatten image features
        img_features_flat = out.view(out.size(0), -1) # Shape: [B, final_block_channels]

        # 5. Concatenate flattened image features and annotations u
        combined_features = torch.cat([img_features_flat, u], dim=1)

        # 6. Pass combined features through fully connected layers
        hidden = F.relu(self.fc1(combined_features))
        hidden_dropout = self.dropout(hidden) # Apply dropout

        # 7. Compute mu and logvar
        mu = self.fc_mu(hidden_dropout)
        logvar = self.fc_logvar(hidden_dropout)

        return mu, logvar


# --- Example Usage ---
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# z_dim = 64
# full_annotation_dim = 40

# # Instantiate
# encoder = ConvEncoderConditionalResNet(
#     z_dim_target=z_dim,
#     annotation_dim=full_annotation_dim,
#     start_channels=32 # Example start channels
# ).to(device)

# # Create dummy inputs
# dummy_x = torch.randn(4, 3, 64, 64).to(device)       # Batch size 4
# dummy_u_full = torch.randn(4, full_annotation_dim).to(device)

# # Forward pass
# with torch.no_grad():
#     mu_out, logvar_out = encoder(dummy_x, dummy_u_full)

# print(f"Output mu shape: {mu_out.shape}")         # Should be [4, 64]
# print(f"Output logvar shape: {logvar_out.shape}")   # Should be [4, 64]

# causalclassifier

In [7]:
class CausalMLPClassifier(nn.Module):
    def __init__(self, latent_dim=64, annotation_dim=4, hidden_dim=256, output_dim=4):
        """
        Args:
            latent_dim (int): Dimension of the latent code from the encoder (z_dim), here 64.
            annotation_dim (int): Dimension of the annotation vector u, e.g., 4.
            hidden_dim (int): Number of neurons in the hidden layer.
            output_dim (int): Dimension of the classifier output (typically same as annotation_dim).
        """
        super(CausalMLPClassifier, self).__init__()
        self.fc1 = nn.Linear(latent_dim + annotation_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, z, u):
        # Concatenate the latent code z and the annotation vector u along the feature dimension
        x = torch.cat([z, u], dim=1)  # shape: [B, latent_dim + annotation_dim]
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

# causal Matrix A

In [8]:
import torch

def compute_causal_matrix(latent_features, u, classifier):
    """
    Computes a 4x4 causal matrix A from classifier outputs.
    
    Args:
        latent_features: Tensor of shape [batch_size, feature_dim] from the encoder.
        u: Tensor of shape [batch_size, 4] representing the full annotation vector.
        classifier: A model that takes (latent_features, u) and outputs logits for u (shape [batch_size, 4]).
    
    Returns:
        A_estimated: Tensor of shape [4, 4] where each row i corresponds to the averaged Total Direct Effect
                     when the i-th attribute is zeroed out.
    """
    # Get factual predictions using the full annotation vector
    factual_logits = classifier(latent_features, u)  # shape: [batch_size, 4]
    
    TDE_list = []  # To store the Total Direct Effect for each attribute
    for i in range(u.shape[1]):  # Loop over 4 attributes
        # Create a counterfactual annotation by zeroing out the i-th attribute
        u_cf = u.clone()
        u_cf[:, i] = 0.0
        # Compute counterfactual predictions
        cf_logits = classifier(latent_features, u_cf)
        # Compute the Total Direct Effect for attribute i by averaging the difference over the batch
        TDE_i = (factual_logits - cf_logits).mean(dim=0)  # shape: [4]
        TDE_list.append(TDE_i)
    
    # Stack the TDE vectors along a new dimension to form the 4x4 causal matrix A
    A_estimated = torch.stack(TDE_list, dim=0)  # shape: [4, 4]
    return A_estimated



# Partitioning Into Es and Er

In [9]:
def partition_latent(z, causal_dim=4):
    """
    Splits the latent code z into es and er.
    
    Args:
        z (torch.Tensor): The latent code with shape [B, z_dim].
        causal_dim (int): Number of dimensions to allocate for the causal factors (es).
        
    Returns:
        tuple[torch.Tensor, torch.Tensor]: es (causal factors) and er (residual information).
    """
    es = z[:, :causal_dim]   # First `causal_dim` dimensions as es.
    er = z[:, causal_dim:]   # Remaining dimensions as er.
    return es, er

# MLP based Partitioner

In [10]:
class LatentPartitioner(nn.Module):
    """
    A learnable partitioner that takes the latent vector z
    (shape [B, z_dim]) and outputs:
      - es: causal factors (learned projection to causal_dim, e.g., 4)
      - er: residual factors (learned projection to z_dim - causal_dim)
    """
    def __init__(self, z_dim, causal_dim=4):
        super(LatentPartitioner, self).__init__()
        self.z_dim = z_dim
        self.causal_dim = causal_dim
        self.residual_dim = z_dim - causal_dim
        
        # Instead of hard slicing, use two linear layers to extract each part
        self.es_layer = nn.Linear(z_dim, causal_dim)
        self.er_layer = nn.Linear(z_dim, self.residual_dim)
    
    def forward(self, z):
        es = self.es_layer(z)  # shape: [B, causal_dim]
        er = self.er_layer(z)  # shape: [B, z_dim - causal_dim]
        return es, er

# Linear SCM


In [11]:
def linear_scm(es, A, epsilon=1e-3):
    """
    Computes z_l = (I - Aᵀ + epsilon*I)⁻¹ * es using the pseudo-inverse for stability.
    
    Args:
        es (torch.Tensor): Causal part of latent code, shape [B, n].
        A (torch.Tensor): Estimated causal matrix, shape [n, n].
        epsilon (float): Small constant added to the diagonal.
    
    Returns:
        torch.Tensor: Linear SCM output z_l, shape [B, n].
    """
    n = A.size(0)
    I = torch.eye(n, device=A.device, dtype=A.dtype)
    M = I - A.t() + epsilon * I
    # Use pseudo-inverse to handle singularity
    C_inv = torch.linalg.pinv(M)
    z_l = torch.matmul(es, C_inv.t())
    return z_l


# Non-Linearity Through NN

In [12]:
class NonlinearSCM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        """
        Args:
            input_dim (int): Dimension of the input zₛ from the linear SCM.
            hidden_dim (int): Number of neurons in the hidden layer.
            output_dim (int): Desired output dimension (typically same as input_dim).
        """
        super(NonlinearSCM, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Merging Zc and Er

In [13]:
def merge_latent(z_causal, er):
    """
    Merges the causal representation and residual latent information.
    
    Args:
        z_causal (torch.Tensor): Nonlinear causal representation, shape [B, n]
        er (torch.Tensor): Residual latent information, shape [B, m]
        
    Returns:
        torch.Tensor: Merged latent representation, shape [B, n+m]
    """
    return torch.cat([z_causal, er], dim=1)

# Merging Through MLP

In [14]:
class LearnedFusion(nn.Module):
    """
    A fusion network that learns to combine the causal representation
    (after SCM transformation, z_l) and the residual latent factors (er)
    into a final latent representation.
    """
    def __init__(self, es_dim, er_dim, output_dim):
        """
        Args:
            es_dim (int): Dimension of the causal part (from SCM), e.g. 4.
            er_dim (int): Dimension of the residual part, e.g. (z_dim - 4).
            output_dim (int): Final latent dimension (usually equal to z_dim).
        """
        super(LearnedFusion, self).__init__()
        self.fc1 = nn.Linear(es_dim + er_dim, output_dim)
        self.fc2 = nn.Linear(output_dim, output_dim)
    
    def forward(self, z_l, er):
        # Concatenate the SCM-transformed causal part and residual part
        x = torch.cat([z_l, er], dim=1)
        x = F.gelu(self.fc1(x))
        z_fused = self.fc2(x)
        return z_fused

# Conv Decoder

In [15]:
class ConvDecoder(nn.Module):
    """
    Convolutional Decoder for reconstructing 64x64x3 images from a latent vector.
    Designed to be roughly the inverse of the provided ConvEncoder's 'conv6' path.

    Input: Latent vector z, shape [B, z_dim]
    Output: Reconstructed image, shape [B, 3, 64, 64] (values typically in [0, 1] via Sigmoid)
    """
    def __init__(self, z_dim=64, target_channels=3):
        super().__init__()
        self.z_dim = z_dim
        self.target_channels = target_channels # Should be 3 for RGB images

        # Calculate the number of channels needed before the final 1x1 conv in the encoder
        # This was 256 in the encoder's conv6
        initial_decoder_channels = 256

        # 1. Initial dense layer to project z and reshape
        # Projects z_dim to the shape needed to start the transposed conv sequence (B x 256 x 1 x 1)
        self.fc = nn.Linear(z_dim, initial_decoder_channels * 1 * 1)

        # 2. Transposed Convolutional Layers (roughly mirroring encoder's conv6 in reverse)
        self.deconv_layers = nn.Sequential(
            # Input: B x 256 x 1 x 1
            nn.ReLU(True),
            # Reverse Conv2d(64, 256, 4, 1, padding=1) -> Output: B x 64 x 2 x 2
            nn.ConvTranspose2d(initial_decoder_channels, 64, 4, 1, 1), # k=4, s=1, p=1 -> H_out = (1-1)*1 - 2*1 + 4 = 2
            nn.ReLU(True),

            # Reverse Conv2d(64, 64, 4, 2, 1) -> Output: B x 64 x 4 x 4
            nn.ConvTranspose2d(64, 64, 4, 2, 1), # k=4, s=2, p=1 -> H_out = (2-1)*2 - 2*1 + 4 = 4
            nn.ReLU(True),

            # Reverse Conv2d(64, 64, 4, 2, 1) -> Output: B x 64 x 8 x 8
            nn.ConvTranspose2d(64, 64, 4, 2, 1), # k=4, s=2, p=1 -> H_out = (4-1)*2 - 2*1 + 4 = 8
            nn.ReLU(True),

            # Reverse Conv2d(32, 64, 4, 2, 1) -> Output: B x 32 x 16 x 16
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # k=4, s=2, p=1 -> H_out = (8-1)*2 - 2*1 + 4 = 16
            nn.ReLU(True),

            # Reverse Conv2d(32, 32, 4, 2, 1) -> Output: B x 32 x 32 x 32
            nn.ConvTranspose2d(32, 32, 4, 2, 1), # k=4, s=2, p=1 -> H_out = (16-1)*2 - 2*1 + 4 = 32
            nn.ReLU(True),

            # Reverse Conv2d(3, 32, 4, 2, 1) -> Output: B x 3 x 64 x 64
            nn.ConvTranspose2d(32, self.target_channels, 4, 2, 1) # k=4, s=2, p=1 -> H_out = (32-1)*2 - 2*1 + 4 = 64
        )

        print(f"ConvDecoder initialized for z_dim = {self.z_dim}")
        print(f"Expected input shape: [B, {self.z_dim}]")
        print(f"Output shape: [B, {self.target_channels}, 64, 64]")

    def forward(self, z):
        """
        Decodes the latent vector z into an image.

        Args:
            z (torch.Tensor): Latent vector batch, shape [B, z_dim].

        Returns:
            torch.Tensor: Reconstructed image batch, shape [B, 3, 64, 64].
                          Values are mapped to [0, 1] by the final Sigmoid activation.
        """
        # 1. Project and reshape
        # Ensure z is flat [B, z_dim]
        z_flat = z.view(z.shape[0], -1)
        if z_flat.shape[1] != self.z_dim:
             raise ValueError(f"Input z dimension ({z_flat.shape[1]}) does not match decoder's expected z_dim ({self.z_dim})")

        x = self.fc(z_flat)
        # Reshape to [B, Channels, Height, Width] for convolutional layers
        x = x.view(x.shape[0], 256, 1, 1)

        # 2. Pass through transposed convolutional layers
        x = self.deconv_layers(x)

        # 3. Apply final activation function (Sigmoid for output in [0, 1])
        # If your images are normalized differently (e.g., [-1, 1]), use torch.tanh instead.
        reconstruction = torch.sigmoid(x)

        return reconstruction

# Resnet_Decoder

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UpsampleResBlock(nn.Module):
    """
    A Residual Block for Upsampling using ConvTranspose2d.
    Applies BN and ReLU before convolutions (pre-activation style).
    Handles changes in channels and spatial dimensions via the skip connection.
    """
    def __init__(self, in_channels, out_channels, stride=2, kernel_size=4, padding=1, output_padding=0):
        super().__init__()
        self.stride = stride
        self.channels_match = (in_channels == out_channels)

        # Main (Residual) Path - F(x)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        # Upsamples spatially if stride > 1, also changes channels
        self.conv_t1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
                                          stride=stride, padding=padding, output_padding=output_padding, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        # Second conv (kernel 3x3) refines features without changing size/channels
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)

        # Skip Connection Path
        self.needs_skip_proj = (stride > 1) or not self.channels_match
        if self.needs_skip_proj:
            # This ConvTranspose handles both spatial upsampling (if stride>1)
            # and channel matching needed for the addition.
            # Using kernel=1 for stride=1 upsampling essentially becomes a 1x1 conv for channels.
            # For stride=2, a kernel > 1 (like 2 or 4) might be needed depending on desired output alignment.
            # Let's use a simpler approach: Upsample then 1x1 Conv if needed.
            self.skip_upsample = nn.Identity() # Default if only channels change
            if stride > 1:
                 # Use interpolation + 1x1 conv for clarity and flexibility
                 self.skip_upsample = nn.Upsample(scale_factor=stride, mode='nearest') # Or 'bilinear'
                 # If only channels change (stride=1, C_in!=C_out), skip_in_channels=in_channels
                 # If only stride changes (stride>1, C_in==C_out), skip_in_channels=in_channels
                 # If both change, skip_in_channels=in_channels
                 skip_in_channels = in_channels
            elif not self.channels_match: # stride is 1, only channels change
                 skip_in_channels = in_channels
            else: # stride = 1, channels match = identity skip
                 skip_in_channels = in_channels # Value not used but set for consistency

            # 1x1 Conv to adjust channels after upsampling (if needed) or if only channels changed
            self.skip_conv = nn.Conv2d(skip_in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)


    def forward(self, x):
        # --- Skip Connection Path ---
        identity = x
        if self.needs_skip_proj:
            identity = self.skip_upsample(identity) # Upsample spatially if needed
            identity = self.skip_conv(identity)   # Adjust channels if needed

        # --- Main Path ---
        out = self.bn1(x)
        out = self.relu1(out)
        out = self.conv_t1(out) # Upsample and change channels
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out) # Refine features

        # --- Add Skip Connection ---
        out = out + identity # Element-wise addition

        return out

class ConvDecoderResNet(nn.Module):
    """
    Convolutional Decoder using UpsampleResBlocks.

    Input: Latent vector z, shape [B, z_dim]
    Output: Reconstructed image logits, shape [B, 3, 64, 64]
    """
    def __init__(self, z_dim=64, target_channels=3, start_channels=256, blocks_per_stage=[1,1,1,1,1,1]):
        super().__init__()
        self.z_dim = z_dim
        self.target_channels = target_channels
        self.start_channels = start_channels # Channels after initial dense layer

        # 1. Initial dense layer and reshape
        self.fc = nn.Linear(z_dim, self.start_channels * 1 * 1)

        # 2. Sequence of Upsampling ResBlocks
        layers = []
        current_channels = self.start_channels
        # Define channel progression (example, customize as needed)
        channel_config = [256, 128, 64, 32, 16] # Channels *after* each stage

        # Stage 1: 1x1 -> 2x2 (Special stride/padding)
        out_ch = channel_config[0]
        layers.append(UpsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1, output_padding=0)) # 1->2 requires careful params or adjustment
        # For stride=2: output = (input-1)*stride + kernel - 2*padding + output_padding
        # Output for k=4,s=2,p=1,op=0: (1-1)*2 + 4 - 2*1 + 0 = 2. Seems okay.
        current_channels = out_ch

        # Stage 2: 2x2 -> 4x4
        out_ch = channel_config[0] # Keep 256
        layers.append(UpsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 3: 4x4 -> 8x8
        out_ch = channel_config[1] # 128
        layers.append(UpsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 4: 8x8 -> 16x16
        out_ch = channel_config[2] # 64
        layers.append(UpsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 5: 16x16 -> 32x32
        out_ch = channel_config[3] # 32
        layers.append(UpsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        # Stage 6: 32x32 -> 64x64
        out_ch = channel_config[4] # 16
        layers.append(UpsampleResBlock(current_channels, out_ch, stride=2, kernel_size=4, padding=1))
        current_channels = out_ch

        self.res_blocks = nn.Sequential(*layers)

        # 3. Final layers to get to target channels
        self.final_bn = nn.BatchNorm2d(current_channels)
        self.final_relu = nn.ReLU(inplace=True)
        # Use a standard 3x3 conv for final feature mapping
        self.final_conv = nn.Conv2d(current_channels, self.target_channels, kernel_size=3, stride=1, padding=1, bias=True) # Can use bias here

        print(f"ConvDecoderResNet initialized for z_dim = {self.z_dim}")
        print(f"Upsampling blocks built.")
        print(f"Output shape (logits): [B, {self.target_channels}, 64, 64]")


    def forward(self, z):
        """
        Decodes the latent vector z into image logits.
        """
        # 1. Project and reshape
        z_flat = z.view(z.shape[0], -1)
        if z_flat.shape[1] != self.z_dim:
            raise ValueError(f"Input z dimension ({z_flat.shape[1]}) does not match decoder's expected z_dim ({self.z_dim})")

        x = self.fc(z_flat)
        x = x.view(x.shape[0], self.start_channels, 1, 1) # Reshape B x C x 1 x 1

        # 2. Pass through Residual Blocks
        x = self.res_blocks(x) # Output shape B x final_block_channels x 64 x 64

        # 3. Apply final BN, ReLU, and Conv
        x = self.final_bn(x)
        x = self.final_relu(x)
        logits = self.final_conv(x) # Final layer, no activation

        # Return LOGITS
        return logits


# --- Example Usage and Notes ---

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# z_dim = 64
# target_channels = 3

# # Instantiate
# decoder = ConvDecoderResNet(z_dim=z_dim, target_channels=target_channels).to(device)

# # Create dummy input
# dummy_z = torch.randn(4, z_dim).to(device) # Batch size 4

# # Forward pass
# with torch.no_grad():
#     output_logits = decoder(dummy_z)

# print(f"Output logits shape: {output_logits.shape}") # Should be [4, 3, 64, 64]

# # Remember to apply sigmoid OUTSIDE for visualization or standard BCE loss
# output_image = torch.sigmoid(output_logits)
# print(f"Output image shape after sigmoid: {output_image.shape}")

# # Recommended Loss:
# # L_rec = F.binary_cross_entropy_with_logits(output_logits, target_images, ...)

# loss_functions

# KL_divergence_loss

In [17]:
def kl_divergence_conditional(mu, logvar, u):
    """
    Computes the KL divergence between q(z|x,u) ~ N(mu, sigma^2)
    and the conditional prior p(z|u) ~ N(u, I).
    
    Args:
        mu (torch.Tensor): Mean from the encoder, shape [B, z_dim].
        logvar (torch.Tensor): Log variance from the encoder, shape [B, z_dim].
        u (torch.Tensor): Conditional prior mean (e.g., from annotations), shape [B, z_dim].
                        If u originally has lower dimension, you must map it to z_dim.
    
    Returns:
        torch.Tensor: KL divergence loss averaged over the batch.
    """
    # Convert logvar to variance
    sigma_sq = torch.exp(logvar)
    # Compute KL divergence per latent dimension:
    # KL = 0.5 * [sigma_sq + (mu - u)^2 - 1 - logvar]
    kl = 0.5 * torch.sum(sigma_sq + (mu - u)**2 - 1 - logvar, dim=1)
    return torch.mean(kl)



def kl_Divergence_conditional(mu_q, log_var_q, mu_p, log_var_p):
    var_q = log_var_q.exp()
    var_p = log_var_p.exp()
    kl = 0.5 * (
        (var_q / var_p) + 
        ((mu_p - mu_q).pow(2) / var_p) - 
        1 + 
        log_var_p - log_var_q
    ).sum(1)
    return kl.mean()

# Reconstruction_Loss

In [18]:
def reconstruction_loss1(x, x_reconstructed):
    """
    Computes the reconstruction loss between the original images and the reconstructed images.
    
    Args:
        x (torch.Tensor): Original images with shape [B, C, H, W]. Expected to be in [0, 1].
        x_reconstructed (torch.Tensor): Reconstructed images with shape [B, C, H, W]. Expected to be in [0, 1] if using Sigmoid.
    
    Returns:
        torch.Tensor: The scalar reconstruction loss averaged over the batch.
    """
    # Compute Binary Cross-Entropy loss over all pixels and average over the batch.
    return F.binary_cross_entropy(x_reconstructed, x, reduction='mean')

In [19]:
def reconstruction_loss(target_images, reconstructed_logits):
     """ Calculates BCE loss between images in [0,1] and decoder logits. """
     # Ensure target_images are indeed [0,1]
     # You might add assert statements here in debug mode
     # assert target_images.min() >= 0.0 and target_images.max() <= 1.0

     # reduction='mean' averages over all elements (pixels, channels, batch)
     # Often VAEs sum over features and average over batch:
     loss_per_element = F.binary_cross_entropy_with_logits(reconstructed_logits, target_images, reduction='none')
     loss_summed = loss_per_element.view(loss_per_element.size(0), -1).sum(1) # Sum over all dims except batch
     return loss_summed.mean() # Average over batch

# Causal_Loss

In [20]:
def reweighted_loss(factual_loss, counterfactual_loss, r, k, gamma=1.0):
    reweight_factor = (1 - r) / (1 - r ** k)
    loss = reweight_factor * (factual_loss - gamma * counterfactual_loss)
    return loss
def compute_bce_loss(logits, annotations):
    """
    Computes the binary cross-entropy loss (with logits) between logits and annotations.
    
    Args:
        logits (torch.Tensor): Classifier output logits, shape [B, num_attributes].
        annotations (torch.Tensor): Ground truth annotations, shape [B, num_attributes]. 
                                    Expected to be in [0, 1].
    
    Returns:
        torch.Tensor: Scalar loss (mean over the batch).
    """
    loss = F.binary_cross_entropy_with_logits(logits, annotations, reduction='mean')
    return loss

# Scm_Consistency_Loss

In [21]:
def scm_consistency_loss(z_s, f_out, eps_s, kappa=0.1):
    """
    Computes the SCM consistency loss with threshold kappa:
        L_f = max(0, MSE(z_s, f_out + eps_s) - kappa)
    
    Args:
        z_s (torch.Tensor): Nonlinear causal representation from SCM, shape [B, causal_dim].
        f_out (torch.Tensor): Output of applying the nonlinear function f to A^T * z_s, shape [B, causal_dim].
        eps_s (torch.Tensor): Exogenous causal factors from the encoder (the causal part), shape [B, causal_dim].
        kappa (float): Threshold constant. Loss is zero if MSE is below kappa. Default is 0.1.
    
    Returns:
        torch.Tensor: SCM consistency loss (scalar).
    """
    # Compute the mean squared error between z_s and the sum (f_out + eps_s)
    mse = F.mse_loss(z_s, f_out + eps_s, reduction='mean')
    # Use a hinge (ReLU) to enforce the threshold: loss is (mse - kappa) if mse > kappa, else 0.
    loss = F.relu(mse - kappa)
    return loss

# Training Loop

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# Hyperparameters and device settings
num_epochs = 150
lr = 1e-5
z_dim = 64
causal_dim = 4
full_annotation_dim = 40      # Full annotation vector dimension for encoder and annotation mapper
causal_annotation_dim = 4     # Causal annotation subset for classifier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Beta annealing schedule for KL divergence loss
start_anneal_epoch = 80
end_anneal_epoch = 150
target_beta_kl = 0.01  # Target beta value for KL loss
initial_beta_kl = 0.0001

# Instantiate modules with updated dimensions
encoder = ConvEncoderConditionalResNet(z_dim_target=z_dim, annotation_dim=full_annotation_dim).to(device)
decoder = ConvDecoderResNet(z_dim=z_dim, target_channels=3).to(device)
classifier = CausalMLPClassifier(latent_dim=z_dim, annotation_dim=causal_annotation_dim,
                                 hidden_dim=256, output_dim=causal_annotation_dim).to(device)
nonlinear_scm = NonlinearSCM(input_dim=causal_dim, hidden_dim=16, output_dim=causal_dim).to(device)
fusion_module = LearnedFusion(es_dim=causal_dim, er_dim=z_dim - causal_dim, output_dim=z_dim).to(device)
annotation_mapper = nn.Linear(full_annotation_dim, z_dim).to(device)

# Optimizer: update all parameters together
optimizer = optim.Adam(
    list(encoder.parameters()) +
    list(decoder.parameters()) +
    list(classifier.parameters()) +
    list(nonlinear_scm.parameters()) +
    list(fusion_module.parameters()) +
    list(annotation_mapper.parameters()),
    lr=lr
)

# Loss hyperparameters
r = 0.5
k = 10
gamma = 1.0
weight_causal = 1.0
weight_scm = 1.0
kappa = 0.1

# Training loop
for epoch in range(num_epochs):
    for batch_idx, (images, full_annotations, causal_annotations) in enumerate(dataloader):
        if images is None or full_annotations is None or causal_annotations is None:
            continue

        images = images.to(device)                        # [B, 3, 64, 64]
        full_annotations = full_annotations.to(device)    # [B, 40]
        causal_annotations = causal_annotations.to(device)  # [B, 4]

        optimizer.zero_grad()

        # Encoder: use images and full annotations to get latent parameters
        mu, logvar = encoder(images, full_annotations)    # [B, 64]
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std                                  # [B, 64]

        # Partition z into causal part (Es) and residual (Er)
        es = z[:, :causal_dim]                              # [B, 4]
        er = z[:, causal_dim:]                              # [B, 60] if z_dim=64

        # Compute causal matrix A using classifier outputs; classifier uses causal_annotations (4-dim)
        A_est = compute_causal_matrix(z, causal_annotations, classifier)  # [4, 4]
        if epoch == 0 and batch_idx == 0:
            print(f"\n--- A_estimated (Epoch: {epoch+1}, Batch: {batch_idx}) ---")
            print(A_est)

        if torch.isnan(A_est).any() or torch.isinf(A_est).any():
            raise ValueError("A_estimated contains non-finite values")

        # Linear SCM: transform Es using A_est
        z_l = linear_scm(es, A_est)                         # [B, 4]

        # Nonlinear SCM: map z_l to a nonlinear causal representation z_s
        z_s = nonlinear_scm(z_l)                            # [B, 4]

        # Fusion: combine z_s and Er via the fusion module to get final latent representation z_c
        z_c = fusion_module(z_s, er)                        # [B, 64]

        # Decoder: reconstruct images from z_c
        x_reconstructed_logits = decoder(z_c)# [B, 3, 64, 64]

        x_reconstructed = torch.sigmoid(x_reconstructed_logits)
        # Compute Reconstruction Loss
        L_rec = reconstruction_loss1(images, x_reconstructed)

        # For KL divergence, map full annotations (40 dims) to latent dimension using annotation_mapper.
        u_mapped = annotation_mapper(full_annotations)      # [B, 64]
        L_kl = kl_divergence_conditional(mu, logvar, u_mapped)

        # Compute Causal Loss:
        # Factual: classifier with (z, causal_annotations)
        factual_logits_full = classifier(z, causal_annotations)  # [B,4]

        L_causal_terms = []
        for j in range(causal_annotation_dim):
           # isolate uj and its factual logits
          uj = causal_annotations[:, j]              # [B]
          factual_logits_j = factual_logits_full[:, j]# [B]
          factual_loss_j = F.binary_cross_entropy_with_logits(factual_logits_j, uj, reduction='mean')

               # build counterfactual losses for this target j
    cf_losses_j = []
    for i in range(causal_annotation_dim):
        u_cf = causal_annotations.clone()
        u_cf[:, i] = 0.0
        cf_logits_full = classifier(z, u_cf)       # [B,4]
        cf_logits_i_j = cf_logits_full[:, j]       # [B]
        cf_loss_i_j = F.binary_cross_entropy_with_logits(cf_logits_i_j, uj, reduction='mean')
        cf_losses_j.append(cf_loss_i_j)
    counterfactual_loss_j = sum(cf_losses_j) / causal_annotation_dim

    # compute k_j (number of positives of uj in this batch)
    k_j = (uj == 1).sum().item()
    epsilon_k = 1e-8
    denom = 1 - r**(k_j + epsilon_k)
    if abs(denom) < epsilon_k:
        factor_j = 1.0
    else:
        factor_j = (1 - r) / denom

    # one causal‐loss term per j
    L_causal_terms.append(factor_j * (factual_loss_j - gamma * counterfactual_loss_j))

# final causal loss = mean over dimensions
L_causal = sum(L_causal_terms) / causal_annotation_dim

        # Compute SCM Consistency Loss:
        A_transpose = A_est.t()
        z_s_transformed = torch.matmul(z_s, A_transpose)     # [B, 4]
        f_out = nonlinear_scm(z_s_transformed)               # [B, 4]
        L_scm = scm_consistency_loss(z_s, f_out, es, kappa)

        # Beta annealing schedule for KL divergence:
        if epoch >= start_anneal_epoch:
            progress = (epoch - start_anneal_epoch) / (end_anneal_epoch - start_anneal_epoch)
            current_beta_kl = initial_beta_kl + progress * (target_beta_kl - initial_beta_kl)
            current_beta_kl = min(target_beta_kl, current_beta_kl)
        else:
            current_beta_kl = initial_beta_kl

        # Total Loss with annealed beta for KL divergence
        total_loss = L_rec + current_beta_kl * L_kl + weight_causal * L_causal + weight_scm * L_scm
        total_loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}: Total Loss: {total_loss.item():.4f} | L_rec: {L_rec.item():.4f} | "
          f"L_kl: {L_kl.item():.4f} | L_causal: {L_causal.item():.4f} | L_scm: {L_scm.item():.4f}")

    # GPU memory usage tracking at the end of each epoch
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(0) / 1024**2  # Convert to MB
        reserved = torch.cuda.memory_reserved(0) / 1024**2
        print(f"[GPU] Memory Allocated: {allocated:.2f} MB | Memory Reserved: {reserved:.2f} MB")

    # Display 5 sample original and reconstructed images inline using matplotlib
    num_samples_to_display = 5
    orig_samples = images[:num_samples_to_display].detach().cpu()
    recon_samples = x_reconstructed[:num_samples_to_display].detach().cpu()

    grid_orig = make_grid(orig_samples, nrow=num_samples_to_display, normalize=True, scale_each=True)
    grid_recon = make_grid(recon_samples, nrow=num_samples_to_display, normalize=True, scale_each=True)

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Original Images")
    plt.imshow(grid_orig.permute(1, 2, 0).numpy())
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.title("Reconstructed Images")
    plt.imshow(grid_recon.permute(1, 2, 0).numpy())
    plt.axis("off")
    plt.show()

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 158)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# Hyperparameters and device settings
num_epochs = 150
lr = 1e-5
z_dim = 64
causal_dim = 4
full_annotation_dim = 40      # Full annotation vector dimension for encoder and annotation mapper
causal_annotation_dim = 4     # Causal annotation subset for classifier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Beta annealing schedule for KL divergence loss
start_anneal_epoch = 80
end_anneal_epoch = 150
target_beta_kl = 0.01  # Target beta value for KL loss
initial_beta_kl = 0.0001

# Instantiate modules with updated dimensions
encoder = ConvEncoderConditionalResNet(z_dim_target=z_dim, annotation_dim=full_annotation_dim).to(device)
decoder = ConvDecoderResNet(z_dim=z_dim, target_channels=3).to(device)
classifier = CausalMLPClassifier(latent_dim=z_dim, annotation_dim=causal_annotation_dim,
                                 hidden_dim=256, output_dim=causal_annotation_dim).to(device)
nonlinear_scm = NonlinearSCM(input_dim=causal_dim, hidden_dim=16, output_dim=causal_dim).to(device)
fusion_module = LearnedFusion(es_dim=causal_dim, er_dim=z_dim - causal_dim, output_dim=z_dim).to(device)
annotation_mapper = nn.Linear(full_annotation_dim, z_dim).to(device)

# Optimizer: update all parameters together
optimizer = optim.Adam(
    list(encoder.parameters()) +
    list(decoder.parameters()) +
    list(classifier.parameters()) +
    list(nonlinear_scm.parameters()) +
    list(fusion_module.parameters()) +
    list(annotation_mapper.parameters()),
    lr=lr
)

# Loss hyperparameters
r = 0.5
k = 10
gamma = 1.0
weight_causal = 1.0
weight_scm = 1.0
kappa = 0.1

# Training loop
for epoch in range(num_epochs):
    for batch_idx, (images, full_annotations, causal_annotations) in enumerate(dataloader):
        if images is None or full_annotations is None or causal_annotations is None:
            continue

        images = images.to(device)                        # [B, 3, 64, 64]
        full_annotations = full_annotations.to(device)    # [B, 40]
        causal_annotations = causal_annotations.to(device)  # [B, 4]

        optimizer.zero_grad()

        # Encoder: use images and full annotations to get latent parameters
        mu, logvar = encoder(images, full_annotations)    # [B, 64]
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std                                  # [B, 64]

        # Partition z into causal part (Es) and residual (Er)
        es = z[:, :causal_dim]                              # [B, 4]
        er = z[:, causal_dim:]                              # [B, 60] if z_dim=64

        # Compute causal matrix A using classifier outputs; classifier uses causal_annotations (4-dim)
        A_est = compute_causal_matrix(z, causal_annotations, classifier)  # [4, 4]
        if epoch == 0 and batch_idx == 0:
            print(f"\n--- A_estimated (Epoch: {epoch+1}, Batch: {batch_idx}) ---")
            print(A_est)

        if torch.isnan(A_est).any() or torch.isinf(A_est).any():
            raise ValueError("A_estimated contains non-finite values")

        # Linear SCM: transform Es using A_est
        z_l = linear_scm(es, A_est)                         # [B, 4]

        # Nonlinear SCM: map z_l to a nonlinear causal representation z_s
        z_s = nonlinear_scm(z_l)                            # [B, 4]

        # Fusion: combine z_s and Er via the fusion module to get final latent representation z_c
        z_c = fusion_module(z_s, er)                        # [B, 64]

        # Decoder: reconstruct images from z_c
        x_reconstructed_logits = decoder(z_c)# [B, 3, 64, 64]

        x_reconstructed = torch.sigmoid(x_reconstructed_logits)
        # Compute Reconstruction Loss
        L_rec = reconstruction_loss1(images, x_reconstructed)

        # For KL divergence, map full annotations (40 dims) to latent dimension using annotation_mapper.
        u_mapped = annotation_mapper(full_annotations)      # [B, 64]
        L_kl = kl_divergence_conditional(mu, logvar, u_mapped)
        
        # Compute Causal Loss:
        # Factual: classifier with (z, causal_annotations)
        factual_logits = classifier(z, causal_annotations)   # [B, 4]
        factual_loss = F.binary_cross_entropy_with_logits(factual_logits, causal_annotations, reduction='mean')
        cf_losses = []
        for i in range(causal_annotations.shape[1]):
            u_cf = causal_annotations.clone()
            u_cf[:, i] = 0.0
            cf_logits = classifier(z, u_cf)
            cf_loss = F.binary_cross_entropy_with_logits(cf_logits, causal_annotations, reduction='mean')
            cf_losses.append(cf_loss)
        counterfactual_loss = sum(cf_losses) / causal_annotations.shape[1]
        L_causal = reweighted_loss(factual_loss, counterfactual_loss, r, k, gamma)

        # Compute SCM Consistency Loss:
        A_transpose = A_est.t()
        z_s_transformed = torch.matmul(z_s, A_transpose)     # [B, 4]
        f_out = nonlinear_scm(z_s_transformed)               # [B, 4]
        L_scm = scm_consistency_loss(z_s, f_out, es, kappa)

        # Beta annealing schedule for KL divergence:
        if epoch >= start_anneal_epoch:
            progress = (epoch - start_anneal_epoch) / (end_anneal_epoch - start_anneal_epoch)
            current_beta_kl = initial_beta_kl + progress * (target_beta_kl - initial_beta_kl)
            current_beta_kl = min(target_beta_kl, current_beta_kl)
        else:
            current_beta_kl = initial_beta_kl

        # Total Loss with annealed beta for KL divergence
        total_loss = L_rec + current_beta_kl * L_kl + weight_causal * L_causal + weight_scm * L_scm
        total_loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}: Total Loss: {total_loss.item():.4f} | L_rec: {L_rec.item():.4f} | "
          f"L_kl: {L_kl.item():.4f} | L_causal: {L_causal.item():.4f} | L_scm: {L_scm.item():.4f}")

    # GPU memory usage tracking at the end of each epoch
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(0) / 1024**2  # Convert to MB
        reserved = torch.cuda.memory_reserved(0) / 1024**2
        print(f"[GPU] Memory Allocated: {allocated:.2f} MB | Memory Reserved: {reserved:.2f} MB")

    # Display 5 sample original and reconstructed images inline using matplotlib
    num_samples_to_display = 5
    orig_samples = images[:num_samples_to_display].detach().cpu()
    recon_samples = x_reconstructed[:num_samples_to_display].detach().cpu()

    grid_orig = make_grid(orig_samples, nrow=num_samples_to_display, normalize=True, scale_each=True)
    grid_recon = make_grid(recon_samples, nrow=num_samples_to_display, normalize=True, scale_each=True)

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Original Images")
    plt.imshow(grid_orig.permute(1, 2, 0).numpy())
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.title("Reconstructed Images")
    plt.imshow(grid_recon.permute(1, 2, 0).numpy())
    plt.axis("off")
    plt.show()