# To Run On Google Colab

In [None]:
# First cell - Clone repository
!git clone https://github.com/utkarsh231/ViT-based-Image-Compression

In [None]:
%cd ViT-based-Image-Compression

In [None]:
# Second cell - Install dependencies
# If this breaks - restart and dont run this again
!pip install -r requirements.txt

In [None]:
!pip install pytorch-msssim

In [None]:
import os
import kagglehub

# Set custom path before downloading
os.environ["KAGGLEHUB_DIR"] = "/content/ViT-based-Image-Compression"

# Download the dataset
path = kagglehub.dataset_download("crawford/cat-dataset")

print("Path to dataset files:", path)

In [None]:
# Fourth cell - Prepare dataset
!python prepare_cat_data.py

In [None]:
!python train.py

# Testing

In [None]:
from google.colab import files
from IPython.display import Image, display
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image as PILImage
import os
from inference import run_inference
from tic_vit_encoder import ViTCompressor
from config import CompressionConfig
from torch.serialization import add_safe_globals

# Add CompressionConfig to safe globals for checkpoint loading
add_safe_globals([CompressionConfig])

# 1. Set image path (using your specific path)
image_path = "/content/ViT-based-Image-Compression/data/val/00000001_000.jpg"
print(f"Using image: {image_path}")

# 2. Load the model
print("\nLoading model...")
config = CompressionConfig()
model = ViTCompressor(
    img_size=config.img_size,
    patch_size=config.patch_size,
    embed_dim=config.embed_dim
)

# 3. Load checkpoint (replace with your checkpoint path)
checkpoint_path = '/content/ViT-based-Image-Compression/checkpoints/checkpoint_epoch_49.pt'
print(f"Loading checkpoint from: {checkpoint_path}")

# Load checkpoint with weights_only=False to handle the config
checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
else:
    model.load_state_dict(checkpoint)  # Direct state dict

model = model.cuda()
model.eval()

# 4. Run inference
print("\nRunning inference...")
output_dir = '/content/drive/MyDrive/transformer_compression/outputs/single_test'
os.makedirs(output_dir, exist_ok=True)  # Create output directory if it doesn't exist

# Modified inference function to handle device
def run_inference_cuda(model, image_path, output_dir, img_size=256):
    """Run model inference on a single image with CUDA support."""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load and process image
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = PILImage.open(image_path).convert('RGB')
    x = transform(image).unsqueeze(0).cuda()  # Move input to CUDA
    
    # Run model
    model.eval()
    with torch.no_grad():
        x_hat, likelihoods = model(x)
    
    # Move tensors back to CPU for saving
    x = x.cpu()
    x_hat = x_hat.cpu()
    
    # Calculate metrics
    mse = torch.mean((x - x_hat) ** 2)
    psnr = 10 * torch.log10(1.0 / mse)
    bpp = -torch.log2(likelihoods).mean() / (img_size * img_size)
    
    # Save original and reconstructed images
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    
    # Denormalize and save original
    x_denorm = x * std + mean
    x_denorm = torch.clamp(x_denorm, 0, 1)
    x_img = (x_denorm.squeeze(0).permute(1, 2, 0).numpy() * 255).astype('uint8')
    PILImage.fromarray(x_img).save(os.path.join(output_dir, 'original.png'))
    
    # Denormalize and save reconstructed
    x_hat_denorm = x_hat * std + mean
    x_hat_denorm = torch.clamp(x_hat_denorm, 0, 1)
    x_hat_img = (x_hat_denorm.squeeze(0).permute(1, 2, 0).numpy() * 255).astype('uint8')
    PILImage.fromarray(x_hat_img).save(os.path.join(output_dir, 'reconstructed.png'))
    
    # Plot and save metrics
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(x_img)
    plt.title('Original')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(x_hat_img)
    plt.title(f'Reconstructed\nPSNR: {psnr:.2f} dB, BPP: {bpp:.4f}')
    plt.axis('off')
    
    plt.savefig(os.path.join(output_dir, 'comparison.png'))
    plt.close()

# Run the modified inference
run_inference_cuda(model, image_path, output_dir, img_size=config.img_size)

# 5. Display results
print("\nResults saved in:", output_dir)
display(Image(filename=os.path.join(output_dir, 'comparison.png')))