<a href="https://colab.research.google.com/github/your-username/flux-transparent-png/blob/main/colab/train_transparent_png.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Flux.1 Transparent PNG Training

This notebook allows you to train a modified VAE on transparent PNG images using Flux.1 dev. The trained VAE can then be used to generate new transparent PNG images without backgrounds.

## Setup

First, let's install the required dependencies and clone the repository.

In [None]:
# Check if running in Google Colab
import sys
IN_COLAB = 'google.colab' in sys.modules
print(f"Running in Colab: {IN_COLAB}")

# Mount Google Drive if in Colab
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted at /content/drive")

In [None]:
# Clone the repository
!git clone https://github.com/your-username/flux-transparent-png.git
%cd flux-transparent-png/python

In [None]:
# Install dependencies
!pip install torch torchvision diffusers transformers pillow numpy matplotlib tqdm

## Configuration

Set up the training configuration parameters.

In [None]:
# Configuration parameters
DATA_DIR = "/content/drive/MyDrive/SD-Data/TrainData/4000_PNG/TEST"
OUTPUT_DIR = "/content/drive/MyDrive/VAE-DECODER"
CHECKPOINT_DIR = "/content/drive/MyDrive/VAE-DECODER/checkpoints"
BATCH_SIZE = 8
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
ALPHA_WEIGHT = 2.0
IMAGE_SIZE = 512

# Create output directories
!mkdir -p {OUTPUT_DIR}
!mkdir -p {CHECKPOINT_DIR}

## Explore Training Data

Let's explore the training data to make sure it's properly loaded.

In [None]:
# Import the dataset class
from train_transparent_png import TransparentPNGDataset
import matplotlib.pyplot as plt
import torch
import numpy as np
from torch.utils.data import DataLoader

# Create dataset
dataset = TransparentPNGDataset(DATA_DIR, image_size=IMAGE_SIZE)
print(f"Found {len(dataset)} images in {DATA_DIR}")

# Create dataloader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Visualize a few samples
fig, axes = plt.subplots(3, 2, figsize=(12, 18))

for i, batch in enumerate(dataloader):
    if i >= 3:
        break
        
    # Get image
    image = batch['image'][0]  # Remove batch dimension
    
    # Convert to numpy for visualization
    image_np = image.numpy()
    
    # Scale from [-1, 1] to [0, 1]
    image_np = (image_np + 1) / 2
    
    # Clip to valid range
    image_np = np.clip(image_np, 0, 1)
    
    # RGB channels
    axes[i, 0].imshow(np.transpose(image_np[:3], (1, 2, 0)))
    axes[i, 0].set_title(f"Sample {i+1} - RGB Channels")
    axes[i, 0].axis('off')
    
    # Alpha channel
    axes[i, 1].imshow(image_np[3], cmap='gray')
    axes[i, 1].set_title(f"Sample {i+1} - Alpha Channel")
    axes[i, 1].axis('off')

plt.tight_layout()
plt.show()

## Train the Model

Now let's train the transparent VAE model.

In [None]:
# Run the training
!python train_transparent_png.py \
  --data_dir="{DATA_DIR}" \
  --output_dir="{OUTPUT_DIR}" \
  --checkpoint_dir="{CHECKPOINT_DIR}" \
  --batch_size={BATCH_SIZE} \
  --num_epochs={NUM_EPOCHS} \
  --learning_rate={LEARNING_RATE} \
  --alpha_weight={ALPHA_WEIGHT} \
  --image_size={IMAGE_SIZE}

## Save the Trained Models

After training, let's save the VAE and decoder models.

In [None]:
# Save the models
!python save_vae_decoder.py \
  --checkpoint_dir="{CHECKPOINT_DIR}" \
  --output_dir="{OUTPUT_DIR}" \
  --vae_filename="transparent_vae.pt" \
  --decoder_filename="transparent_decoder.pt" \
  --verify

## Visualize Training Results

Let's visualize some of the training results.

In [None]:
import glob
from PIL import Image
import matplotlib.pyplot as plt

# Find visualization files
vis_dirs = sorted(glob.glob(f"{OUTPUT_DIR}/visualizations_epoch_*"))
if vis_dirs:
    latest_vis_dir = vis_dirs[-1]
    print(f"Showing visualizations from {latest_vis_dir}")
    
    # Get original and reconstructed images
    original_files = sorted(glob.glob(f"{latest_vis_dir}/original_*.png"))
    recon_files = sorted(glob.glob(f"{latest_vis_dir}/reconstructed_*.png"))
    
    # Display images
    n = min(3, len(original_files))
    fig, axes = plt.subplots(n, 2, figsize=(12, 6*n))
    
    for i in range(n):
        # Original
        original = Image.open(original_files[i])
        axes[i, 0].imshow(original)
        axes[i, 0].set_title(f"Original {i+1}")
        axes[i, 0].axis('off')
        
        # Reconstructed
        recon = Image.open(recon_files[i])
        axes[i, 1].imshow(recon)
        axes[i, 1].set_title(f"Reconstructed {i+1}")
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No visualization files found.")

## Next Steps

Now that you have trained the VAE and decoder models, you can use them to generate transparent PNG images. See the `generate_transparent_png.ipynb` notebook for details on how to generate images using your trained models.