<a href="https://colab.research.google.com/github/taras-musakovskyi/colab-jupyter-fish-models/blob/main/Species_Name_Conditioned_Single_VAE_for_Multi_Scale_Fish_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
Size-Conditioned Single VAE for Multi-Scale Fish Dataset
Dataset: 9,498 fish crops across 3 size categories (small/medium/wide)
Training: Google Colab A100 40GB
Purpose: Part 5 - Fish inpainting into apartment scenes
"""

# ============================================================================
# CELL 1: Environment Setup & Installation
# ============================================================================
!pip install -q torch torchvision --upgrade
!pip install -q pillow numpy tqdm matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import zipfile
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
import numpy as np
from dataclasses import dataclass
from typing import Optional
import time
import gc
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

PyTorch version: 2.8.0+cu126
CUDA available: True
GPU: NVIDIA A100-SXM4-40GB
VRAM: 39.56 GB


In [None]:
# ============================================================================
# CELL 2: Google Drive Authentication (for uploading results)
# ============================================================================
from google.colab import auth
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseUpload
import io

print("🔐 Authenticating with Google Drive for results upload...")

auth.authenticate_user()
from google.auth import default
creds, _ = default()

drive_service = build('drive', 'v3', credentials=creds)

print("✓ Authentication successful!")

# Set your results upload folder ID
GDRIVE_RESULTS_FOLDER_ID = "1aIbv7Jgad_BAcJpwTQp9sYn4_2oZToXp"# VAE - for inpaint - size conditioned

if "YOUR_" in GDRIVE_RESULTS_FOLDER_ID or "_HERE" in GDRIVE_RESULTS_FOLDER_ID:
    print("\n⚠️  WARNING: Set GDRIVE_RESULTS_FOLDER_ID to upload results!")
else:
    print(f"✓ Results will be uploaded to folder ID: {GDRIVE_RESULTS_FOLDER_ID}")

🔐 Authenticating with Google Drive for results upload...
✓ Authentication successful!
✓ Results will be uploaded to folder ID: 1aIbv7Jgad_BAcJpwTQp9sYn4_2oZToXp


In [None]:
@dataclass
class VAEConfig:
    # Model architecture
    image_size: int = 256
    latent_dim: int = 256
    hidden_dims: list = None
    species_embed_dim: int = 128

    # Training settings
    num_epochs: int = 250
    batch_size: int = 32
    learning_rate: float = 1e-4
    beta: float = 1.0

    # Data settings
    num_workers: int = 0

    # Checkpointing
    save_every_n_epochs: int = 50
    validation_every_n_epochs: int = 50
    upload_after_epoch: int = 90
    upload_interval_epochs: int = 50

    # Paths
    output_dir: str = "/content/species_vae_output"
    dataset_dir: str = "/content/fish_dataset"

    # Species categories - CHANGED
    species_categories: dict = None

    def __post_init__(self):
        if self.hidden_dims is None:
            self.hidden_dims = [32, 64, 128, 256, 512]

        if self.species_categories is None:
            self.species_categories = {
                'guppy': 0,  # Combined guppy_female + guppy_male
                'gold_molly': 1,
                'gold_fish': 2,
                'ancistrus': 3,
                'black_molly': 4,
                'dalmatian_molly': 5
            }

config = VAEConfig()

print("="*60)
print("SPECIES-CONDITIONED VAE TRAINING CONFIGURATION")
print("="*60)
print(f"Image Size: {config.image_size}×{config.image_size}")
print(f"Latent Dim: {config.latent_dim}")
print(f"Species Embedding Dim: {config.species_embed_dim}")
print(f"Epochs: {config.num_epochs}")
print(f"Batch Size: {config.batch_size}")
print(f"Learning Rate: {config.learning_rate}")
print(f"Beta (KL weight): {config.beta}")
print(f"\nSpecies Categories: {config.species_categories}")
print("="*60)

SPECIES-CONDITIONED VAE TRAINING CONFIGURATION
Image Size: 256×256
Latent Dim: 256
Species Embedding Dim: 128
Epochs: 250
Batch Size: 32
Learning Rate: 0.0001
Beta (KL weight): 1.0

Species Categories: {'guppy': 0, 'gold_molly': 1, 'gold_fish': 2, 'ancistrus': 3, 'black_molly': 4, 'dalmatian_molly': 5}


In [None]:
# ============================================================================
# CELL 4: Upload Dataset from Local Machine
# ============================================================================
from google.colab import files

print("📦 Upload your fish dataset ZIP file")
print("   Expected structure inside ZIP:")
print("   - small/")
print("   - medium/")
print("   - wide/")
print()

uploaded = files.upload()

# Get the uploaded filename
zip_filename = list(uploaded.keys())[0]
print(f"\n✓ Uploaded: {zip_filename}")

# Extract dataset
print("📂 Extracting dataset...")
os.makedirs(config.dataset_dir, exist_ok=True)

with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall(config.dataset_dir)

print("✓ Dataset extracted!")

# Verify folder structure
dataset_stats = {}
for size_folder in ['small', 'medium', 'wide']:  # Hardcode instead
    folder_path = os.path.join(config.dataset_dir, size_folder)
    if os.path.exists(folder_path):
        num_images = len([f for f in os.listdir(folder_path) if f.endswith(('.jpg', '.jpeg', '.png'))])
        dataset_stats[size_folder] = num_images
        print(f"  ✓ {size_folder}: {num_images} images")
    else:
        print(f"  ⚠️  {size_folder}: NOT FOUND")
        dataset_stats[size_folder] = 0

total_images = sum(dataset_stats.values())
print(f"\n📊 Total images: {total_images}")

  ✓ small: 1771 images
  ✓ medium: 4038 images
  ✓ wide: 3689 images

📊 Total images: 9498


In [None]:
# ============================================================================
# CELL 5: Custom Dataset with Size Conditioning
# ============================================================================
class SpeciesConditionedFishDataset(Dataset):  # CHANGED name
    """
    Dataset that provides fish images with species category labels.
    All images resized to common resolution.
    """
    def __init__(self, dataset_dir, species_categories, image_size=256):
        self.image_size = image_size
        self.species_categories = species_categories

        self.transform = transforms.Compose([
            transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])

        # Collect all image paths with species labels
        self.samples = []

        # Iterate through all size folders
        for size_folder in ['small', 'medium', 'wide']:
            folder_path = os.path.join(dataset_dir, size_folder)

            if not os.path.exists(folder_path):
                print(f"⚠️  Skipping {size_folder} - folder not found")
                continue

            image_files = [
                f for f in os.listdir(folder_path)
                if f.lower().endswith(('.jpg', '.jpeg', '.png'))
            ]

            for img_file in image_files:
                # Extract species from filename
                species_name = self._extract_species(img_file)

                if species_name in species_categories:
                    species_id = species_categories[species_name]
                    img_path = os.path.join(folder_path, img_file)
                    self.samples.append((img_path, species_id))

        # Print statistics
        species_counts = {}
        for _, species_id in self.samples:
            species_name = [k for k, v in species_categories.items() if v == species_id][0]
            species_counts[species_name] = species_counts.get(species_name, 0) + 1

        print(f"\n✓ Dataset ready: {len(self.samples)} total samples")
        print(f"\nSpecies distribution:")
        for species_name, count in sorted(species_counts.items()):
            print(f"  {species_name}: {count} images")

    def _extract_species(self, filename):
        """
        Extract species name from filename.
        Format: {species}_{timestamp}_{number}_conf{confidence}.jpg
        Combines guppy_female and guppy_male into 'guppy'
        """
        # Remove extension
        name_without_ext = filename.rsplit('.', 1)[0]

        # Split by underscore
        parts = name_without_ext.split('_')

        # Find where timestamp starts (8-digit number)
        species_parts = []
        for part in parts:
            if part.isdigit() and len(part) == 8:
                break
            species_parts.append(part)

        if not species_parts:
            return 'unknown'

        # Join species parts
        species_name = '_'.join(species_parts)

        # Combine guppy variants
        if species_name in ['guppy_female', 'guppy_male']:
            return 'guppy'

        return species_name

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, species_id = self.samples[idx]

        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)

        return {
            'image': image,
            'species_id': species_id
        }

# Create dataset
print("="*60)
print("CREATING SPECIES-CONDITIONED DATASET")
print("="*60)

dataset = SpeciesConditionedFishDataset(
    config.dataset_dir,
    config.species_categories,
    config.image_size
)

# Create dataloader
train_dataloader = DataLoader(
    dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True
)

print(f"\n✓ DataLoader created: {len(train_dataloader)} batches per epoch")

CREATING SPECIES-CONDITIONED DATASET

✓ Dataset ready: 9498 total samples

Species distribution:
  ancistrus: 1664 images
  black_molly: 877 images
  dalmatian_molly: 500 images
  gold_fish: 1938 images
  gold_molly: 2413 images
  guppy: 2106 images

✓ DataLoader created: 297 batches per epoch


In [None]:
# ============================================================================
# CELL 6: Size-Conditioned VAE Architecture
# ============================================================================
class SpeciesConditionedVAE(nn.Module):  # CHANGED name
    """
    VAE with species category conditioning.  # CHANGED
    Handles all fish species in single unified latent space.  # CHANGED
    """
    def __init__(self, config):
        super().__init__()

        self.latent_dim = config.latent_dim
        self.species_embed_dim = config.species_embed_dim  # CHANGED

        # Species embedding layer - CHANGED
        num_species_categories = len(config.species_categories)
        self.species_embedding = nn.Embedding(num_species_categories, config.species_embed_dim)

        # [Rest of __init__ same, just rename size_embed_dim → species_embed_dim in comments]

        # Encoder (same)
        modules = []
        in_channels = 3
        for h_dim in config.hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, h_dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU()
                )
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)

        self.flatten_size = config.hidden_dims[-1] * (config.image_size // (2 ** len(config.hidden_dims))) ** 2

        # Latent layers (conditioned on species embedding) - CHANGED comment
        self.fc_mu = nn.Linear(self.flatten_size + config.species_embed_dim, config.latent_dim)
        self.fc_var = nn.Linear(self.flatten_size + config.species_embed_dim, config.latent_dim)

        # Decoder input (latent + species embedding) - CHANGED comment
        self.decoder_input = nn.Linear(config.latent_dim + config.species_embed_dim, self.flatten_size)

        # Decoder (same)
        modules = []
        hidden_dims_reversed = config.hidden_dims[::-1]

        for i in range(len(hidden_dims_reversed) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        hidden_dims_reversed[i],
                        hidden_dims_reversed[i + 1],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1
                    ),
                    nn.BatchNorm2d(hidden_dims_reversed[i + 1]),
                    nn.LeakyReLU()
                )
            )

        self.decoder = nn.Sequential(*modules)

        # Final layer (same)
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(
                hidden_dims_reversed[-1],
                hidden_dims_reversed[-1],
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1
            ),
            nn.BatchNorm2d(hidden_dims_reversed[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims_reversed[-1], 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x, species_id):  # CHANGED param name
        """
        Encode input image conditioned on species category.  # CHANGED
        """
        # Get species embedding - CHANGED
        species_embed = self.species_embedding(species_id)

        # Encode image
        encoded = self.encoder(x)
        encoded = torch.flatten(encoded, start_dim=1)

        # Concatenate with species embedding - CHANGED
        encoded = torch.cat([encoded, species_embed], dim=1)

        # Get mean and log variance
        mu = self.fc_mu(encoded)
        log_var = self.fc_var(encoded)

        return mu, log_var

    def reparameterize(self, mu, log_var):
        """Reparameterization trick."""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, species_id):  # CHANGED param name
        """
        Decode latent vector conditioned on species category.  # CHANGED
        """
        # Get species embedding - CHANGED
        species_embed = self.species_embedding(species_id)

        # Concatenate latent with species embedding - CHANGED
        z = torch.cat([z, species_embed], dim=1)

        # Decode
        result = self.decoder_input(z)
        result = result.view(-1, config.hidden_dims[-1],
                           config.image_size // (2 ** len(config.hidden_dims)),
                           config.image_size // (2 ** len(config.hidden_dims)))
        result = self.decoder(result)
        result = self.final_layer(result)

        return result

    def forward(self, x, species_id):  # CHANGED param name
        """Forward pass through entire VAE."""
        mu, log_var = self.encode(x, species_id)
        z = self.reparameterize(mu, log_var)
        reconstruction = self.decode(z, species_id)

        return reconstruction, mu, log_var

# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SpeciesConditionedVAE(config).to(device)  # CHANGED

print("="*60)
print("SPECIES-CONDITIONED VAE ARCHITECTURE")  # CHANGED
print("="*60)
print(model)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("="*60)

SPECIES-CONDITIONED VAE ARCHITECTURE
SpeciesConditionedVAE(
  (species_embedding): Embedding(6, 128)
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=Tr

In [None]:
# ============================================================================
# CELL 7: Loss Function and Optimizer
# ============================================================================
def vae_loss(recon_x, x, mu, log_var, beta=1.0):
    """
    VAE loss = Reconstruction loss + KL divergence
    """
    # Reconstruction loss (MSE)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    # Total loss
    total_loss = recon_loss + beta * kl_loss

    return total_loss, recon_loss, kl_loss

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)

print("✓ Loss function and optimizer configured")
print(f"  Optimizer: Adam (lr={config.learning_rate})")
print(f"  Scheduler: ReduceLROnPlateau")
print(f"  Beta (KL weight): {config.beta}")

✓ Loss function and optimizer configured
  Optimizer: Adam (lr=0.0001)
  Scheduler: ReduceLROnPlateau
  Beta (KL weight): 1.0


In [None]:
# ============================================================================
# CELL 8: Create Output Directories
# ============================================================================
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{config.output_dir}/samples", exist_ok=True)
os.makedirs(f"{config.output_dir}/reconstructions", exist_ok=True)

print(f"✓ Output directories created at {config.output_dir}")

✓ Output directories created at /content/species_vae_output


In [None]:
# ============================================================================
# CELL 9: Training Loop with Periodic Upload
# ============================================================================
def upload_model_to_gdrive(epoch, model_path, max_retries=3):
    """Upload model weights to Google Drive"""
    for attempt in range(max_retries):
        try:
            print(f"\n📤 Uploading model to Google Drive (attempt {attempt+1}/{max_retries})...")

            if not os.path.exists(model_path):
                print("⚠️ Model not found, skipping upload")
                return False

            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            remote_filename = f'size_vae_epoch_{epoch}_{timestamp}.pth'

            file_metadata = {
                'name': remote_filename,
                'parents': [GDRIVE_RESULTS_FOLDER_ID]
            }

            print(f"   Uploading: {os.path.basename(model_path)}")
            with open(model_path, 'rb') as f:
                media = MediaIoBaseUpload(
                    io.BytesIO(f.read()),
                    mimetype='application/octet-stream',
                    resumable=True
                )

            drive_service.files().create(
                body=file_metadata,
                media_body=media,
                fields='id'
            ).execute()

            print(f"✅ Model uploaded: {remote_filename}")
            return True

        except Exception as e:
            print(f"⚠️ Upload attempt {attempt+1} failed: {e}")
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                print(f"   Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print(f"❌ Upload failed after {max_retries} attempts")
                return False

def save_reconstructions(model, dataloader, epoch, device, num_samples=8):
    """Generate and save reconstruction examples"""
    model.eval()

    with torch.no_grad():
        # Get a batch
        batch = next(iter(dataloader))
        images = batch['image'][:num_samples].to(device)
        species_ids = batch['species_id'][:num_samples].to(device)

        # Generate reconstructions
        #recons, _, _ = model(images, size_ids)
        recons, mu, log_var = model(images, species_ids)

        # Create comparison grid
        fig, axes = plt.subplots(2, num_samples, figsize=(num_samples*2, 4))

        for i in range(num_samples):
            # Original
            axes[0, i].imshow(images[i].cpu().permute(1, 2, 0))
            axes[0, i].axis('off')
            if i == 0:
                axes[0, i].set_title('Original', fontsize=10)

            # Reconstruction
            axes[1, i].imshow(recons[i].cpu().permute(1, 2, 0))
            axes[1, i].axis('off')
            if i == 0:
                axes[1, i].set_title('Reconstructed', fontsize=10)

        plt.tight_layout()
        save_path = f"{config.output_dir}/reconstructions/epoch_{epoch:03d}.png"
        plt.savefig(save_path, dpi=100, bbox_inches='tight')
        plt.close()

        print(f"✓ Saved reconstructions: {save_path}")

    model.train()

# Training state
best_loss = float('inf')
best_model_path = f"{config.output_dir}/checkpoints/vae_best.pth"
last_uploaded_epoch = 0

print("="*60)
print("STARTING TRAINING")
print("="*60)

start_epoch = 0
checkpoint_path = f"{config.output_dir}/checkpoints/vae_best.pth"

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"✓ Resuming from epoch {start_epoch}")


for epoch in range(start_epoch, config.num_epochs):
    model.train()
    epoch_loss = 0.0
    epoch_recon_loss = 0.0
    epoch_kl_loss = 0.0

    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{config.num_epochs}")

    for batch in progress_bar:
        images = batch['image'].to(device)
        species_ids = batch['species_id'].to(device)

        # Forward pass
        recon_images, mu, log_var = model(images, species_ids)


        # Calculate loss
        loss, recon_loss, kl_loss = vae_loss(
            recon_images, images, mu, log_var, config.beta
        )

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Logging
        epoch_loss += loss.item()
        epoch_recon_loss += recon_loss.item()
        epoch_kl_loss += kl_loss.item()

        progress_bar.set_postfix({
            "loss": f"{loss.item():.2f}",
            "recon": f"{recon_loss.item():.2f}",
            "kl": f"{kl_loss.item():.2f}"
        })

    # End of epoch
    avg_loss = epoch_loss / len(train_dataloader.dataset)
    avg_recon = epoch_recon_loss / len(train_dataloader.dataset)
    avg_kl = epoch_kl_loss / len(train_dataloader.dataset)

    print(f"\nEpoch {epoch+1}/{config.num_epochs}")
    print(f"  Average Loss: {avg_loss:.4f}")
    print(f"  Recon Loss: {avg_recon:.4f}")
    print(f"  KL Loss: {avg_kl:.4f}")

    # Update learning rate
    scheduler.step(avg_loss)

    # Save best model
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'config': config
        }, best_model_path)
        print(f"  🌟 New best model saved! Loss: {avg_loss:.4f}")

    # Upload to Google Drive
    if (epoch + 1) >= config.upload_after_epoch and \
        (epoch + 1 - last_uploaded_epoch) >= config.upload_interval_epochs:
        if upload_model_to_gdrive(epoch + 1, best_model_path):
            last_uploaded_epoch = epoch + 1

    # Save checkpoint periodically
    if (epoch + 1) % config.save_every_n_epochs == 0:
        checkpoint_path = f"{config.output_dir}/checkpoints/vae_epoch_{epoch+1}.pth"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'config': config
        }, checkpoint_path)
        print(f"  ✓ Checkpoint saved: epoch {epoch+1}")

    # Generate reconstruction samples
    if (epoch + 1) % config.validation_every_n_epochs == 0:
        save_reconstructions(model, train_dataloader, epoch + 1, device)

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)
print(f"Best loss: {best_loss:.4f}")
print(f"Best model saved at: {best_model_path}")

# Final upload
print("\n📤 Uploading final best model...")
upload_model_to_gdrive(config.num_epochs, best_model_path)

print("\n✅ All done! Check your Google Drive for uploaded models.")

STARTING TRAINING
✓ Resuming from epoch 50


Epoch 51/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 51/250
  Average Loss: 1417.6947
  Recon Loss: 1014.8883
  KL Loss: 402.8063
  🌟 New best model saved! Loss: 1417.6947


Epoch 52/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 52/250
  Average Loss: 1408.4788
  Recon Loss: 1006.3965
  KL Loss: 402.0823
  🌟 New best model saved! Loss: 1408.4788


Epoch 53/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 53/250
  Average Loss: 1403.6086
  Recon Loss: 1001.0158
  KL Loss: 402.5928
  🌟 New best model saved! Loss: 1403.6086


Epoch 54/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 54/250
  Average Loss: 1395.4114
  Recon Loss: 993.0039
  KL Loss: 402.4075
  🌟 New best model saved! Loss: 1395.4114


Epoch 55/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 55/250
  Average Loss: 1392.2717
  Recon Loss: 989.5688
  KL Loss: 402.7030
  🌟 New best model saved! Loss: 1392.2717


Epoch 56/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 56/250
  Average Loss: 1386.7327
  Recon Loss: 984.1250
  KL Loss: 402.6077
  🌟 New best model saved! Loss: 1386.7327


Epoch 57/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 57/250
  Average Loss: 1379.9048
  Recon Loss: 977.6155
  KL Loss: 402.2893
  🌟 New best model saved! Loss: 1379.9048


Epoch 58/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 58/250
  Average Loss: 1374.1920
  Recon Loss: 971.5107
  KL Loss: 402.6814
  🌟 New best model saved! Loss: 1374.1920


Epoch 59/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 59/250
  Average Loss: 1369.3444
  Recon Loss: 966.8255
  KL Loss: 402.5189
  🌟 New best model saved! Loss: 1369.3444


Epoch 60/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


Epoch 60/250
  Average Loss: 1361.1551
  Recon Loss: 958.3951
  KL Loss: 402.7600
  🌟 New best model saved! Loss: 1361.1551


Epoch 61/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()    if w.is_alive():
   Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>  
 Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^ ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process' ^
^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ 
^  File "/us


Epoch 61/250
  Average Loss: 1356.9350
  Recon Loss: 954.2271
  KL Loss: 402.7079
  🌟 New best model saved! Loss: 1356.9350


Epoch 62/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 62/250
  Average Loss: 1351.6320
  Recon Loss: 949.3210
  KL Loss: 402.3111
  🌟 New best model saved! Loss: 1351.6320


Epoch 63/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 63/250
  Average Loss: 1344.8314
  Recon Loss: 942.8518
  KL Loss: 401.9796
  🌟 New best model saved! Loss: 1344.8314


Epoch 64/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 64/250
  Average Loss: 1341.9572
  Recon Loss: 940.4152
  KL Loss: 401.5420
  🌟 New best model saved! Loss: 1341.9572


Epoch 65/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 65/250
  Average Loss: 1334.6348
  Recon Loss: 932.6317
  KL Loss: 402.0030
  🌟 New best model saved! Loss: 1334.6348


Epoch 66/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 66/250
  Average Loss: 1328.3088
  Recon Loss: 926.3699
  KL Loss: 401.9389
  🌟 New best model saved! Loss: 1328.3088


Epoch 67/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 67/250
  Average Loss: 1322.3378
  Recon Loss: 920.5292
  KL Loss: 401.8086
  🌟 New best model saved! Loss: 1322.3378


Epoch 68/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 68/250
  Average Loss: 1318.9850
  Recon Loss: 917.1652
  KL Loss: 401.8198
  🌟 New best model saved! Loss: 1318.9850


Epoch 69/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 69/250
  Average Loss: 1314.3260
  Recon Loss: 912.6241
  KL Loss: 401.7019
  🌟 New best model saved! Loss: 1314.3260


Epoch 70/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 70/250
  Average Loss: 1310.0781
  Recon Loss: 907.7578
  KL Loss: 402.3204
  🌟 New best model saved! Loss: 1310.0781


Epoch 71/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 71/250
  Average Loss: 1303.7259
  Recon Loss: 901.6778
  KL Loss: 402.0481
  🌟 New best model saved! Loss: 1303.7259


Epoch 72/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


Epoch 72/250
  Average Loss: 1300.8681
  Recon Loss: 899.3314
  KL Loss: 401.5366
  🌟 New best model saved! Loss: 1300.8681


Epoch 73/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

Traceback (most recent call last):
      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
if w.is_alive():    
self._shutdown_workers() 
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
       if w.is_alive(): 
  ^ ^ Exception ignored in: ^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>^ 
^ Traceback (most recent call last):
^   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataload


Epoch 73/250
  Average Loss: 1295.7352
  Recon Loss: 894.0196
  KL Loss: 401.7156
  🌟 New best model saved! Loss: 1295.7352


Epoch 74/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 74/250
  Average Loss: 1292.5897
  Recon Loss: 890.8380
  KL Loss: 401.7517
  🌟 New best model saved! Loss: 1292.5897


Epoch 75/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 75/250
  Average Loss: 1288.9444
  Recon Loss: 886.6890
  KL Loss: 402.2553
  🌟 New best model saved! Loss: 1288.9444


Epoch 76/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 76/250
  Average Loss: 1284.8188
  Recon Loss: 882.8594
  KL Loss: 401.9594
  🌟 New best model saved! Loss: 1284.8188


Epoch 77/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 77/250
  Average Loss: 1281.5572
  Recon Loss: 880.0862
  KL Loss: 401.4710
  🌟 New best model saved! Loss: 1281.5572


Epoch 78/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 78/250
  Average Loss: 1273.4451
  Recon Loss: 871.7686
  KL Loss: 401.6765
  🌟 New best model saved! Loss: 1273.4451


Epoch 79/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 79/250
  Average Loss: 1274.6195
  Recon Loss: 873.2853
  KL Loss: 401.3343


Epoch 80/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 80/250
  Average Loss: 1269.9077
  Recon Loss: 868.2419
  KL Loss: 401.6658
  🌟 New best model saved! Loss: 1269.9077


Epoch 81/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 81/250
  Average Loss: 1266.2487
  Recon Loss: 864.7681
  KL Loss: 401.4806
  🌟 New best model saved! Loss: 1266.2487


Epoch 82/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 82/250
  Average Loss: 1263.0579
  Recon Loss: 861.8697
  KL Loss: 401.1882
  🌟 New best model saved! Loss: 1263.0579


Epoch 83/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 83/250
  Average Loss: 1256.3691
  Recon Loss: 854.9658
  KL Loss: 401.4033
  🌟 New best model saved! Loss: 1256.3691


Epoch 84/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


Epoch 84/250
  Average Loss: 1257.4669
  Recon Loss: 855.8219
  KL Loss: 401.6450


Epoch 85/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
    Traceback (most recent call last):
if w.is_alive():  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

     self._shutdown_workers() 
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
       if w.is_alive(): 
^ ^ ^ ^  ^^ ^ ^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^  ^ 
   File "/usr/lib/p


Epoch 85/250
  Average Loss: 1255.8302
  Recon Loss: 854.1099
  KL Loss: 401.7203
  🌟 New best model saved! Loss: 1255.8302


Epoch 86/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 86/250
  Average Loss: 1247.0921
  Recon Loss: 845.8163
  KL Loss: 401.2758
  🌟 New best model saved! Loss: 1247.0921


Epoch 87/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 87/250
  Average Loss: 1247.2200
  Recon Loss: 845.8485
  KL Loss: 401.3715


Epoch 88/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 88/250
  Average Loss: 1243.8504
  Recon Loss: 842.3766
  KL Loss: 401.4738
  🌟 New best model saved! Loss: 1243.8504


Epoch 89/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 89/250
  Average Loss: 1241.3095
  Recon Loss: 840.0959
  KL Loss: 401.2135
  🌟 New best model saved! Loss: 1241.3095


Epoch 90/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 90/250
  Average Loss: 1239.5141
  Recon Loss: 837.8747
  KL Loss: 401.6394
  🌟 New best model saved! Loss: 1239.5141

📤 Uploading model to Google Drive (attempt 1/3)...
   Uploading: vae_best.pth
✅ Model uploaded: size_vae_epoch_90_20251010_155306.pth


Epoch 91/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 91/250
  Average Loss: 1231.0353
  Recon Loss: 830.3125
  KL Loss: 400.7227
  🌟 New best model saved! Loss: 1231.0353


Epoch 92/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 92/250
  Average Loss: 1230.3747
  Recon Loss: 829.0228
  KL Loss: 401.3518
  🌟 New best model saved! Loss: 1230.3747


Epoch 93/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 93/250
  Average Loss: 1224.1261
  Recon Loss: 823.4067
  KL Loss: 400.7194
  🌟 New best model saved! Loss: 1224.1261


Epoch 94/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 94/250
  Average Loss: 1222.7548
  Recon Loss: 822.0997
  KL Loss: 400.6551
  🌟 New best model saved! Loss: 1222.7548


Epoch 95/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


Epoch 95/250
  Average Loss: 1219.8665
  Recon Loss: 818.7308
  KL Loss: 401.1357
  🌟 New best model saved! Loss: 1219.8665


Epoch 96/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
       self._shutdown_workers() 
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^ ^ ^ ^ ^ ^ 
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^ ^ ^ ^ ^ ^ 
   File "/usr/


Epoch 96/250
  Average Loss: 1216.8601
  Recon Loss: 816.4915
  KL Loss: 400.3685
  🌟 New best model saved! Loss: 1216.8601


Epoch 97/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 97/250
  Average Loss: 1213.9003
  Recon Loss: 813.1076
  KL Loss: 400.7927
  🌟 New best model saved! Loss: 1213.9003


Epoch 98/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 98/250
  Average Loss: 1211.6552
  Recon Loss: 811.3027
  KL Loss: 400.3524
  🌟 New best model saved! Loss: 1211.6552


Epoch 99/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 99/250
  Average Loss: 1210.1537
  Recon Loss: 809.2031
  KL Loss: 400.9505
  🌟 New best model saved! Loss: 1210.1537


Epoch 100/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 100/250
  Average Loss: 1205.7237
  Recon Loss: 805.3439
  KL Loss: 400.3798
  🌟 New best model saved! Loss: 1205.7237
  ✓ Checkpoint saved: epoch 100
✓ Saved reconstructions: /content/species_vae_output/reconstructions/epoch_100.png


Epoch 101/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 101/250
  Average Loss: 1203.7632
  Recon Loss: 803.2618
  KL Loss: 400.5014
  🌟 New best model saved! Loss: 1203.7632


Epoch 102/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 102/250
  Average Loss: 1201.3427
  Recon Loss: 800.6666
  KL Loss: 400.6761
  🌟 New best model saved! Loss: 1201.3427


Epoch 103/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 103/250
  Average Loss: 1198.6148
  Recon Loss: 798.3868
  KL Loss: 400.2280
  🌟 New best model saved! Loss: 1198.6148


Epoch 104/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 104/250
  Average Loss: 1196.5145
  Recon Loss: 796.2726
  KL Loss: 400.2419
  🌟 New best model saved! Loss: 1196.5145


Epoch 105/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
     self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
     Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()^
^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process' 
    ^ ^ ^ ^ ^ ^ ^ ^ ^ ^^^^^^^
^  File 


Epoch 105/250
  Average Loss: 1194.5053
  Recon Loss: 794.2146
  KL Loss: 400.2908
  🌟 New best model saved! Loss: 1194.5053


Epoch 106/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 106/250
  Average Loss: 1196.0110
  Recon Loss: 795.6124
  KL Loss: 400.3986


Epoch 107/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 107/250
  Average Loss: 1189.0582
  Recon Loss: 789.2433
  KL Loss: 399.8149
  🌟 New best model saved! Loss: 1189.0582


Epoch 108/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 108/250
  Average Loss: 1187.8025
  Recon Loss: 787.2832
  KL Loss: 400.5193
  🌟 New best model saved! Loss: 1187.8025


Epoch 109/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 109/250
  Average Loss: 1186.0049
  Recon Loss: 785.9863
  KL Loss: 400.0186
  🌟 New best model saved! Loss: 1186.0049


Epoch 110/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 110/250
  Average Loss: 1185.8174
  Recon Loss: 785.8384
  KL Loss: 399.9790
  🌟 New best model saved! Loss: 1185.8174


Epoch 111/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 111/250
  Average Loss: 1180.8161
  Recon Loss: 780.5348
  KL Loss: 400.2814
  🌟 New best model saved! Loss: 1180.8161


Epoch 112/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 112/250
  Average Loss: 1177.5141
  Recon Loss: 777.8831
  KL Loss: 399.6310
  🌟 New best model saved! Loss: 1177.5141


Epoch 113/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 113/250
  Average Loss: 1176.1893
  Recon Loss: 776.3323
  KL Loss: 399.8571
  🌟 New best model saved! Loss: 1176.1893


Epoch 114/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 114/250
  Average Loss: 1174.7245
  Recon Loss: 775.0933
  KL Loss: 399.6312
  🌟 New best model saved! Loss: 1174.7245


Epoch 115/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 115/250
  Average Loss: 1174.5618
  Recon Loss: 774.4601
  KL Loss: 400.1017
  🌟 New best model saved! Loss: 1174.5618


Epoch 116/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 116/250
  Average Loss: 1171.0349
  Recon Loss: 771.7411
  KL Loss: 399.2938
  🌟 New best model saved! Loss: 1171.0349


Epoch 117/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
Exception ignored in:    <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
 Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^^
 ^ ^^  ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
     ^^assert self._parent_pid == os.getpid(), 'can only test a child process'
^ ^ ^ ^ ^ ^ ^ ^ ^ ^ 
   File "/usr


Epoch 117/250
  Average Loss: 1167.9485
  Recon Loss: 768.3068
  KL Loss: 399.6417
  🌟 New best model saved! Loss: 1167.9485


Epoch 118/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 118/250
  Average Loss: 1166.9540
  Recon Loss: 767.5568
  KL Loss: 399.3972
  🌟 New best model saved! Loss: 1166.9540


Epoch 119/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 119/250
  Average Loss: 1163.2982
  Recon Loss: 763.8438
  KL Loss: 399.4544
  🌟 New best model saved! Loss: 1163.2982


Epoch 120/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 120/250
  Average Loss: 1163.6293
  Recon Loss: 764.2504
  KL Loss: 399.3788


Epoch 121/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 121/250
  Average Loss: 1162.0296
  Recon Loss: 762.6050
  KL Loss: 399.4246
  🌟 New best model saved! Loss: 1162.0296


Epoch 122/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 122/250
  Average Loss: 1158.4916
  Recon Loss: 759.7264
  KL Loss: 398.7653
  🌟 New best model saved! Loss: 1158.4916


Epoch 123/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 123/250
  Average Loss: 1157.2324
  Recon Loss: 758.0379
  KL Loss: 399.1946
  🌟 New best model saved! Loss: 1157.2324


Epoch 124/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 124/250
  Average Loss: 1155.3266
  Recon Loss: 755.8925
  KL Loss: 399.4341
  🌟 New best model saved! Loss: 1155.3266


Epoch 125/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 125/250
  Average Loss: 1153.8334
  Recon Loss: 754.8007
  KL Loss: 399.0326
  🌟 New best model saved! Loss: 1153.8334


Epoch 126/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 126/250
  Average Loss: 1153.6140
  Recon Loss: 754.6012
  KL Loss: 399.0128
  🌟 New best model saved! Loss: 1153.6140


Epoch 127/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 127/250
  Average Loss: 1153.1010
  Recon Loss: 753.7431
  KL Loss: 399.3578
  🌟 New best model saved! Loss: 1153.1010


Epoch 128/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 128/250
  Average Loss: 1148.0270
  Recon Loss: 749.1049
  KL Loss: 398.9221
  🌟 New best model saved! Loss: 1148.0270


Epoch 129/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
Exception ignored in:    <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00> 
Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^^ ^ ^ ^ ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^ ^ ^ ^ ^ ^ ^ 
   File "/usr


Epoch 129/250
  Average Loss: 1143.0398
  Recon Loss: 744.3827
  KL Loss: 398.6571
  🌟 New best model saved! Loss: 1143.0398


Epoch 130/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


Epoch 130/250
  Average Loss: 1144.3829
  Recon Loss: 745.6998
  KL Loss: 398.6831


Epoch 131/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 131/250
  Average Loss: 1142.5007
  Recon Loss: 743.9543
  KL Loss: 398.5464
  🌟 New best model saved! Loss: 1142.5007


Epoch 132/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 132/250
  Average Loss: 1143.8408
  Recon Loss: 745.1088
  KL Loss: 398.7320


Epoch 133/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 133/250
  Average Loss: 1139.8991
  Recon Loss: 740.8418
  KL Loss: 399.0573
  🌟 New best model saved! Loss: 1139.8991


Epoch 134/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 134/250
  Average Loss: 1136.7944
  Recon Loss: 738.1698
  KL Loss: 398.6246
  🌟 New best model saved! Loss: 1136.7944


Epoch 135/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 135/250
  Average Loss: 1135.2503
  Recon Loss: 736.6729
  KL Loss: 398.5774
  🌟 New best model saved! Loss: 1135.2503


Epoch 136/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 136/250
  Average Loss: 1134.7652
  Recon Loss: 736.2586
  KL Loss: 398.5066
  🌟 New best model saved! Loss: 1134.7652


Epoch 137/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 137/250
  Average Loss: 1132.2853
  Recon Loss: 733.2691
  KL Loss: 399.0162
  🌟 New best model saved! Loss: 1132.2853


Epoch 138/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 138/250
  Average Loss: 1129.8673
  Recon Loss: 731.3913
  KL Loss: 398.4759
  🌟 New best model saved! Loss: 1129.8673


Epoch 139/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 139/250
  Average Loss: 1131.3893
  Recon Loss: 733.1541
  KL Loss: 398.2352


Epoch 140/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
 ^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ ^ ^ ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
     ^assert self._parent_pid == os.getpid(), 'can only test a child process'^^
^ ^ ^ ^ ^ ^ ^ ^  ^ 
   File "/usr


Epoch 140/250
  Average Loss: 1127.2138
  Recon Loss: 729.1210
  KL Loss: 398.0928
  🌟 New best model saved! Loss: 1127.2138

📤 Uploading model to Google Drive (attempt 1/3)...
   Uploading: vae_best.pth
✅ Model uploaded: size_vae_epoch_140_20251010_160913.pth


Epoch 141/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 141/250
  Average Loss: 1128.7836
  Recon Loss: 730.5337
  KL Loss: 398.2499


Epoch 142/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 142/250
  Average Loss: 1124.7303
  Recon Loss: 726.3198
  KL Loss: 398.4105
  🌟 New best model saved! Loss: 1124.7303


Epoch 143/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 143/250
  Average Loss: 1127.2523
  Recon Loss: 728.3816
  KL Loss: 398.8707


Epoch 144/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 144/250
  Average Loss: 1123.5384
  Recon Loss: 725.2027
  KL Loss: 398.3357
  🌟 New best model saved! Loss: 1123.5384


Epoch 145/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 145/250
  Average Loss: 1120.8719
  Recon Loss: 722.4490
  KL Loss: 398.4229
  🌟 New best model saved! Loss: 1120.8719


Epoch 146/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 146/250
  Average Loss: 1120.3429
  Recon Loss: 722.2456
  KL Loss: 398.0973
  🌟 New best model saved! Loss: 1120.3429


Epoch 147/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 147/250
  Average Loss: 1117.3962
  Recon Loss: 719.4243
  KL Loss: 397.9719
  🌟 New best model saved! Loss: 1117.3962


Epoch 148/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 148/250
  Average Loss: 1115.9093
  Recon Loss: 717.6675
  KL Loss: 398.2418
  🌟 New best model saved! Loss: 1115.9093


Epoch 149/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 149/250
  Average Loss: 1115.0876
  Recon Loss: 716.6706
  KL Loss: 398.4169
  🌟 New best model saved! Loss: 1115.0876


Epoch 150/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 150/250
  Average Loss: 1114.9651
  Recon Loss: 716.6155
  KL Loss: 398.3495
  🌟 New best model saved! Loss: 1114.9651
  ✓ Checkpoint saved: epoch 150
✓ Saved reconstructions: /content/species_vae_output/reconstructions/epoch_150.png


Epoch 151/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 151/250
  Average Loss: 1111.7056
  Recon Loss: 713.4292
  KL Loss: 398.2764
  🌟 New best model saved! Loss: 1111.7056


Epoch 152/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 152/250
  Average Loss: 1110.6417
  Recon Loss: 712.6319
  KL Loss: 398.0098
  🌟 New best model saved! Loss: 1110.6417


Epoch 153/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 153/250
  Average Loss: 1108.8537
  Recon Loss: 710.6274
  KL Loss: 398.2263
  🌟 New best model saved! Loss: 1108.8537


Epoch 154/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 154/250
  Average Loss: 1107.6809
  Recon Loss: 709.7674
  KL Loss: 397.9136
  🌟 New best model saved! Loss: 1107.6809


Epoch 155/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 155/250
  Average Loss: 1107.1463
  Recon Loss: 709.2815
  KL Loss: 397.8648
  🌟 New best model saved! Loss: 1107.1463


Epoch 156/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 156/250
  Average Loss: 1106.3172
  Recon Loss: 708.6769
  KL Loss: 397.6403
  🌟 New best model saved! Loss: 1106.3172


Epoch 157/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 157/250
  Average Loss: 1103.3071
  Recon Loss: 705.6064
  KL Loss: 397.7008
  🌟 New best model saved! Loss: 1103.3071


Epoch 158/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 158/250
  Average Loss: 1104.2350
  Recon Loss: 706.4388
  KL Loss: 397.7962


Epoch 159/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 159/250
  Average Loss: 1101.7834
  Recon Loss: 704.1851
  KL Loss: 397.5983
  🌟 New best model saved! Loss: 1101.7834


Epoch 160/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 160/250
  Average Loss: 1102.2158
  Recon Loss: 704.7681
  KL Loss: 397.4477


Epoch 161/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
   Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00> 
 Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ ^ ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ 
^  File "/us


Epoch 161/250
  Average Loss: 1096.9471
  Recon Loss: 699.2131
  KL Loss: 397.7339
  🌟 New best model saved! Loss: 1096.9471


Epoch 162/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 162/250
  Average Loss: 1097.4406
  Recon Loss: 699.9968
  KL Loss: 397.4438


Epoch 163/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 163/250
  Average Loss: 1096.2094
  Recon Loss: 699.1520
  KL Loss: 397.0574
  🌟 New best model saved! Loss: 1096.2094


Epoch 164/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 164/250
  Average Loss: 1095.0158
  Recon Loss: 697.4535
  KL Loss: 397.5623
  🌟 New best model saved! Loss: 1095.0158


Epoch 165/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 165/250
  Average Loss: 1093.7350
  Recon Loss: 696.2720
  KL Loss: 397.4629
  🌟 New best model saved! Loss: 1093.7350


Epoch 166/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 166/250
  Average Loss: 1091.5352
  Recon Loss: 694.1153
  KL Loss: 397.4199
  🌟 New best model saved! Loss: 1091.5352


Epoch 167/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 167/250
  Average Loss: 1091.6291
  Recon Loss: 694.1610
  KL Loss: 397.4681


Epoch 168/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 168/250
  Average Loss: 1091.2179
  Recon Loss: 693.7789
  KL Loss: 397.4390
  🌟 New best model saved! Loss: 1091.2179


Epoch 169/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 169/250
  Average Loss: 1088.3469
  Recon Loss: 691.3129
  KL Loss: 397.0339
  🌟 New best model saved! Loss: 1088.3469


Epoch 170/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 170/250
  Average Loss: 1090.3952
  Recon Loss: 693.5323
  KL Loss: 396.8629


Epoch 171/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 171/250
  Average Loss: 1086.9950
  Recon Loss: 689.9010
  KL Loss: 397.0941
  🌟 New best model saved! Loss: 1086.9950


Epoch 172/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 172/250
  Average Loss: 1087.3806
  Recon Loss: 690.1581
  KL Loss: 397.2225


Epoch 173/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1


Epoch 173/250
  Average Loss: 1085.0426
  Recon Loss: 688.1503
  KL Loss: 396.8923
  🌟 New best model saved! Loss: 1085.0426


Epoch 174/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 174/250
  Average Loss: 1084.1004
  Recon Loss: 687.1872
  KL Loss: 396.9132
  🌟 New best model saved! Loss: 1084.1004


Epoch 175/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 175/250
  Average Loss: 1082.8833
  Recon Loss: 685.5992
  KL Loss: 397.2840
  🌟 New best model saved! Loss: 1082.8833


Epoch 176/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 176/250
  Average Loss: 1082.7788
  Recon Loss: 685.8490
  KL Loss: 396.9298
  🌟 New best model saved! Loss: 1082.7788


Epoch 177/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 177/250
  Average Loss: 1081.6057
  Recon Loss: 684.8185
  KL Loss: 396.7872
  🌟 New best model saved! Loss: 1081.6057


Epoch 178/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 178/250
  Average Loss: 1080.7758
  Recon Loss: 683.3110
  KL Loss: 397.4649
  🌟 New best model saved! Loss: 1080.7758


Epoch 179/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 179/250
  Average Loss: 1077.6592
  Recon Loss: 680.9973
  KL Loss: 396.6619
  🌟 New best model saved! Loss: 1077.6592


Epoch 180/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 180/250
  Average Loss: 1078.5159
  Recon Loss: 681.4379
  KL Loss: 397.0780


Epoch 181/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 181/250
  Average Loss: 1076.7841
  Recon Loss: 679.8895
  KL Loss: 396.8946
  🌟 New best model saved! Loss: 1076.7841


Epoch 182/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 182/250
  Average Loss: 1076.2087
  Recon Loss: 679.5510
  KL Loss: 396.6576
  🌟 New best model saved! Loss: 1076.2087


Epoch 183/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 183/250
  Average Loss: 1077.2084
  Recon Loss: 679.9559
  KL Loss: 397.2525


Epoch 184/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 184/250
  Average Loss: 1073.8350
  Recon Loss: 677.1883
  KL Loss: 396.6467
  🌟 New best model saved! Loss: 1073.8350


Epoch 185/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
     Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00> 
 Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    self._shutdown_workers()^^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only test a child process'
^ ^ ^^  ^ ^ ^ ^ ^ ^ ^ ^^
^  File "/u


Epoch 185/250
  Average Loss: 1071.3300
  Recon Loss: 674.7517
  KL Loss: 396.5784
  🌟 New best model saved! Loss: 1071.3300


Epoch 186/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 186/250
  Average Loss: 1072.4469
  Recon Loss: 675.5839
  KL Loss: 396.8631


Epoch 187/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 187/250
  Average Loss: 1071.7533
  Recon Loss: 675.1069
  KL Loss: 396.6464


Epoch 188/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 188/250
  Average Loss: 1069.3594
  Recon Loss: 673.1589
  KL Loss: 396.2005
  🌟 New best model saved! Loss: 1069.3594


Epoch 189/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 189/250
  Average Loss: 1067.6835
  Recon Loss: 671.5037
  KL Loss: 396.1798
  🌟 New best model saved! Loss: 1067.6835


Epoch 190/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 190/250
  Average Loss: 1067.1529
  Recon Loss: 670.4041
  KL Loss: 396.7488
  🌟 New best model saved! Loss: 1067.1529

📤 Uploading model to Google Drive (attempt 1/3)...
   Uploading: vae_best.pth
✅ Model uploaded: size_vae_epoch_190_20251010_162426.pth


Epoch 191/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 191/250
  Average Loss: 1065.8109
  Recon Loss: 668.9241
  KL Loss: 396.8868
  🌟 New best model saved! Loss: 1065.8109


Epoch 192/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 192/250
  Average Loss: 1066.6250
  Recon Loss: 670.4698
  KL Loss: 396.1552


Epoch 193/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 193/250
  Average Loss: 1065.7669
  Recon Loss: 669.1335
  KL Loss: 396.6334
  🌟 New best model saved! Loss: 1065.7669


Epoch 194/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 194/250
  Average Loss: 1062.4947
  Recon Loss: 666.3992
  KL Loss: 396.0956
  🌟 New best model saved! Loss: 1062.4947


Epoch 195/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 195/250
  Average Loss: 1061.0684
  Recon Loss: 664.7120
  KL Loss: 396.3564
  🌟 New best model saved! Loss: 1061.0684


Epoch 196/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


Epoch 196/250
  Average Loss: 1061.9009
  Recon Loss: 665.6709
  KL Loss: 396.2300


Epoch 197/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 197/250
  Average Loss: 1062.5649
  Recon Loss: 666.2778
  KL Loss: 396.2870


Epoch 198/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 198/250
  Average Loss: 1059.3806
  Recon Loss: 663.3120
  KL Loss: 396.0686
  🌟 New best model saved! Loss: 1059.3806


Epoch 199/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 199/250
  Average Loss: 1061.1352
  Recon Loss: 664.8389
  KL Loss: 396.2963


Epoch 200/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 200/250
  Average Loss: 1057.3952
  Recon Loss: 661.2409
  KL Loss: 396.1544
  🌟 New best model saved! Loss: 1057.3952
  ✓ Checkpoint saved: epoch 200
✓ Saved reconstructions: /content/species_vae_output/reconstructions/epoch_200.png


Epoch 201/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 201/250
  Average Loss: 1058.8809
  Recon Loss: 662.6215
  KL Loss: 396.2594


Epoch 202/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 202/250
  Average Loss: 1055.6567
  Recon Loss: 659.5442
  KL Loss: 396.1126
  🌟 New best model saved! Loss: 1055.6567


Epoch 203/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 203/250
  Average Loss: 1055.5747
  Recon Loss: 659.7994
  KL Loss: 395.7754
  🌟 New best model saved! Loss: 1055.5747


Epoch 204/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 204/250
  Average Loss: 1053.5818
  Recon Loss: 657.3834
  KL Loss: 396.1984
  🌟 New best model saved! Loss: 1053.5818


Epoch 205/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
    Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>  
^Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process' 
  ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^^^^
^  File "/


Epoch 205/250
  Average Loss: 1054.8504
  Recon Loss: 658.8226
  KL Loss: 396.0278


Epoch 206/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 206/250
  Average Loss: 1050.8608
  Recon Loss: 654.9908
  KL Loss: 395.8700
  🌟 New best model saved! Loss: 1050.8608


Epoch 207/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 207/250
  Average Loss: 1050.9000
  Recon Loss: 654.9184
  KL Loss: 395.9816


Epoch 208/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 208/250
  Average Loss: 1051.2092
  Recon Loss: 655.4754
  KL Loss: 395.7338


Epoch 209/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 209/250
  Average Loss: 1050.5412
  Recon Loss: 654.5560
  KL Loss: 395.9851
  🌟 New best model saved! Loss: 1050.5412


Epoch 210/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 210/250
  Average Loss: 1050.0933
  Recon Loss: 654.2666
  KL Loss: 395.8268
  🌟 New best model saved! Loss: 1050.0933


Epoch 211/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 211/250
  Average Loss: 1047.9102
  Recon Loss: 652.2958
  KL Loss: 395.6143
  🌟 New best model saved! Loss: 1047.9102


Epoch 212/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 212/250
  Average Loss: 1047.0093
  Recon Loss: 651.5965
  KL Loss: 395.4128
  🌟 New best model saved! Loss: 1047.0093


Epoch 213/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 213/250
  Average Loss: 1048.0298
  Recon Loss: 652.0089
  KL Loss: 396.0208


Epoch 214/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 214/250
  Average Loss: 1047.7397
  Recon Loss: 651.4580
  KL Loss: 396.2817


Epoch 215/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 215/250
  Average Loss: 1044.6640
  Recon Loss: 649.1745
  KL Loss: 395.4895
  🌟 New best model saved! Loss: 1044.6640


Epoch 216/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 216/250
  Average Loss: 1044.7985
  Recon Loss: 649.1269
  KL Loss: 395.6715


Epoch 217/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
      Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^^if w.is_alive():^

   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process' 
      ^ ^ ^ ^ ^ ^ ^ ^ ^^^^^^^^
^  File 


Epoch 217/250
  Average Loss: 1044.9997
  Recon Loss: 649.4343
  KL Loss: 395.5654


Epoch 218/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 218/250
  Average Loss: 1043.5823
  Recon Loss: 648.0804
  KL Loss: 395.5018
  🌟 New best model saved! Loss: 1043.5823


Epoch 219/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 219/250
  Average Loss: 1043.7908
  Recon Loss: 647.7511
  KL Loss: 396.0397


Epoch 220/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 220/250
  Average Loss: 1040.1538
  Recon Loss: 644.8631
  KL Loss: 395.2907
  🌟 New best model saved! Loss: 1040.1538


Epoch 221/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 221/250
  Average Loss: 1040.9028
  Recon Loss: 645.6281
  KL Loss: 395.2747


Epoch 222/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 222/250
  Average Loss: 1042.3122
  Recon Loss: 646.5419
  KL Loss: 395.7703


Epoch 223/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 223/250
  Average Loss: 1040.7583
  Recon Loss: 645.1370
  KL Loss: 395.6213


Epoch 224/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 224/250
  Average Loss: 1040.5252
  Recon Loss: 645.0134
  KL Loss: 395.5119


Epoch 225/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 225/250
  Average Loss: 1040.6813
  Recon Loss: 645.2015
  KL Loss: 395.4797


Epoch 226/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 226/250
  Average Loss: 1037.7999
  Recon Loss: 642.3640
  KL Loss: 395.4360
  🌟 New best model saved! Loss: 1037.7999


Epoch 227/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 227/250
  Average Loss: 1038.5804
  Recon Loss: 642.7359
  KL Loss: 395.8446


Epoch 228/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 228/250
  Average Loss: 1036.5009
  Recon Loss: 641.2510
  KL Loss: 395.2498
  🌟 New best model saved! Loss: 1036.5009


Epoch 229/250:   0%|          | 0/297 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
    Exception ignored in:   <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00> 
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
      assert self._parent_pid == os.getpid(), 'can only test a child process' 
  ^ ^ ^ ^ ^ ^  ^ ^ ^ ^^^^^^
^^  File 


Epoch 229/250
  Average Loss: 1035.9911
  Recon Loss: 640.3380
  KL Loss: 395.6531
  🌟 New best model saved! Loss: 1035.9911


Epoch 230/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 230/250
  Average Loss: 1034.8756
  Recon Loss: 639.6890
  KL Loss: 395.1866
  🌟 New best model saved! Loss: 1034.8756


Epoch 231/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 231/250
  Average Loss: 1033.6619
  Recon Loss: 638.4358
  KL Loss: 395.2261
  🌟 New best model saved! Loss: 1033.6619


Epoch 232/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 232/250
  Average Loss: 1033.2901
  Recon Loss: 638.4483
  KL Loss: 394.8418
  🌟 New best model saved! Loss: 1033.2901


Epoch 233/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 233/250
  Average Loss: 1033.0891
  Recon Loss: 638.0444
  KL Loss: 395.0447
  🌟 New best model saved! Loss: 1033.0891


Epoch 234/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 234/250
  Average Loss: 1034.3110
  Recon Loss: 639.1481
  KL Loss: 395.1629


Epoch 235/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 235/250
  Average Loss: 1031.0593
  Recon Loss: 636.4809
  KL Loss: 394.5784
  🌟 New best model saved! Loss: 1031.0593


Epoch 236/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 236/250
  Average Loss: 1030.1944
  Recon Loss: 635.2544
  KL Loss: 394.9401
  🌟 New best model saved! Loss: 1030.1944


Epoch 237/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 237/250
  Average Loss: 1030.9637
  Recon Loss: 635.7297
  KL Loss: 395.2340


Epoch 238/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 238/250
  Average Loss: 1030.3237
  Recon Loss: 635.5733
  KL Loss: 394.7505


Epoch 239/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 239/250
  Average Loss: 1027.6374
  Recon Loss: 632.4655
  KL Loss: 395.1718
  🌟 New best model saved! Loss: 1027.6374


Epoch 240/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 240/250
  Average Loss: 1027.5550
  Recon Loss: 632.3013
  KL Loss: 395.2538
  🌟 New best model saved! Loss: 1027.5550

📤 Uploading model to Google Drive (attempt 1/3)...
   Uploading: vae_best.pth
✅ Model uploaded: size_vae_epoch_240_20251010_163833.pth


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7fc1cc1ace00>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^self._shutdown_workers()
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
^  ^ ^ 

Epoch 241/250:   0%|          | 0/297 [00:00<?, ?it/s]

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ ^ ^ ^ ^ 
 AssertionError^: ^can only test a child process^
^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process



Epoch 241/250
  Average Loss: 1027.4769
  Recon Loss: 632.6118
  KL Loss: 394.8651
  🌟 New best model saved! Loss: 1027.4769


Epoch 242/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 242/250
  Average Loss: 1026.5253
  Recon Loss: 631.6209
  KL Loss: 394.9044
  🌟 New best model saved! Loss: 1026.5253


Epoch 243/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 243/250
  Average Loss: 1025.9378
  Recon Loss: 630.9739
  KL Loss: 394.9639
  🌟 New best model saved! Loss: 1025.9378


Epoch 244/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 244/250
  Average Loss: 1024.7338
  Recon Loss: 629.5209
  KL Loss: 395.2129
  🌟 New best model saved! Loss: 1024.7338


Epoch 245/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 245/250
  Average Loss: 1025.2096
  Recon Loss: 630.4653
  KL Loss: 394.7443


Epoch 246/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 246/250
  Average Loss: 1025.3781
  Recon Loss: 630.7219
  KL Loss: 394.6562


Epoch 247/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 247/250
  Average Loss: 1024.2995
  Recon Loss: 629.6770
  KL Loss: 394.6225
  🌟 New best model saved! Loss: 1024.2995


Epoch 248/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 248/250
  Average Loss: 1022.8848
  Recon Loss: 627.9399
  KL Loss: 394.9449
  🌟 New best model saved! Loss: 1022.8848


Epoch 249/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 249/250
  Average Loss: 1021.3020
  Recon Loss: 626.6019
  KL Loss: 394.7001
  🌟 New best model saved! Loss: 1021.3020


Epoch 250/250:   0%|          | 0/297 [00:00<?, ?it/s]


Epoch 250/250
  Average Loss: 1021.7989
  Recon Loss: 627.0103
  KL Loss: 394.7886
  ✓ Checkpoint saved: epoch 250
✓ Saved reconstructions: /content/species_vae_output/reconstructions/epoch_250.png

TRAINING COMPLETE!
Best loss: 1021.3020
Best model saved at: /content/species_vae_output/checkpoints/vae_best.pth

📤 Uploading final best model...

📤 Uploading model to Google Drive (attempt 1/3)...
   Uploading: vae_best.pth
✅ Model uploaded: size_vae_epoch_250_20251010_164503.pth

✅ All done! Check your Google Drive for uploaded models.
