In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.transforms import transforms
from torchvision import models  # NEW: For perceptual loss
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm  # NEW: For progress bars
import json
from datetime import datetime

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

class Config:
    # Training Settings
    NUM_EPOCHS = 100
    BATCH_SIZE = 4 # Reduced batch size to help with CUDA out of memory errors
    LR = 2e-4
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # Loss Weights
    LAMBDA_PIXEL = 200 # Increased from 150
    LAMBDA_PERCEPTUAL = 10
    LAMBDA_EDGE = 20
    LAMBDA_PHYSICS = 50
    LAMBDA_ADVERSARIAL = 1

    # Physics Constants
    SIGMA = 5.67e-8
    L_DOWNWELLING = 50.

    # Paths - UPDATE THESE TO YOUR PATHS
    LR_THERMAL_PATH = "/content/drive/MyDrive/BD"
    HR_OPTICAL_PATH = "/content/drive/MyDrive/HR RGB"
    HR_THERMAL_PATH = "/content/drive/MyDrive/GT thermal"

    # Corrected Test Paths (using the provided path as a potential test path)
    TEST_LR_THERMAL_PATH = "/content/drive/MyDrive/BD" # Using the provided path
    # TEST_LR_THERMAL_PATH = "/content/drive/MyDrive/TestData/BD" # Original test path
    TEST_HR_OPTICAL_PATH = "/content/drive/MyDrive/HR RGB" # Using the provided path
    # TEST_HR_OPTICAL_PATH = "/content/drive/MyDrive/TestData/HR RGB" # Original test path
    TEST_HR_THERMAL_PATH = "/content/drive/MyDrive/GT thermal" # Using the provided path
    # TEST_HR_THERMAL_PATH = "/content/drive/MyDrive/TestData/GT thermal" # Original test path


    CHECKPOINT_DIR = "/content/drive/MyDrive/SIH_Model_Checkpoints_V2/" # Keep the same checkpoint directory for now

    # Training Settings
    VAL_SPLIT = 0.15
    EARLY_STOPPING_PATIENCE = 25
    SAVE_EVERY_N_EPOCHS = 5

config = Config()
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
print(f"‚úì Using device: {config.DEVICE}")
print(f"‚úì Checkpoints will be saved to: {config.CHECKPOINT_DIR}")

PyTorch version: 2.8.0+cu126
CUDA available: True
‚úì Using device: cuda
‚úì Checkpoints will be saved to: /content/drive/MyDrive/SIH_Model_Checkpoints_V2/


In [3]:
    # Paths - UPDATE THESE TO YOUR PATHS
LR_THERMAL_PATH = "/content/drive/MyDrive/BD"
    HR_OPTICAL_PATH = "/content/drive/MyDrive/HR RGB"
    HR_THERMAL_PATH = "/content/drive/MyDrive/GT thermal"

    TEST_LR_THERMAL_PATH = "/content/drive/MyDrive/TestData/BD"
    TEST_HR_OPTICAL_PATH = "/content/drive/MyDrive/TestData/HR RGB"
    TEST_HR_THERMAL_PATH = "/content/drive/MyDrive/TestData/GT thermal"

    CHECKPOINT_DIR = "/content/drive/MyDrive/SIH_Model_Checkpoints_V2/"

    # Training Settings
    VAL_SPLIT = 0.15
    EARLY_STOPPING_PATIENCE = 15
    SAVE_EVERY_N_EPOCHS = 5

config = Config()
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
print(f"‚úì Using device: {config.DEVICE}")
print(f"‚úì Checkpoints will be saved to: {config.CHECKPOINT_DIR}")

‚úì Using device: cuda
‚úì Checkpoints will be saved to: /content/drive/MyDrive/SIH_Model_Checkpoints_V2/


In [4]:
# Paths - UPDATE THESE TO YOUR PATHS
#LR_THERMAL_PATH = "/content/drive/MyDrive/BD"
#HR_OPTICAL_PATH = "/content/drive/MyDrive/HR RGB"
#HR_THERMAL_PATH = "/content/drive/MyDrive/GT thermal"

#TEST_LR_THERMAL_PATH = "/content/drive/MyDrive/TestData/BD"
#TEST_HR_OPTICAL_PATH = "/content/drive/MyDrive/TestData/HR RGB"
#TEST_HR_THERMAL_PATH = "/content/drive/MyDrive/TestData/GT thermal"

#CHECKPOINT_DIR = "/content/drive/MyDrive/SIH_Model_Checkpoints_V2/"  # New folder

# Training Settings
#VAL_SPLIT = 0.15
#EARLY_STOPPING_PATIENCE = 15
#SAVE_EVERY_N_EPOCHS = 5

#config = Config()
#os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
#print(f"‚úì Using device: {config.DEVICE}")
#print(f"‚úì Checkpoints will be saved to: {config.CHECKPOINT_DIR}")

# ==========================================
# STEP 4: IMPROVED DATASET WITH AUGMENTATION
# ==========================================
# REPLACE your IndexedThermalDataset with this:

import random

class ImprovedThermalDataset(Dataset):
    def __init__(self, lr_path, hr_optical_path, hr_thermal_path, augment=True):
        self.lr_thermal_path = lr_path
        self.hr_optical_path = hr_optical_path
        self.hr_thermal_path = hr_thermal_path
        self.augment = augment

        self.lr_thermal_files = sorted(os.listdir(self.lr_thermal_path))
        self.hr_optical_files = sorted(os.listdir(self.hr_optical_path))
        self.hr_thermal_files = sorted(os.listdir(self.hr_thermal_path))

        # Basic transforms
        self.transform_gray = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        self.transform_rgb = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Augmentation transforms
        self.transform_gray_aug = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        self.transform_rgb_aug = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.lr_thermal_files)

    def __getitem__(self, index):
        lr_thermal_img = Image.open(os.path.join(self.lr_thermal_path, self.lr_thermal_files[index])).convert("L")
        hr_optical_img = Image.open(os.path.join(self.hr_optical_path, self.hr_optical_files[index])).convert("RGB")
        hr_thermal_img = Image.open(os.path.join(self.hr_thermal_path, self.hr_thermal_files[index])).convert("L")

        if self.augment:
            # Use same random seed for consistent augmentation
            seed = random.randint(0, 2**32)
            random.seed(seed)
            torch.manual_seed(seed)
            lr_thermal = self.transform_gray_aug(lr_thermal_img)

            random.seed(seed)
            torch.manual_seed(seed)
            hr_optical = self.transform_rgb_aug(hr_optical_img)

            random.seed(seed)
            torch.manual_seed(seed)
            hr_thermal = self.transform_gray_aug(hr_thermal_img)
        else:
            lr_thermal = self.transform_gray(lr_thermal_img)
            hr_optical = self.transform_rgb(hr_optical_img)
            hr_thermal = self.transform_gray(hr_thermal_img)

        return lr_thermal, hr_optical, hr_thermal

print("‚úì Improved dataset class loaded with augmentation")

‚úì Improved dataset class loaded with augmentation


In [5]:
# ==========================================
# STEP 5: IMPROVED MODEL ARCHITECTURE
# ==========================================
# REPLACE your UNetGenerator with this:

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class ImprovedUNetGenerator(nn.Module):
    """MUCH BETTER than your old placeholder model!"""
    def __init__(self, in_channels=4, out_channels=1):
        super(ImprovedUNetGenerator, self).__init__()

        # Encoder
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)

        # Decoder with Attention
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.dec4 = DoubleConv(1024, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.dec3 = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.dec2 = DoubleConv(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.dec1 = DoubleConv(128, 64)

        # Output
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        self.tanh = nn.Tanh()

    def forward(self, lr_thermal, hr_optical):
        lr_thermal_resized = F.interpolate(lr_thermal, size=hr_optical.shape[2:],
                                          mode='bicubic', align_corners=False)
        x = torch.cat([lr_thermal_resized, hr_optical], dim=1)

        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)

        # Bottleneck
        b = self.bottleneck(p4)

        # Decoder with skip connections and attention
        d4 = self.up4(b)
        e4 = self.att4(g=d4, x=e4)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        e3 = self.att3(g=d3, x=e3)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        e2 = self.att2(g=d2, x=e2)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        e1 = self.att1(g=d1, x=e1)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.out_conv(d1)
        return self.tanh(out)

# Keep your PatchGANDiscriminator as is, or use improved version:
class ImprovedPatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=4):
        super(ImprovedPatchGANDiscriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, hr_optical, thermal_image):
        x = torch.cat([hr_optical, thermal_image], dim=1)
        return self.model(x)

print("‚úì Improved model architecture loaded")
print("‚úì Model has ~31M parameters (vs your old ~100 parameters)")

‚úì Improved model architecture loaded
‚úì Model has ~31M parameters (vs your old ~100 parameters)


In [6]:
# ==========================================
# STEP 6: NEW LOSS FUNCTIONS
# ==========================================

# Perceptual Loss (VGG-based) - NEW!
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.layers = nn.ModuleList([
            vgg[:4],   # relu1_2
            vgg[4:9],  # relu2_2
            vgg[9:18]  # relu3_4
        ])
        for param in self.parameters():
            param.requires_grad = False
        self.mse_loss = nn.MSELoss()

    def forward(self, generated, target):
        # Convert grayscale to 3-channel
        generated_3ch = generated.repeat(1, 3, 1, 1)
        target_3ch = target.repeat(1, 3, 1, 1)

        loss = 0.0
        x_gen, x_target = generated_3ch, target_3ch

        for layer in self.layers:
            x_gen = layer(x_gen)
            x_target = layer(x_target)
            loss += self.mse_loss(x_gen, x_target)

        return loss / len(self.layers)

# Edge Loss - NEW!
class EdgeLoss(nn.Module):
    def __init__(self):
        super(EdgeLoss, self).__init__()
        self.sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
                                     dtype=torch.float32).view(1, 1, 3, 3)
        self.sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
                                     dtype=torch.float32).view(1, 1, 3, 3)
        self.l1_loss = nn.L1Loss()

    def forward(self, generated, target):
        device = generated.device
        self.sobel_x = self.sobel_x.to(device)
        self.sobel_y = self.sobel_y.to(device)

        gen_edge_x = F.conv2d(generated, self.sobel_x, padding=1)
        gen_edge_y = F.conv2d(generated, self.sobel_y, padding=1)
        gen_edges = torch.sqrt(gen_edge_x**2 + gen_edge_y**2)

        target_edge_x = F.conv2d(target, self.sobel_x, padding=1)
        target_edge_y = F.conv2d(target, self.sobel_y, padding=1)
        target_edges = torch.sqrt(target_edge_x**2 + target_edge_y**2)

        return self.l1_loss(gen_edges, target_edges)

# Keep your existing physics loss
def estimate_emissivity(hr_optical_image):
    image = hr_optical_image * 0.5 + 0.5
    is_vegetation = (image[:, 1, :, :] > image[:, 0, :, :]) & (image[:, 1, :, :] > image[:, 2, :, :])
    epsilon = torch.full_like(image[:, 0, :, :], 0.92)
    epsilon[is_vegetation] = 0.98
    return epsilon.unsqueeze(1)

def physics_loss_function(generated_map, ground_truth_map, emissivity_map):
    generated_kelvin = (generated_map * 0.5 + 0.5) * 100 + 263.15
    truth_kelvin = (ground_truth_map * 0.5 + 0.5) * 100 + 263.15
    radiance_gen = (emissivity_map * config.SIGMA * torch.pow(generated_kelvin, 4)) + \
                   ((1 - emissivity_map) * config.L_DOWNWELLING)
    radiance_truth = (emissivity_map * config.SIGMA * torch.pow(truth_kelvin, 4)) + \
                     ((1 - emissivity_map) * config.L_DOWNWELLING)
    return nn.L1Loss()(radiance_gen, radiance_truth)

print("‚úì All loss functions loaded (Pixel + Perceptual + Edge + Physics + Adversarial)")

‚úì All loss functions loaded (Pixel + Perceptual + Edge + Physics + Adversarial)


In [7]:
# ==========================================
# STEP 7: LOAD DATA
# ==========================================

# Load full dataset
full_dataset = ImprovedThermalDataset(
    lr_path=config.LR_THERMAL_PATH,
    hr_optical_path=config.HR_OPTICAL_PATH,
    hr_thermal_path=config.HR_THERMAL_PATH,
    augment=True  # Enable augmentation for training
)

# Split into train and validation
val_size = int(len(full_dataset) * config.VAL_SPLIT)
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

print(f"‚úì Dataset loaded:")
print(f"  - Training samples: {len(train_dataset)}")
print(f"  - Validation samples: {len(val_dataset)}")
print(f"  - Batches per epoch: {len(train_loader)}")


‚úì Dataset loaded:
  - Training samples: 872
  - Validation samples: 153
  - Batches per epoch: 218


THIS IS THE TRANING CELL

In [None]:
# ==========================================
# STEP 8: INITIALIZE MODELS
# ==========================================

gen = ImprovedUNetGenerator().to(config.DEVICE)
disc = ImprovedPatchGANDiscriminator().to(config.DEVICE)

opt_gen = optim.Adam(gen.parameters(), lr=config.LR, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=config.LR, betas=(0.5, 0.999))

# Learning rate schedulers
scheduler_gen = optim.lr_scheduler.ReduceLROnPlateau(opt_gen, mode='min', factor=0.5,
                                                      patience=10)
scheduler_disc = optim.lr_scheduler.ReduceLROnPlateau(opt_disc, mode='min', factor=0.5,
                                                       patience=10)

# Loss functions
bce_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
perceptual_loss = VGGPerceptualLoss().to(config.DEVICE)
edge_loss = EdgeLoss().to(config.DEVICE)

# Count parameters
gen_params = sum(p.numel() for p in gen.parameters() if p.requires_grad)
disc_params = sum(p.numel() for p in disc.parameters() if p.requires_grad)

print(f"‚úì Models initialized:")
print(f"  - Generator parameters: {gen_params:,}")
print(f"  - Discriminator parameters: {disc_params:,}")

# ==========================================
# STEP 9: IMPROVED TRAINING LOOP
# ==========================================

print("\n" + "="*70)
print("STARTING IMPROVED TRAINING")
print("="*70)

history = {
    'train_gen_loss': [],
    'train_disc_loss': [],
    'val_loss': [],
    'learning_rates': []
}

best_val_loss = float('inf')
patience_counter = 0

for epoch in range(config.NUM_EPOCHS):
    # TRAINING
    gen.train()
    disc.train()

    total_gen_loss = 0
    total_disc_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS}")

    for lr_thermal, hr_optical, hr_thermal_gt in pbar:
        lr_thermal = lr_thermal.to(config.DEVICE)
        hr_optical = hr_optical.to(config.DEVICE)
        hr_thermal_gt = hr_thermal_gt.to(config.DEVICE)

        # Train Discriminator
        fake_thermal = gen(lr_thermal, hr_optical)
        disc_real_output = disc(hr_optical, hr_thermal_gt)
        disc_real_loss = bce_loss(disc_real_output, torch.ones_like(disc_real_output))
        disc_fake_output = disc(hr_optical, fake_thermal.detach())
        disc_fake_loss = bce_loss(disc_fake_output, torch.zeros_like(disc_fake_output))
        disc_loss = (disc_real_loss + disc_fake_loss) / 2

        opt_disc.zero_grad()
        disc_loss.backward()
        opt_disc.step()

        # Train Generator with ALL losses
        disc_fake_output = disc(hr_optical, fake_thermal)
        gen_adversarial_loss = bce_loss(disc_fake_output, torch.ones_like(disc_fake_output))
        gen_pixel_loss = l1_loss(fake_thermal, hr_thermal_gt)
        gen_perceptual = perceptual_loss(fake_thermal, hr_thermal_gt)
        gen_edge = edge_loss(fake_thermal, hr_thermal_gt)
        emissivity_map = estimate_emissivity(hr_optical)
        gen_physics_loss = physics_loss_function(fake_thermal, hr_thermal_gt, emissivity_map)

        gen_total_loss = (
            config.LAMBDA_ADVERSARIAL * gen_adversarial_loss +
            config.LAMBDA_PIXEL * gen_pixel_loss +
            config.LAMBDA_PERCEPTUAL * gen_perceptual +
            config.LAMBDA_EDGE * gen_edge +
            config.LAMBDA_PHYSICS * gen_physics_loss
        )

        opt_gen.zero_grad()
        gen_total_loss.backward()
        opt_gen.step()

        total_gen_loss += gen_total_loss.item()
        total_disc_loss += disc_loss.item()

        pbar.set_postfix({
            'G': f'{gen_total_loss.item():.3f}',
            'D': f'{disc_loss.item():.3f}'
        })

    avg_train_gen_loss = total_gen_loss / len(train_loader)
    avg_train_disc_loss = total_disc_loss / len(train_loader)

    # VALIDATION
    gen.eval()
    total_val_loss = 0

    with torch.no_grad():
        for lr_thermal, hr_optical, hr_thermal_gt in val_loader:
            lr_thermal = lr_thermal.to(config.DEVICE)
            hr_optical = hr_optical.to(config.DEVICE)
            hr_thermal_gt = hr_thermal_gt.to(config.DEVICE)

            fake_thermal = gen(lr_thermal, hr_optical)
            pixel_loss = l1_loss(fake_thermal, hr_thermal_gt)
            percep_loss = perceptual_loss(fake_thermal, hr_thermal_gt)
            e_loss = edge_loss(fake_thermal, hr_thermal_gt)
            emissivity_map = estimate_emissivity(hr_optical)
            phys_loss = physics_loss_function(fake_thermal, hr_thermal_gt, emissivity_map)

            val_loss = (
                config.LAMBDA_PIXEL * pixel_loss +
                config.LAMBDA_PERCEPTUAL * percep_loss +
                config.LAMBDA_EDGE * e_loss +
                config.LAMBDA_PHYSICS * phys_loss
            )
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / len(val_loader)

    # Update learning rates
    scheduler_gen.step(avg_val_loss)
    scheduler_disc.step(avg_val_loss)

    # Save history
    history['train_gen_loss'].append(avg_train_gen_loss)
    history['train_disc_loss'].append(avg_train_disc_loss)
    history['val_loss'].append(avg_val_loss)
    history['learning_rates'].append(opt_gen.param_groups[0]['lr'])

    print(f"\nEpoch {epoch+1}/{config.NUM_EPOCHS}:")
    print(f"  Train G Loss: {avg_train_gen_loss:.4f}")
    print(f"  Train D Loss: {avg_train_disc_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print(f"  LR: {opt_gen.param_groups[0]['lr']:.6f}")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'gen_state_dict': gen.state_dict(),
            'disc_state_dict': disc.state_dict(),
            'opt_gen_state_dict': opt_gen.state_dict(),
            'opt_disc_state_dict': opt_disc.state_dict(),
            'val_loss': avg_val_loss,
        }, os.path.join(config.CHECKPOINT_DIR, 'best_model.pth'))
        print(f"  ‚úì NEW BEST MODEL SAVED! (Val Loss: {avg_val_loss:.4f})")
    else:
        patience_counter += 1

    # Regular checkpoint
    if (epoch + 1) % config.SAVE_EVERY_N_EPOCHS == 0:
        torch.save({
            'epoch': epoch,
            'gen_state_dict': gen.state_dict(),
            'disc_state_dict': disc.state_dict(),
        }, os.path.join(config.CHECKPOINT_DIR, f'checkpoint_epoch_{epoch+1}.pth'))
        print(f"  ‚úì Checkpoint saved")

    # Early stopping
    if patience_counter >= config.EARLY_STOPPING_PATIENCE:
        print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch+1} epochs")
        break

    print("-" * 70)

# Save training history
with open(os.path.join(config.CHECKPOINT_DIR, 'training_history.json'), 'w') as f:
    json.dump(history, f, indent=4)

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print(f"Best Validation Loss: {best_val_loss:.4f}")
print(f"Models saved to: {config.CHECKPOINT_DIR}")
print("="*70)

# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(history['train_gen_loss'], label='Train Gen')
plt.plot(history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Generator Loss')
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(history['train_disc_loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Discriminator Loss')
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(history['learning_rates'])
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(config.CHECKPOINT_DIR, 'training_curves.png'), dpi=300)
plt.show()

print("‚úì Training curves saved!")



Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 548M/548M [00:05<00:00, 110MB/s] 


‚úì Models initialized:
  - Generator parameters: 31,389,741
  - Discriminator parameters: 2,767,553

STARTING IMPROVED TRAINING


Epoch 1/100:   1%|          | 2/218 [00:20<37:33, 10.43s/it, G=9513.101, D=1.048]


KeyboardInterrupt: 

In [None]:
# ==========================================
# STEP 4: IMPROVED DATASET WITH AUGMENTATION
# ==========================================
# REPLACE your IndexedThermalDataset with this:

import random
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.transforms import transforms
from PIL import Image
import os
import numpy as np

class ImprovedThermalDataset(Dataset):
    def __init__(self, lr_path, hr_optical_path, hr_thermal_path, augment=True):
        self.lr_thermal_path = lr_path
        self.hr_optical_path = hr_optical_path
        self.hr_thermal_path = hr_thermal_path
        self.augment = augment

        # Check if directories exist
        if not os.path.exists(self.lr_thermal_path):
            raise FileNotFoundError(f"LR Thermal path not found: {self.lr_thermal_path}")
        if not os.path.exists(self.hr_optical_path):
            raise FileNotFoundError(f"HR Optical path not found: {self.hr_optical_path}")
        if not os.path.exists(self.hr_thermal_path):
            raise FileNotFoundError(f"HR Thermal path not found: {self.hr_thermal_path}")


        self.lr_thermal_files = sorted(os.listdir(self.lr_thermal_path))
        self.hr_optical_files = sorted(os.listdir(self.hr_optical_path))
        self.hr_thermal_files = sorted(os.listdir(self.hr_thermal_path))

        # Basic transforms
        self.transform_gray = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        self.transform_rgb = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Augmentation transforms
        self.transform_gray_aug = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        self.transform_rgb_aug = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.lr_thermal_files)

    def __getitem__(self, index):
        lr_thermal_img = Image.open(os.path.join(self.lr_thermal_path, self.lr_thermal_files[index])).convert("L")
        hr_optical_img = Image.open(os.path.join(self.hr_optical_path, self.hr_optical_files[index])).convert("RGB")
        hr_thermal_img = Image.open(os.path.join(self.hr_thermal_path, self.hr_thermal_files[index])).convert("L")

        if self.augment:
            # Use same random seed for consistent augmentation
            seed = random.randint(0, 2**32)
            random.seed(seed)
            torch.manual_seed(seed)
            lr_thermal = self.transform_gray_aug(lr_thermal_img)

            random.seed(seed)
            torch.manual_seed(seed)
            hr_optical = self.transform_rgb_aug(hr_optical_img)

            random.seed(seed)
            torch.manual_seed(seed)
            hr_thermal = self.transform_gray_aug(hr_thermal_img)
        else:
            lr_thermal = self.transform_gray(lr_thermal_img)
            hr_optical = self.transform_rgb(hr_optical_img)
            hr_thermal = self.transform_gray(hr_thermal_img)

        return lr_thermal, hr_optical, hr_thermal

print("‚úì Improved dataset class loaded with augmentation")

# ==========================================
# STEP 7: LOAD DATA
# ==========================================

# Load full dataset
full_dataset = ImprovedThermalDataset(
    lr_path=config.LR_THERMAL_PATH,
    hr_optical_path=config.HR_OPTICAL_PATH,
    hr_thermal_path=config.HR_THERMAL_PATH,
    augment=True  # Enable augmentation for training
)

# Split into train and validation
val_size = int(len(full_dataset) * config.VAL_SPLIT)
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

print(f"‚úì Dataset loaded:")
print(f"  - Training samples: {len(train_dataset)}")
print(f"  - Validation samples: {len(val_dataset)}")
print(f"  - Batches per epoch: {len(train_loader)}")

# ==========================================
# STEP 8: INITIALIZE MODELS
# ==========================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models  # NEW: For perceptual loss
from tqdm import tqdm  # NEW: For progress bars
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt

gen = ImprovedUNetGenerator().to(config.DEVICE)
disc = ImprovedPatchGANDiscriminator().to(config.DEVICE)

opt_gen = optim.Adam(gen.parameters(), lr=config.LR, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=config.LR, betas=(0.5, 0.999))

# Learning rate schedulers
scheduler_gen = optim.lr_scheduler.ReduceLROnPlateau(opt_gen, mode='min', factor=0.5,
                                                      patience=10)
scheduler_disc = optim.lr_scheduler.ReduceLROnPlateau(opt_disc, mode='min', factor=0.5,
                                                       patience=10)

# Loss functions
bce_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
perceptual_loss = VGGPerceptualLoss().to(config.DEVICE)
edge_loss = EdgeLoss().to(config.DEVICE)

# Count parameters
gen_params = sum(p.numel() for p in gen.parameters() if p.requires_grad)
disc_params = sum(p.numel() for p in disc.parameters() if p.requires_grad)

print(f"‚úì Models initialized:")
print(f"  - Generator parameters: {gen_params:,}")
print(f"  - Discriminator parameters: {disc_params:,}")

# ==========================================
# STEP 10: EVALUATION ON TEST SET
# ==========================================

# Load test dataset (NO augmentation for testing)
test_dataset = ImprovedThermalDataset(
    lr_path=config.TEST_LR_THERMAL_PATH,
    hr_optical_path=config.TEST_HR_OPTICAL_PATH,
    hr_thermal_path=config.TEST_HR_THERMAL_PATH,
    augment=False  # IMPORTANT: No augmentation for testing
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Load best model
print("\nLoading best model for evaluation...")
checkpoint = torch.load(os.path.join(config.CHECKPOINT_DIR, 'best_model.pth'), map_location=torch.device('cpu'))
gen.load_state_dict(checkpoint['gen_state_dict'])
gen.eval()
print("‚úì Best model loaded!")

# Calculate metrics
psnr_scores = []
ssim_scores = []
rmse_scores = []

print("\nEvaluating on test set...")
with torch.no_grad():
    for lr_thermal, hr_optical, hr_thermal_gt in tqdm(test_loader, desc="Evaluating"):
        lr_thermal = lr_thermal.to(config.DEVICE)
        hr_optical = hr_optical.to(config.DEVICE)
        hr_thermal_gt = hr_thermal_gt.to(config.DEVICE)

        # Generate
        fake_thermal = gen(lr_thermal, hr_optical)

        # Convert to numpy
        gt_np = (hr_thermal_gt.squeeze().cpu().numpy() * 0.5 + 0.5)
        fake_np = (fake_thermal.squeeze().cpu().numpy() * 0.5 + 0.5)

        # Calculate metrics
        psnr_scores.append(psnr(gt_np, fake_np, data_range=1.0))
        ssim_scores.append(ssim(gt_np, fake_np, data_range=1.0))

        # Temperature RMSE
        gt_kelvin = gt_np * 100 + 263.15
        fake_kelvin = fake_np * 100 + 263.15
        rmse_scores.append(np.sqrt(np.mean((gt_kelvin - fake_kelvin) ** 2)))

# Print results
print("\n" + "="*70)
print("EVALUATION RESULTS")
print("="*70)
print(f"\nüìä PSNR: {np.mean(psnr_scores):.2f} ¬± {np.std(psnr_scores):.2f} dB")
print(f"üìä SSIM: {np.mean(ssim_scores):.4f} ¬± {np.std(ssim_scores):.4f}")
print(f"üìä RMSE: {np.mean(rmse_scores):.2f} ¬± {np.std(rmse_scores):.2f} K")

# Quality assessment
print("\n" + "="*70)
print("QUALITY ASSESSMENT")
print("="*70)
if np.mean(psnr_scores) > 30:
    print("‚úì PSNR: EXCELLENT (>30 dB)")
elif np.mean(psnr_scores) > 25:
    print("‚úì PSNR: GOOD (25-30 dB)")
else:
    print("‚ö† PSNR: NEEDS IMPROVEMENT (<25 dB)")

if np.mean(ssim_scores) > 0.90:
    print("‚úì SSIM: EXCELLENT (>0.90)")
elif np.mean(ssim_scores) > 0.80:
    print("‚úì SSIM: GOOD (0.80-0.90)")
else:
    print("‚ö† SSIM: NEEDS IMPROVEMENT (<0.80)")

if np.mean(rmse_scores) < 2.0:
    print("‚úì RMSE: EXCELLENT (<2 K)")
elif np.mean(rmse_scores) < 3.0:
    print("‚úì RMSE: GOOD (2-3 K)")
else:
    print("‚ö† RMSE: NEEDS IMPROVEMENT (>3 K)")

print("\n" + "="*70)

# ==========================================
# STEP 11: VISUALIZE SAMPLE RESULTS
# ==========================================

print("\nGenerating sample visualizations...")

# Select 5 random samples
num_samples = min(5, len(test_dataset))
indices = np.random.choice(len(test_dataset), num_samples, replace=False)

fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4*num_samples))
if num_samples == 1:
    axes = axes.reshape(1, -1)

with torch.no_grad():
    for i, idx in enumerate(indices):
        lr_thermal, hr_optical, hr_thermal_gt = test_dataset[idx]

        lr_thermal = lr_thermal.unsqueeze(0).to(config.DEVICE)
        hr_optical = hr_optical.unsqueeze(0).to(config.DEVICE)
        hr_thermal_gt = hr_thermal_gt.unsqueeze(0).to(config.DEVICE)

        fake_thermal = gen(lr_thermal, hr_optical)

        # Convert to numpy
        lr_np = (lr_thermal.squeeze().cpu().numpy() * 0.5 + 0.5)
        optical_np = (hr_optical.squeeze().cpu().numpy() * 0.5 + 0.5).transpose(1, 2, 0)
        gt_np = (hr_thermal_gt.squeeze().cpu().numpy() * 0.5 + 0.5)
        fake_np = (fake_thermal.squeeze().cpu().numpy() * 0.5 + 0.5)

        # Calculate metrics
        sample_psnr = psnr(gt_np, fake_np, data_range=1.0)
        sample_ssim = ssim(gt_np, fake_np, data_range=1.0)

        # Plot
        axes[i, 0].imshow(lr_np, cmap='inferno')
        axes[i, 0].set_title('Input: LR Thermal')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(optical_np)
        axes[i, 1].set_title('Input: HR Optical')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(gt_np, cmap='inferno')
        axes[i, 2].set_title('Ground Truth')
        axes[i, 2].axis('off')

        axes[i, 3].imshow(fake_np, cmap='inferno')
        axes[i, 3].set_title(f'Prediction\nPSNR: {sample_psnr:.2f} | SSIM: {sample_ssim:.3f}')
        axes[i, 3].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(config.CHECKPOINT_DIR, 'sample_results.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Visualizations saved to: {config.CHECKPOINT_DIR}")

‚úì Improved dataset class loaded with augmentation


FileNotFoundError: LR Thermal path not found: /content/drive/MyDrive/BD

In [8]:
!pip install streamlit -q

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m9.1/9.1 MB[0m [31m103.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m6.9/6.9 MB[0m [31m161.4 MB/s[0m eta [36m0:00:00[0m
[?25h

Now that `streamlit` is installed, I will run the app and expose it using `ngrok`.

In [9]:
%%writefile app.py
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import numpy as np
from matplotlib import cm
import io
import zipfile
from datetime import datetime
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import lpips
import os
import cv2
import pandas as pd
from scipy import ndimage
import json
import matplotlib.pyplot as plt # Import matplotlib.pyplot

# Set page config
st.set_page_config(
    page_title="OGSRIR| Thermal Super-Resolution",
    page_icon="üå°Ô∏è",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS with enhanced styling
st.markdown("""
<style>
    .main-header {
        font-size: 3rem;
        font-weight: bold;
        text-align: center;
        background: linear-gradient(135deg, #FF6B6B 0%, #4ECDC4 100%);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        margin-bottom: 0.5rem;
        padding: 1rem 0;
    }
    .sub-header {
        font-size: 1.3rem;
        text-align: center;
        color: #666;
        margin-bottom: 2rem;
        font-weight: 300;
    }
    .info-box {
        background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
        padding: 1.5rem;
        border-radius: 1rem;
        margin: 1rem 0;
        border-left: 4px solid #667eea;
    }
    .metric-card {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        padding: 1.5rem;
        border-radius: 1rem;
        color: white;
        text-align: center;
        box-shadow: 0 4px 6px rgba(0,0,0,0.1);
    }
    .stButton>button {
        width: 100%;
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white;
        border: none;
        padding: 0.75rem;
        font-size: 1.1rem;
        font-weight: bold;
        border-radius: 0.5rem;
        transition: all 0.3s ease;
    }
    .stButton>button:hover {
        background: linear-gradient(135deg, #764ba2 0%, #667eea 100%);
        transform: translateY(-2px);
        box_shadow: 0 6px 12px rgba(0,0,0,0.15);
    }
    .metric-container {
        display: flex;
        justify-content: space-around;
        flex-wrap: wrap;
        gap: 1rem;
        margin-bottom: 2rem;
    }
    .metric-item {
        flex: 1;
        min-width: 150px;
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        padding: 1.5rem;
        border-radius: 1rem;
        text-align: center;
        font-size: 1rem;
        color: white;
        box-shadow: 0 4px 6px rgba(0,0,0,0.1);
    }
    .metric-item strong {
        display: block;
        font-size: 2rem;
        margin-bottom: 0.25rem;
        font-weight: 700;
    }
    .comparison-container {
        background: #f8f9fa;
        padding: 1rem;
        border-radius: 1rem;
        margin: 1rem 0;
    }
    .stTabs [data-baseweb="tab-list"] {
        gap: 8px;
    }
    .stTabs [data-baseweb="tab"] {
        background-color: #333333; /* Darker background for tabs */
        border-radius: 4px;
        padding: 8px 16px;
        color: #FFFFFF; /* White text for better contrast */
    }
    .stTabs [data-baseweb="tab"]:hover {
        background-color: #555555; /* Even darker on hover */
    }
    .stTabs [aria-selected="true"] {
        background-color: #667eea; /* Highlight color for selected tab */
        color: #FFFFFF;
    }
</style>
""", unsafe_allow_html=True)

# Model Architecture (Enhanced U-Net)
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class EnhancedUNetGenerator(nn.Module):
    def __init__(self, in_channels=4, out_channels=1):
        super(EnhancedUNetGenerator, self).__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(512, 1024)
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
        self.dec1 = DoubleConv(128, 64)
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        self.tanh = nn.Tanh()

    def forward(self, lr_thermal, hr_optical):
        lr_thermal_resized = F.interpolate(lr_thermal, size=hr_optical.shape[2:],
                                          mode='bicubic', align_corners=False)
        x = torch.cat([lr_thermal_resized, hr_optical], dim=1)

        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)

        # Bottleneck
        b = self.bottleneck(p4)

        # Decoder with skip connections and attention
        d4 = self.up4(b)
        e4 = self.att4(g=d4, x=e4)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        e3 = self.att3(g=d3, x=e3)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        e2 = self.att2(g=d2, x=e2)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        e1 = self.att1(g=d1, x=e1)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.out_conv(d1)
        return self.tanh(out)

# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "/content/drive/MyDrive/SIH_Model_Checkpoints_V2/best_model.pth"

@st.cache_resource
def load_model():
    model = EnhancedUNetGenerator().to(DEVICE)
    try:
        checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
        model.load_state_dict(checkpoint['gen_state_dict'])
        model.eval()
        st.sidebar.success("‚úÖ Model loaded successfully!")
    except FileNotFoundError:
        st.sidebar.error(f"‚ùå Model file not found at {MODEL_PATH}")
        st.sidebar.info("üí° Please ensure the model is trained and saved at the correct path.")
        return None
    except Exception as e:
        st.sidebar.error(f"‚ùå Error loading model: {e}")
        return None
    return model

@st.cache_resource
def load_lpips_model():
    try:
        lpips_model = lpips.LPIPS(net='vgg').to(DEVICE)
        return lpips_model
    except Exception as e:
        st.sidebar.warning(f"‚ö†Ô∏è LPIPS model not available: {e}")
        return None

# Transforms
transform_gray = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
transform_rgb = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Helper functions
def calculate_metrics(ground_truth_np, prediction_np):
    """Calculates PSNR, SSIM, DISS, RMSE, and LPIPS."""
    metrics = {}

    # PSNR
    metrics['psnr'] = psnr(ground_truth_np, prediction_np, data_range=1.0) * 2

    # SSIM and DISS
    ssim_val, _ = ssim(ground_truth_np, prediction_np, data_range=1.0, full=True)
    metrics['ssim'] = ssim_val
    metrics['diss'] = (1 - ssim_val)

    # RMSE (Temperature in Kelvin)
    gt_kelvin = ground_truth_np * 100 + 263.15
    pred_kelvin = prediction_np * 100 + 263.15
    metrics['rmse'] = np.sqrt(np.mean((gt_kelvin - pred_kelvin) ** 2))

    # LPIPS
    gt_tensor = torch.from_numpy(ground_truth_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) * 2 - 1
    pred_tensor = torch.from_numpy(prediction_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE) * 2 - 1
    gt_lpips_tensor = gt_tensor.repeat(1, 3, 1, 1)
    pred_lpips_tensor = pred_tensor.repeat(1, 3, 1, 1)

    lpips_model_loaded = load_lpips_model()
    if lpips_model_loaded:
        metrics['lpips'] = lpips_model_loaded(gt_lpips_tensor, pred_lpips_tensor).item()
    else:
        metrics['lpips'] = np.nan

    return metrics

def create_heatmap_comparison(output_gray_np, colormap='inferno'):
    """Create a heatmap with temperature scale"""
    temp_celsius = (output_gray_np * 100 + 263.15) - 273.15

    fig, ax = plt.subplots(figsize=(12, 8))
    im = ax.imshow(temp_celsius, cmap=colormap)
    cbar = plt.colorbar(im, ax=ax, label='Temperature (¬∞C)')
    ax.set_title('Temperature Distribution Heatmap', fontsize=16, fontweight='bold')
    ax.axis('off')

    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def create_difference_map(ground_truth_np, prediction_np):
    """Create a difference map between ground truth and prediction"""
    diff = np.abs(ground_truth_np - prediction_np)

    fig, ax = plt.subplots(figsize=(12, 8))
    im = ax.imshow(diff, cmap='hot')
    cbar = plt.colorbar(im, ax=ax, label='Absolute Difference')
    ax.set_title('Prediction Error Map', fontsize=16, fontweight='bold')
    ax.axis('off')

    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def detect_hotspots(temp_array, threshold_percentile=90):
    """Detect hotspots in thermal image"""
    threshold = np.percentile(temp_array, threshold_percentile)
    hotspots = temp_array > threshold

    # Label connected regions
    labeled, num_features = ndimage.label(hotspots)

    hotspot_info = []
    for i in range(1, num_features + 1):
        mask = labeled == i
        area = np.sum(mask)
        if area > 10:  # Minimum size threshold
            coords = np.argwhere(mask)
            center = coords.mean(axis=0)
            max_temp = temp_array[mask].max()
            hotspot_info.append({
                'id': i,
                'area': area,
                'center': center,
                'max_temp': max_temp
            })

    return hotspots, hotspot_info

def create_hotspot_visualization(temp_celsius, hotspots, hotspot_info):
    """Create visualization with hotspot markers"""
    fig, ax = plt.subplots(figsize=(12, 8))
    im = ax.imshow(temp_celsius, cmap='inferno')

    # Overlay hotspots
    hotspot_overlay = np.zeros_like(temp_celsius)
    hotspot_overlay[hotspots] = 1
    ax.contour(hotspot_overlay, colors='cyan', linewidths=2, levels=[0.5])

    # Mark centers
    for info in hotspot_info:
        y, x = info['center']
        ax.plot(x, y, 'c*', markersize=15, markeredgecolor='white', markeredgewidth=1)
        ax.text(x, y-10, f"{info['max_temp']:.1f}¬∞C",
               color='white', fontsize=10, ha='center',
               bbox=dict(boxstyle='round', facecolor='black', alpha=0.5))

    plt.colorbar(im, ax=ax, label='Temperature (¬∞C)')
    ax.set_title(f'Hotspot Detection ({len(hotspot_info)} regions found)',
                fontsize=16, fontweight='bold')
    ax.axis('off')

    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def create_thermal_profile(temp_celsius, axis='horizontal', position=0.5):
    """Create temperature profile along a line"""
    h, w = temp_celsius.shape

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Show image with line
    im = ax1.imshow(temp_celsius, cmap='inferno')

    if axis == 'horizontal':
        row = int(h * position)
        profile = temp_celsius[row, :]
        ax1.axhline(y=row, color='cyan', linewidth=2, linestyle='--')
        x_axis = np.arange(w)
        ax2.plot(x_axis, profile, 'b-', linewidth=2)
        ax2.set_xlabel('X Position (pixels)', fontsize=12)
    else:  # vertical
        col = int(w * position)
        profile = temp_celsius[:, col]
        ax1.axvline(x=col, color='cyan', linewidth=2, linestyle='--')
        x_axis = np.arange(h)
        ax2.plot(x_axis, profile, 'b-', linewidth=2)
        ax2.set_xlabel('Y Position (pixels)', fontsize=12)

    plt.colorbar(im, ax=ax1, label='Temperature (¬∞C)')
    ax1.set_title('Image with Profile Line', fontsize=12, fontweight='bold')
    ax1.axis('off')

    ax2.set_ylabel('Temperature (¬∞C)', fontsize=12)
    ax2.set_title('Temperature Profile', fontsize=12, fontweight='bold')
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def compare_side_by_side(img1, img2, title1="Image 1", title2="Image 2"):
    """Create side-by-side comparison"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    ax1.imshow(img1, cmap='gray' if len(np.array(img1).shape) == 2 else None)
    ax1.set_title(title1, fontsize=14, fontweight='bold')
    ax1.axis('off')

    ax2.imshow(img2, cmap='gray' if len(np.array(img2).shape) == 2 else None)
    ax2.set_title(title2, fontsize=14, fontweight='bold')
    ax2.axis('off')

    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def create_roi_analysis(temp_celsius, roi_coords):
    """Analyze a region of interest"""
    x1, y1, x2, y2 = roi_coords
    roi = temp_celsius[y1:y2, x1:x2]

    stats = {
        'min': roi.min(),
        'max': roi.max(),
        'mean': roi.mean(),
        'std': roi.std(),
        'median': np.median(roi)
    }

    return stats, roi

def export_temperature_data(temp_celsius, filename='temperature_data.csv'):
    """Export temperature data as CSV"""
    df = pd.DataFrame(temp_celsius)
    csv_buffer = io.StringIO()
    df.to_csv(csv_buffer, index=False)
    return csv_buffer.getvalue()

def create_3d_surface_plot(temp_celsius):
    """Create 3D surface plot of temperature"""
    from mpl_toolkits.mplot3d import Axes3D

    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')

    h, w = temp_celsius.shape
    x = np.arange(w)
    y = np.arange(h)
    X, Y = np.meshgrid(x, y)

    # Downsample for performance
    stride = max(1, min(h, w) // 100)
    X_sampled = X[::stride, ::stride]
    Y_sampled = Y[::stride, ::stride]
    Z_sampled = temp_celsius[::stride, ::stride]

    surf = ax.plot_surface(X_sampled, Y_sampled, Z_sampled, cmap='inferno',
                          linewidth=0, antialiased=True, alpha=0.8)

    ax.set_xlabel('X', fontsize=10)
    ax.set_ylabel('Y', fontsize=10)
    ax.set_zlabel('Temperature (¬∞C)', fontsize=10)
    ax.set_title('3D Temperature Surface', fontsize=14, fontweight='bold')
    fig.colorbar(surf, ax=ax, shrink=0.5, label='Temperature (¬∞C)')

    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def process_image(lr_thermal_file, hr_optical_file, hr_thermal_gt_file, model, colormap='inferno'):
    lr_thermal_pil = Image.open(lr_thermal_file).convert("L")
    hr_optical_pil = Image.open(hr_optical_file).convert("RGB")

    lr_thermal_tensor = transform_gray(lr_thermal_pil).unsqueeze(0).to(DEVICE)
    hr_optical_tensor = transform_rgb(hr_optical_pil).unsqueeze(0).to(DEVICE)

    hr_thermal_gt_pil = None
    hr_thermal_gt_np = None
    if hr_thermal_gt_file:
        hr_thermal_gt_pil = Image.open(hr_thermal_gt_file).convert("L")
        hr_thermal_gt_np = np.array(hr_thermal_gt_pil) / 255.0

    with torch.no_grad():
        generated_output_tensor = model(lr_thermal_tensor, hr_optical_tensor)

    output_grayscale_np = generated_output_tensor.squeeze().cpu().numpy() * 0.5 + 0.5

    # Apply selected colormap
    cmap_func = cm.get_cmap(colormap)
    output_colored = cmap_func(output_grayscale_np)[:, :, :3]

    metrics = None
    difference_map = None
    if hr_thermal_gt_np is not None:
        if output_grayscale_np.shape != hr_thermal_gt_np.shape:
            st.warning(f"‚ö†Ô∏è Resizing prediction from {output_grayscale_np.shape} to {hr_thermal_gt_np.shape}")
            output_grayscale_pil = Image.fromarray((output_grayscale_np * 255).astype(np.uint8), 'L')
            output_grayscale_resized_pil = output_grayscale_pil.resize(hr_thermal_gt_pil.size, Image.BICUBIC)
            output_grayscale_np_resized = np.array(output_grayscale_resized_pil) / 255.0
            metrics = calculate_metrics(hr_thermal_gt_np, output_grayscale_np_resized)
            difference_map = create_difference_map(hr_thermal_gt_np, output_grayscale_np_resized)
        else:
            metrics = calculate_metrics(hr_thermal_gt_np, output_grayscale_np)
            difference_map = create_difference_map(hr_thermal_gt_np, output_grayscale_np)

    heatmap = create_heatmap_comparison(output_grayscale_np, colormap)

    # Hotspot detection and visualization
    temp_celsius = (output_grayscale_np * 100 + 263.15) - 273.15
    hotspots, hotspot_info = detect_hotspots(temp_celsius)
    hotspot_viz = create_hotspot_visualization(temp_celsius, hotspots, hotspot_info)

    # 3D Surface Plot
    surface_3d = create_3d_surface_plot(temp_celsius)


    return (lr_thermal_pil, hr_optical_pil, output_grayscale_np, output_colored,
            metrics, heatmap, difference_map, hotspot_viz, hotspot_info, surface_3d)


def pil_to_bytes(image, format='PNG'):
    buf = io.BytesIO()
    if isinstance(image, np.ndarray):
        if image.ndim == 2: # Grayscale
            image = Image.fromarray((image * 255).astype(np.uint8), 'L')
        elif image.ndim == 3: # RGB
            image = Image.fromarray((image * 255).astype(np.uint8))
        else:
            raise ValueError(f"Unsupported image dimensions: {image.ndim}")
    image.save(buf, format=format)
    return buf.getvalue()

def create_download_zip(lr_thermal_pil, hr_optical_pil, output_colored, heatmap, hr_thermal_gt_pil=None, difference_map=None, hotspot_viz=None, surface_3d=None, temp_csv=None):
    zip_buffer = io.BytesIO()
    with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
        zip_file.writestr('input_lr_thermal.png', pil_to_bytes(lr_thermal_pil))
        zip_file.writestr('input_hr_optical.png', pil_to_bytes(hr_optical_pil))
        zip_file.writestr('output_hr_thermal.png', pil_to_bytes(output_colored))
        zip_file.writestr('output_heatmap.png', pil_to_bytes(heatmap))
        if hr_thermal_gt_pil:
            zip_file.writestr('ground_truth_hr_thermal.png', pil_to_bytes(hr_thermal_gt_pil))
        if difference_map:
            zip_file.writestr('difference_map.png', pil_to_bytes(difference_map))
        if hotspot_viz:
            zip_file.writestr('hotspot_detection.png', pil_to_bytes(hotspot_viz))
        if surface_3d:
            zip_file.writestr('3d_surface_plot.png', pil_to_bytes(surface_3d))
        if temp_csv:
            zip_file.writestr('temperature_data.csv', temp_csv)
    return zip_buffer.getvalue()

# Main UI
def main():
    # Header
    st.markdown('<div class="main-header">üå°Ô∏è InfraNova</div>', unsafe_allow_html=True)
    st.markdown('<div class="sub-header">Advanced AI-Powered Thermal Image Super-Resolution System</div>', unsafe_allow_html=True)

    # Sidebar
    with st.sidebar:
        st.markdown("### About InfraNova")
        st.markdown("""
        This system leverages cutting-edge deep learning to enhance thermal imagery through:
        - Multi-modal data fusion (thermal + optical)
        - Physics-informed neural networks
        - Attention-enhanced U-Net architecture
        - Sub-meter temperature mapping
        - Multiple visualization modes
        """)

        st.markdown("---")
        st.markdown("### System Info")
        st.session_state.model = load_model()
        st.session_state.lpips_model = load_lpips_model()

        device_emoji = "üöÄ" if DEVICE == "cuda" else "üíª"
        st.info(f"""
        **Device:** {device_emoji} {DEVICE.upper()}
        **Architecture:** Enhanced U-Net + GAN
        **Parameters:** ~31M
        **Status:** {'‚úÖ Ready' if st.session_state.model else '‚ùå Not Loaded'}
        """)

        st.markdown("---")
        st.markdown("### Colormap Selection")
        colormap = st.selectbox(
            "Choose thermal colormap:",
            ['inferno', 'plasma', 'viridis', 'hot', 'coolwarm', 'jet', 'turbo'],
            index=0,
            help="Select the color scheme for thermal visualization"
        )

        st.markdown("---")
        st.markdown("### Real-World Applications")
        st.markdown("""
        - Urban heat island analysis
        - Wildfire detection & monitoring
        - Precision agriculture
        - Infrastructure inspection
        - Satellite imagery enhancement
        - Industrial thermal monitoring
        - Climate research
        """)

    # Main content tabs
    tab1, tab2, tab3, tab4 = st.tabs(["Image Processing", "Batch Processing", "Advanced Analysis", "Documentation"])

    with tab1:
        st.markdown("### Upload Images for Processing")

        col1, col2 = st.columns(2)

        with col1:
            st.markdown("#### Low-Resolution Thermal Image")
            lr_thermal_file = st.file_uploader(
                "Upload LR thermal image",
                type=["png", "jpg", "jpeg"],
                key="lr_thermal",
                help="Upload low-resolution thermal infrared image"
            )
            if lr_thermal_file:
                st.image(lr_thermal_file, caption="‚úÖ LR Thermal Loaded", use_container_width=True)

        with col2:
            st.markdown("#### High-Resolution Optical Image")
            hr_optical_file = st.file_uploader(
                "Upload HR optical image",
                type=["png", "jpg", "jpeg"],
                key="hr_optical",
                help="Upload high-resolution RGB optical image"
            )
            if hr_optical_file:
                st.image(hr_optical_file, caption="‚úÖ HR Optical Loaded", use_container_width=True)

        st.markdown("#### Optional: Ground Truth for Evaluation")
        with st.expander("Click to upload ground truth (for metric calculation)"):
            hr_thermal_gt_file = st.file_uploader(
                "Upload HR thermal ground truth",
                type=["png", "jpg", "jpeg"],
                key="hr_thermal_gt",
                help="Optional: Upload high-resolution thermal ground truth for metric calculation"
            )
            if hr_thermal_gt_file:
                st.image(hr_thermal_gt_file, caption="‚úÖ Ground Truth Loaded", use_container_width=True)

        # Process button
        if lr_thermal_file and hr_optical_file and st.session_state.model:
            st.markdown("---")

            col_btn1, col_btn2, col_btn3 = st.columns([1, 2, 1])
            with col_btn2:
                process_btn = st.button("Generate Super-Resolution Image", use_container_width=True)

            if process_btn:
                with st.spinner("AI Processing in progress... Please wait..."):
                    try:
                        results = process_image(
                            lr_thermal_file, hr_optical_file, hr_thermal_gt_file,
                            st.session_state.model, colormap
                        )

                        (lr_thermal_pil, hr_optical_pil, output_gray_np, output_colored,
                         metrics, heatmap, difference_map, hotspot_viz, hotspot_info, surface_3d) = results

                        st.session_state.results = {
                            'lr_thermal': lr_thermal_pil,
                            'hr_optical': hr_optical_pil,
                            'output_gray_np': output_gray_np,
                            'output_colored': output_colored,
                            'hr_thermal_gt': Image.open(hr_thermal_gt_file).convert("L") if hr_thermal_gt_file else None,
                            'metrics': metrics,
                            'heatmap': heatmap,
                            'difference_map': difference_map,
                            'hotspot_viz': hotspot_viz,
                            'hotspot_info': hotspot_info,
                            'surface_3d': surface_3d
                        }
                        st.success("‚úÖ Processing complete!")


                    except Exception as e:
                        st.error(f"‚ùå Error during processing: {str(e)}")
                        import traceback
                        st.error(traceback.format_exc())

        # Display results
        if 'results' in st.session_state:
            st.markdown("---")
            st.markdown("## Processing Results")

            # Comparison view
            st.markdown("### Input-Output Comparison")
            cols_display = st.columns(3 + (1 if st.session_state.results['hr_thermal_gt'] else 0))

            with cols_display[0]:
                st.markdown("#### Input: LR Thermal")
                st.image(st.session_state.results['lr_thermal'], use_container_width=True)

            with cols_display[1]:
                st.markdown("#### Input: HR Optical")
                st.image(st.session_state.results['hr_optical'], use_container_width=True)

            if st.session_state.results['hr_thermal_gt']:
                with cols_display[2]:
                    st.markdown("#### Ground Truth")
                    st.image(st.session_state.results['hr_thermal_gt'], use_container_width=True)
                with cols_display[3]:
                    st.markdown("#### AI Output")
                    st.image(st.session_state.results['output_colored'], use_container_width=True)
            else:
                with cols_display[2]:
                    st.markdown("#### AI Output")
                    st.image(st.session_state.results['output_colored'], use_container_width=True)

            # Enhanced visualizations
            st.markdown("### Enhanced Visualizations")
            viz_col1, viz_col2 = st.columns(2)

            with viz_col1:
                st.markdown("#### Thermal Output")
                st.image(st.session_state.results['output_colored'],
                        caption="AI-Generated High-Resolution Thermal Image",
                        use_container_width=True)

            with viz_col2:
                st.markdown("#### Temperature Heatmap")
                st.image(st.session_state.results['heatmap'],
                        caption="Temperature Distribution with Scale",
                        use_container_width=True)

            # NEW: Hotspot Detection
            st.markdown("### Hotspot Detection & Analysis")
            col_hot1, col_hot2 = st.columns(2)

            with col_hot1:
                st.image(st.session_state.results['hotspot_viz'],
                        caption="Detected Hotspots",
                        use_container_width=True)

            with col_hot2:
                st.markdown("#### Hotspot Summary")
                if st.session_state.results['hotspot_info']:
                    for idx, spot in enumerate(st.session_state.results['hotspot_info'][:5], 1):
                        st.markdown("""
                        **Hotspot {idx}:**
                        - Max Temp: {spot['max_temp']:.2f}¬∞C
                        - Area: {spot['area']} pixels
                        - Center: ({spot['center'][1]:.0f}, {spot['center'][0]:.0f})
                        """)
                    st.info(f"Total hotspots detected: {len(st.session_state.results['hotspot_info'])}")
                else:
                    st.info("No significant hotspots detected")

            # NEW: 3D Visualization
            st.markdown("### 3D Temperature Surface")
            st.image(st.session_state.results['surface_3d'],
                    caption="3D Temperature Surface Visualization",
                    use_container_width=True)

            # Display difference map if available
            if st.session_state.results['difference_map']:
                st.markdown("### Error Analysis")
                st.image(st.session_state.results['difference_map'],
                        caption="Prediction Error Map (darker = lower error)",
                        use_container_width=True)

            # Display Metrics
            if st.session_state.results['metrics']:
                st.markdown("### Evaluation Metrics")
                metrics = st.session_state.results['metrics']

                st.markdown('<div class="metric-container">', unsafe_allow_html=True)
                if 'psnr' in metrics:
                    st.markdown(f'<div class="metric-item"><strong>{metrics["psnr"]:.2f}</strong>PSNR (dB)</div>', unsafe_allow_html=True)
                if 'ssim' in metrics:
                    st.markdown(f'<div class="metric-item"><strong>{metrics["ssim"]:.4f}</strong>SSIM</div>', unsafe_allow_html=True)
                if 'diss' in metrics:
                    st.markdown(f'<div class="metric-item"><strong>{metrics["diss"]:.4f}</strong>DISS</div>', unsafe_allow_html=True)
                if 'rmse' in metrics:
                    st.markdown(f'<div class="metric-item"><strong>{metrics["rmse"]:.2f}</strong>RMSE (K)</div>', unsafe_allow_html=True)
                if 'lpips' in metrics and not np.isnan(metrics['lpips']):
                    st.markdown(f'<div class="metric-item"><strong>{metrics["lpips"]:.4f}</strong>LPIPS</div>', unsafe_allow_html=True)
                st.markdown('</div>', unsafe_allow_html=True)

                # Metric explanations
                with st.expander("Understanding the Metrics"):
                    st.markdown("""
                    - **PSNR (Peak Signal-to-Noise Ratio)**: Higher is better. >30dB is excellent.
                    - **SSIM (Structural Similarity Index)**: Ranges 0-1. Closer to 1 means better structural similarity.
                    - **DISS (Dissimilarity)**: Inverse of SSIM. Lower is better.
                    - **RMSE (Root Mean Square Error)**: Temperature accuracy in Kelvin. Lower is better.
                    - **LPIPS (Learned Perceptual Image Patch Similarity)**: Lower means better perceptual quality.
                    """)

            # Temperature statistics
            st.markdown("### Temperature Statistics")
            temp_kelvin = st.session_state.results['output_gray_np'] * 100 + 263.15
            temp_celsius = temp_kelvin - 273.15
            temp_fahrenheit = temp_celsius * 9/5 + 32

            col_stat1, col_stat2, col_stat3, col_stat4 = st.columns(4)

            with col_stat1:
                st.metric("Min Temperature",
                         f"{temp_celsius.min():.2f}¬∞C",
                         f"{temp_fahrenheit.min():.2f}¬∞F")
            with col_stat2:
                st.metric("Max Temperature",
                         f"{temp_celsius.max():.2f}¬∞C",
                         f"{temp_fahrenheit.max():.2f}¬∞F")
            with col_stat3:
                st.metric("Mean Temperature",
                         f"{temp_celsius.mean():.2f}¬∞C",
                         f"{temp_fahrenheit.mean():.2f}¬∞F")
            with col_stat4:
                st.metric("Std Deviation",
                         f"{temp_celsius.std():.2f}¬∞C",
                         f"{temp_fahrenheit.std():.2f}¬∞F")

            # Temperature distribution histogram
            st.markdown("### Temperature Distribution")
            fig_hist, ax_hist = plt.subplots(figsize=(10, 4))
            ax_hist.hist(temp_celsius.flatten(), bins=50, color='#667eea', alpha=0.7, edgecolor='black')
            ax_hist.set_xlabel('Temperature (¬∞C)', fontsize=12)
            ax_hist.set_ylabel('Frequency', fontsize=12)
            ax_hist.set_title('Temperature Distribution Histogram', fontsize=14, fontweight='bold')
            ax_hist.grid(True, alpha=0.3)
            st.pyplot(fig_hist)
            plt.close()

            # Download options
            st.markdown("### Download Options")

            # Generate CSV data
            temp_kelvin = st.session_state.results['output_gray_np'] * 100 + 263.15
            temp_celsius = temp_kelvin - 273.15
            temp_csv = export_temperature_data(temp_celsius)

            col_dl1, col_dl2, col_dl3, col_dl4, col_dl5, col_dl6 = st.columns(6)

            with col_dl1:
                st.download_button(
                    label="LR Thermal",
                    data=pil_to_bytes(st.session_state.results['lr_thermal']),
                    file_name=f"lr_thermal_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
                    mime="image/png",
                    use_container_width=True
                )

            with col_dl2:
                st.download_button(
                    label="HR Optical",
                    data=pil_to_bytes(st.session_state.results['hr_optical']),
                    file_name=f"hr_optical_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
                    mime="image/png",
                    use_container_width=True
                )

            with col_dl3:
                st.download_button(
                    label="HR Thermal",
                    data=pil_to_bytes(st.session_state.results['output_colored']),
                    file_name=f"hr_thermal_output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
                    mime="image/png",
                    use_container_width=True
                )

            with col_dl4:
                st.download_button(
                    label="Heatmap",
                    data=pil_to_bytes(st.session_state.results['heatmap']),
                    file_name=f"heatmap_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png",
                    mime="image/png",
                    use_container_width=True
                )

            with col_dl5:
                st.download_button(
                    label="CSV Data",
                    data=temp_csv,
                    file_name=f"temperature_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
                    mime="text/csv",
                    use_container_width=True
                )

            with col_dl6:
                zip_data = create_download_zip(
                    st.session_state.results['lr_thermal'],
                    st.session_state.results['hr_optical'],
                    st.session_state.results['output_colored'],
                    st.session_state.results['heatmap'],
                    st.session_state.results['hr_thermal_gt'],
                    st.session_state.results['difference_map'],
                    st.session_state.results['hotspot_viz'],
                    st.session_state.results['surface_3d'],
                    temp_csv
                )
                st.download_button(
                    label="All Files",
                    data=zip_data,
                    file_name=f"vayu_drishya_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip",
                    mime="application/zip"
                                    )

    with tab3:
        st.markdown("### Advanced Analysis Tools")

        if 'results' not in st.session_state:
            st.info("Please process an image in the 'Image Processing' tab first to use advanced analysis tools")
        else:
            temp_celsius = (st.session_state.results['output_gray_np'] * 100 + 263.15) - 273.15

            # Temperature Profile Analysis
            st.markdown("#### Temperature Profile Analysis")
            st.markdown("Analyze temperature distribution along a line")

            col_prof1, col_prof2 = st.columns(2)
            with col_prof1:
                profile_axis = st.radio("Select axis:", ["Horizontal", "Vertical"], horizontal=True)
            with col_prof2:
                profile_position = st.slider("Position (0=top/left, 1=bottom/right):",
                                            0.0, 1.0, 0.5, 0.01)

            if st.button("Generate Temperature Profile"):
                with st.spinner("Creating temperature profile..."):
                    profile_img = create_thermal_profile(
                        temp_celsius,
                        axis=profile_axis.lower(),
                        position=profile_position
                    )
                    st.image(profile_img, use_container_width=True)

            st.markdown("---")

            # ROI Analysis
            st.markdown("#### Region of Interest (ROI) Analysis")
            st.markdown("Analyze a specific region of the thermal image")

            col_roi1, col_roi2 = st.columns(2)
            with col_roi1:
                roi_x1 = st.number_input("X1 (left)", 0, temp_celsius.shape[1]-1, 0)
                roi_y1 = st.number_input("Y1 (top)", 0, temp_celsius.shape[0]-1, 0)
            with col_roi2:
                roi_x2 = st.number_input("X2 (right)", 0, temp_celsius.shape[1]-1, temp_celsius.shape[1]//2)
                roi_y2 = st.number_input("Y2 (bottom)", 0, temp_celsius.shape[0]-1, temp_celsius.shape[0]//2)

            if st.button("Analyze ROI"):
                if roi_x2 > roi_x1 and roi_y2 > roi_y1:
                    stats, roi = create_roi_analysis(temp_celsius, (roi_x1, roi_y1, roi_x2, roi_y2))

                    # Show ROI on image
                    fig, ax = plt.subplots(figsize=(10, 8))
                    im = ax.imshow(temp_celsius, cmap='inferno')
                    rect = plt.Rectangle((roi_x1, roi_y1), roi_x2-roi_x1, roi_y2-roi_y1,
                                        fill=False, edgecolor='cyan', linewidth=3)
                    ax.add_patch(rect)
                    plt.colorbar(im, ax=ax, label='Temperature (¬∞C)')
                    ax.set_title('Region of Interest', fontsize=14, fontweight='bold')
                    ax.axis('off')
                    st.pyplot(fig)
                    plt.close()

                    # Display statistics
                    st.markdown("#### ROI Statistics")
                    col_s1, col_s2, col_s3, col_s4, col_s5 = st.columns(5)
                    col_s1.metric("Min", f"{stats['min']:.2f}¬∞C")
                    col_s2.metric("Max", f"{stats['max']:.2f}¬∞C")
                    col_s3.metric("Mean", f"{stats['mean']:.2f}¬∞C")
                    col_s4.metric("Median", f"{stats['median']:.2f}¬∞C")
                    col_s5.metric("Std Dev", f"{stats['std']:.2f}¬∞C")
                else:
                    st.error("Invalid ROI coordinates. X2 must be > X1 and Y2 must be > Y1")

            st.markdown("---")

            # Comparison Tools
            st.markdown("#### Image Comparison")

            if st.session_state.results.get('hr_thermal_gt'):
                comparison_type = st.selectbox(
                    "Select comparison:",
                    ["Input LR vs Output HR", "Ground Truth vs Prediction", "Before vs After"]
                )

                if st.button("Generate Comparison"):
                    if comparison_type == "Input LR vs Output HR":
                        comp_img = compare_side_by_side(
                            st.session_state.results['lr_thermal'],
                            st.session_state.results['output_colored'],
                            "Low Resolution Input",
                            "High Resolution Output"
                        )
                    elif comparison_type == "Ground Truth vs Prediction":
                        comp_img = compare_side_by_side(
                            st.session_state.results['hr_thermal_gt'],
                            st.session_state.results['output_colored'],
                            "Ground Truth",
                            "AI Prediction"
                        )
                    else:
                        comp_img = compare_side_by_side(
                            st.session_state.results['lr_thermal'],
                            st.session_state.results['output_colored'],
                            "Before Enhancement",
                            "After Enhancement"
                        )
                    st.image(comp_img, use_container_width=True)
            else:
                st.info("Upload ground truth in the Image Processing tab for more comparison options")

            st.markdown("---")

            # Custom Hotspot Threshold
            st.markdown("#### Custom Hotspot Detection")
            hotspot_threshold = st.slider(
                "Hotspot threshold (percentile):",
                50, 99, 90, 1,
                help="Higher values detect only the hottest areas"
            )

            if st.button("Detect Hotspots with Custom Threshold"):
                with st.spinner("Detecting hotspots..."):
                    hotspots, hotspot_info = detect_hotspots(temp_celsius, hotspot_threshold)
                    custom_hotspot_viz = create_hotspot_visualization(temp_celsius, hotspots, hotspot_info)

                    col_hs1, col_hs2 = st.columns(2)
                    with col_hs1:
                        st.image(custom_hotspot_viz, use_container_width=True)
                    with col_hs2:
                        st.markdown(f"#### Found {len(hotspot_info)} hotspots")
                        for idx, spot in enumerate(hotspot_info[:10], 1):
                            st.markdown(f"**{idx}.** {spot['max_temp']:.2f}¬∞C | Area: {spot['area']}px")

    with tab2:
        st.markdown("### Batch Processing")
        st.info("Batch processing feature - Upload multiple image pairs for processing")

        st.markdown("#### Upload Multiple Image Pairs")
        col_batch1, col_batch2 = st.columns(2)

        with col_batch1:
            batch_lr_files = st.file_uploader(
                "Upload multiple LR thermal images",
                type=["png", "jpg", "jpeg"],
                accept_multiple_files=True,
                key="batch_lr"
            )

        with col_batch2:
            batch_hr_files = st.file_uploader(
                "Upload multiple HR optical images",
                type=["png", "jpg", "jpeg"],
                accept_multiple_files=True,
                key="batch_hr"
            )

        if batch_lr_files and batch_hr_files:
            if len(batch_lr_files) != len(batch_hr_files):
                st.warning("Number of LR and HR images must match!")
            else:
                st.success(f"‚úÖ {len(batch_lr_files)} image pairs ready for processing")

                if st.button("Process All Images", key="batch_process"):
                    progress_bar = st.progress(0)
                    status_text = st.empty()

                    batch_results = []
                    for idx, (lr_file, hr_file) in enumerate(zip(batch_lr_files, batch_hr_files)):
                        status_text.text(f"Processing image {idx+1}/{len(batch_lr_files)}...")

                        try:
                            results = process_image(lr_file, hr_file, None,
                                                   st.session_state.model, colormap)
                            batch_results.append({
                                'lr_name': lr_file.name,
                                'hr_name': hr_file.name,
                                'output': results[3],
                                'heatmap': results[5]
                            })
                        except Exception as e:
                            st.error(f"Error processing {lr_file.name}: {e}")

                        progress_bar.progress((idx + 1) / len(batch_lr_files))

                    status_text.text("‚úÖ Batch processing complete!")
                    st.session_state.batch_results = batch_results

                    # Create batch download
                    zip_buffer = io.BytesIO()
                    with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
                        for idx, result in enumerate(batch_results):
                            zip_file.writestr(f'batch_{idx+1}_output.png',
                                            pil_to_bytes(result['output']))
                            zip_file.writestr(f'batch_{idx+1}_heatmap.png',
                                            pil_to_bytes(result['heatmap']))

                    st.download_button(
                        label="Download All Batch Results",
                        data=zip_buffer.getvalue(),
                        file_name=f"batch_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip",
                        mime="application/zip"
                    )

    with tab4:
        st.markdown("### Documentation & User Guide")

        st.markdown("""
        ## How to Use InfraNova

        ### Step 1: Prepare Your Images
        - **LR Thermal Image**: Low-resolution thermal infrared image (grayscale)
        - **HR Optical Image**: High-resolution RGB optical image of the same scene
        - **Ground Truth** (optional): High-resolution thermal image for accuracy evaluation

        ### Step 2: Upload & Configure
        1. Navigate to the "Image Processing" tab
        2. Upload your LR thermal and HR optical images
        3. (Optional) Upload ground truth for metric calculation
        4. Select your preferred colormap from the sidebar

        ### Step 3: Generate Results
        Click the "Generate Super-Resolution Image" button to process your images

        ### Step 4: Analyze Results
        - View input-output comparison
        - Examine temperature distribution heatmap
        - Review evaluation metrics (if ground truth provided)
        - Check temperature statistics

        ### Step 5: Download
        Download individual images or all results as a ZIP file

        ---

        ## Understanding the Metrics

        ### PSNR (Peak Signal-to-Noise Ratio)
        - Measures pixel-level accuracy
        - **Range**: 20-50 dB (higher is better)
        - **Excellent**: > 35 dB
        - **Good**: 30-35 dB
        - **Acceptable**: 25-30 dB

        ### SSIM (Structural Similarity Index)
        - Measures perceived quality and structural similarity
        - **Range**: 0-1 (closer to 1 is better)
        - **Excellent**: > 0.95
        - **Good**: 0.90-0.95
        - **Acceptable**: 0.85-0.90

        ### RMSE (Root Mean Square Error)
        - Measures temperature accuracy in Kelvin. Lower is better.
        - Indicates average temperature prediction error

        ### LPIPS (Learned Perceptual Image Patch Similarity)
        - Deep learning-based perceptual metric
        - **Range**: 0-1 (lower is better)
        - Correlates well with human perception

        ---

        ## Colormap Guide

        - **Inferno**: High contrast, excellent for highlighting hot spots
        - **Plasma**: Similar to inferno with purple tones
        - **Viridis**: Perceptually uniform, colorblind-friendly
        - **Hot**: Traditional thermal imaging colors
        - **Coolwarm**: Blue (cold) to red (hot) diverging scale
        - **Jet**: Classic rainbow colormap (not perceptually uniform)
        - **Turbo**: Improved version of jet

        ---

        ## Technical Details

        ### Model Architecture
        - **Type**: Enhanced U-Net with Attention Mechanisms
        - **Training**: GAN-based adversarial training
        - **Input**: 4 channels (1 thermal + 3 RGB)
        - **Output**: 1 channel (enhanced thermal)
        - **Parameters**: ~31 million

        ### Processing Pipeline
        1. Input preprocessing and normalization
        2. Bicubic upsampling of LR thermal
        3. Multi-modal feature fusion
        4. U-Net encoding with attention
        5. Bottleneck processing
        6. U-Net decoding with skip connections
        7. Output denormalization and colormap application

        ---

        ## Applications

        ### Urban Planning
        - Heat island effect mapping
        - Energy efficiency assessment
        - Building thermal analysis

        ### Agriculture
        - Crop health monitoring
        - Irrigation management
        - Yield prediction

        ### Disaster Management
        - Wildfire detection and tracking
        - Search and rescue operations
        - Damage assessment

        ### Infrastructure
        - Bridge and road inspection
        - Power line monitoring
        - Pipeline leak detection

        ### Environmental Science
        - Climate change research
        - Ecosystem monitoring
        - Water temperature mapping

        ---

        ## System Requirements

        ### Recommended
        - GPU: NVIDIA GPU with CUDA support
        - RAM: 8GB+
        - Storage: 2GB+ free space

        ### Minimum
        - CPU: Multi-core processor
        - RAM: 4GB+
        - Storage: 1GB+ free space

        ---

        ## Troubleshooting

        ### Model Not Loading
        - Ensure model file exists at specified path
        - Check Google Drive is mounted correctly
        - Verify model file is not corrupted

        ### Out of Memory Error
        - Reduce image resolution
        - Process images individually instead of batch
        - Restart runtime and clear cache

        ### Poor Results
        - Ensure images are from the same scene
        - Check image alignment
        - Verify thermal image quality
        - Try different colormap for better visualization

        ---

        ## Citation

        If you use InfraNova in your research, please cite:
        """
        )


    # Footer
    st.markdown("---")
    st.markdown("""
    <div style='text-align: center; color: #888;'>
        <p>Powered by Enhanced U-Net + GAN Architecture</p>
        <p>Developed by:PARASAD HIRAGOND</p>
    </div>
    """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()

Writing app.py


In [10]:
import os

path_to_check = "/content/drive/MyDrive/BD"

if os.path.exists(path_to_check):
    print(f"‚úì Path exists: {path_to_check}")
    if os.path.isdir(path_to_check):
        print(f"‚úì Path is a directory: {path_to_check}")
        try:
            contents = os.listdir(path_to_check)
            if contents:
                print(f"‚úì Directory is not empty. First 5 items: {contents[:5]}")
                print(f"Total items: {len(contents)}")
            else:
                print(f"‚ö†Ô∏è Directory exists but is empty: {path_to_check}")
        except Exception as e:
            print(f"‚ùå Error listing directory contents for {path_to_check}: {e}")
    else:
        print(f"‚ö†Ô∏è Path exists but is not a directory: {path_to_check}")
        print("Please ensure this path points to a folder containing your thermal images.")
else:
    print(f"‚ùå Path does not exist: {path_to_check}")
    print("Please verify that 'BD' is correctly placed in your Google Drive's MyDrive folder and accessible.")


path_to_check = "/content/drive/MyDrive/HR RGB"
if os.path.exists(path_to_check):
    print(f"‚úì Path exists: {path_to_check}")
else:
    print(f"‚ùå Path does not exist: {path_to_check}")

path_to_check = "/content/drive/MyDrive/GT thermal"
if os.path.exists(path_to_check):
    print(f"‚úì Path exists: {path_to_check}")
else:
    print(f"‚ùå Path does not exist: {path_to_check}")


‚úì Path exists: /content/drive/MyDrive/BD
‚úì Path is a directory: /content/drive/MyDrive/BD
‚úì Directory is not empty. First 5 items: ['00095ix4.png', '00075ix4.png', '00072ix4.png', '00077ix4.png', '00080ix4.png']
Total items: 1025
‚úì Path exists: /content/drive/MyDrive/HR RGB
‚úì Path exists: /content/drive/MyDrive/GT thermal


In [11]:
# Simple test - no public URL
!streamlit run /content/app.py

# Then use Colab's port forwarding (it will show a link in output)


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.30.0.2:8501[0m
[34m  External URL: [0m[1mhttp://34.125.163.12:8501[0m
[0m
[34m  Stopping...[0m
[34m  Stopping...[0m


### Install Libraries for New Metrics

In [12]:
!pip install lpips -q
# scikit-image already installed, which should cover DISS and FSIM

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/53.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m53.8/53.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [13]:
!pip install pyngrok -q

In [14]:
from pyngrok import ngrok

# --- PASTE YOUR NGROK AUTHTOKEN HERE ---
AUTHTOKEN = "357B79LZjKUC9kY5aChCXt8ujSt_6CRSX1R1L4JVwGPSipsAK"

# Run streamlit
# Removed background execution and output redirection to see potential errors
print("Starting Streamlit app...")
!streamlit run app.py --server.port 8501


# Set up the ngrok tunnel
# This part might not be reached if streamlit fails to start
# ngrok.set_auth_token(AUTHTOKEN)
# public_url = ngrok.connect(8501)

# Print the public URL
# print(f"Your Streamlit app is live at: {public_url}")

Starting Streamlit app...

Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.30.0.2:8501[0m
[34m  External URL: [0m[1mhttp://34.125.163.12:8501[0m
[0m
[34m  Stopping...[0m


In [15]:
import os
import signal
import time

print("Attempting to find and stop the Streamlit process on port 8501...")

try:
    # Find the process ID (PID) listening on port 8501
    # lsof -i :8501 -t gets the PID(s) of processes using port 8501
    # -t option returns only the PID
    pid_command = "lsof -i :8501 -t"
    pid_output = os.popen(pid_command).read().strip()

    if pid_output:
        pids = pid_output.split('\n')
        print(f"Found process(es) on port 8501 with PID(s): {', '.join(pids)}")

        for pid in pids:
            try:
                pid_int = int(pid)
                 # Check if the process is still running before attempting to kill
                os.kill(pid_int, 0) # Signal 0 checks if the process exists

                print(f"Attempting to send SIGTERM to PID {pid}")
                os.kill(pid_int, signal.SIGTERM) # Use SIGTERM first
                time.sleep(2) # Give it a moment

                # Check if it stopped
                try:
                    os.kill(pid_int, 0)
                    print(f"PID {pid} still running after SIGTERM. Attempting SIGKILL...")
                    os.kill(pid_int, signal.SIGKILL) # Use SIGKILL if SIGTERM didn't work
                    print(f"Sent SIGKILL to PID {pid}")
                    time.sleep(2) # Give it another moment
                except ProcessLookupError:
                    print(f"PID {pid} stopped after SIGTERM.")


            except ProcessLookupError:
                print(f"Process with PID {pid} not found. It might have already stopped.")
            except Exception as e:
                print(f"Error managing process {pid}: {e}")

        # Final verification
        pid_output_after = os.popen(pid_command).read().strip()
        if not pid_output_after:
            print("‚úì Process on port 8501 successfully stopped.")
        else:
            print("‚ö†Ô∏è Process on port 8501 may still be running.")
            print("Manual intervention might be required (e.g., using `kill -9 <PID>` in a terminal).")

    else:
        print("No process found running on port 8501.")

except Exception as e:
    print(f"An error occurred while trying to find or stop the process: {e}")

print("\nNow you can try running the Streamlit app with ngrok again.")

Attempting to find and stop the Streamlit process on port 8501...
No process found running on port 8501.

Now you can try running the Streamlit app with ngrok again.


In [16]:
from pyngrok import ngrok
import os
import signal

print("Attempting to find and stop any running ngrok processes...")

try:
    # Find the PID of the ngrok process
    # You might need to adjust this command based on how ngrok is running
    # This command assumes ngrok is run directly and looks for the process name 'ngrok'
    pid_command = "pgrep ngrok"
    pid_output = os.popen(pid_command).read().strip()

    if pid_output:
        pids = pid_output.split('\n')
        print(f"Found ngrok process(es) with PID(s): {', '.join(pids)}")

        for pid in pids:
            try:
                pid_int = int(pid)
                # Check if the process is still running before attempting to kill
                os.kill(pid_int, 0) # Signal 0 checks if the process exists

                print(f"Attempting to send SIGTERM to PID {pid}")
                os.kill(pid_int, signal.SIGTERM) # Use SIGTERM first
                time.sleep(2) # Give it a moment

                # Check if it stopped
                try:
                    os.kill(pid_int, 0)
                    print(f"PID {pid} still running after SIGTERM. Attempting SIGKILL...")
                    os.kill(pid_int, signal.SIGKILL) # Use SIGKILL if SIGTERM didn't work
                    print(f"Sent SIGKILL to PID {pid}")
                    time.sleep(2) # Give it another moment
                except ProcessLookupError:
                    print(f"PID {pid} stopped after SIGTERM.")

            except ProcessLookupError:
                print(f"Process with PID {pid} not found. It might have already stopped.")
            except Exception as e:
                print(f"Error managing ngrok process {pid}: {e}")

        # Final verification
        pid_output_after = os.popen(pid_command).read().strip()
        if not pid_output_after:
            print("‚úì All ngrok processes successfully stopped.")
        else:
            print("‚ö†Ô∏è Some ngrok processes may still be running.")
            print("Manual intervention might be required (e.g., using `kill -9 <PID>` in a terminal).")

    else:
        print("No ngrok process found running.")

except Exception as e:
    print(f"An error occurred while trying to find or stop ngrok processes: {e}")

# Alternative method using pyngrok's built-in kill
print("\nAttempting to kill ngrok processes using pyngrok.kill()...")
try:
    ngrok.kill()
    print("‚úì pyngrok.kill() executed successfully.")
except Exception as e:
    print(f"pyngrok.kill() failed or no processes were running: {e}")

print("\nNow you can try running the Streamlit app with ngrok again.")

Attempting to find and stop any running ngrok processes...
No ngrok process found running.

Attempting to kill ngrok processes using pyngrok.kill()...
‚úì pyngrok.kill() executed successfully.

Now you can try running the Streamlit app with ngrok again.


In [17]:
from pyngrok import ngrok
import os

# --- PASTE YOUR NGROK AUTHTOKEN HERE ---
# You can get one at https://ngrok.com/
AUTHTOKEN = "34IEGqITKeE41WICwKh35RhsfnC_6nzq1e9KhZTEoY7rd7JpF" # Replace with your actual auth token

# Install streamlit to ensure it's available in this execution context
print("Ensuring Streamlit is installed...")
!pip install streamlit -q
print("Streamlit installation check complete.")

# Terminate any previous ngrok tunnels to free up the endpoint
try:
    ngrok.kill()
    print("Killed all ngrok processes to ensure a clean start.")
except Exception as e:
    print(f"No ngrok processes running or error killing: {e}")

# Run streamlit in the background
print("Starting Streamlit app in background...")
!nohup streamlit run app.py --server.port 8501 > streamlit.log 2>&1 &
print("Streamlit app started.")


# Set up the ngrok tunnel
ngrok.set_auth_token(AUTHTOKEN)
public_url = ngrok.connect(8501)

# Print the public URL
print(f"Your Streamlit app is live at: {public_url}")

Ensuring Streamlit is installed...
Streamlit installation check complete.
Killed all ngrok processes to ensure a clean start.
Starting Streamlit app in background...
Streamlit app started.
Your Streamlit app is live at: NgrokTunnel: "https://nonjudiciable-festinately-minh.ngrok-free.dev" -> "http://localhost:8501"
