In [1]:
import sys
import os
from pathlib import Path

# If needed, add the project root to sys.path so we can import from src
project_root = Path(os.getcwd()).parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root / 'src'))

In [2]:
import torch 
from models.glip_loc import GLIPLocModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare dummy data
B = 4
ground_images = torch.randn(B, 3, 224, 224).to(device)
sat_images = torch.randn(B, 3, 224, 224).to(device)
# For testing separate captions:
ground_captions = [
    "A ground-level view of a busy street",
    "A ground photo of a forest trail",
    "A panorama of a rural village at ground level",
    "A ground-level shot of a modern building"
]
sat_captions = [
    "A satellite view of a city center",
    "A satellite image of a large forest",
    "A top-down satellite shot of farmland",
    "A satellite image of coastal lines"
]


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
###################################
# 1. CLIP Vision Only
###################################
print("=== Test 1: CLIP Vision Only ===")
clip_vision_model = GLIPLocModel(
    model_name="openai/clip-vit-base-patch32",
    pretrained=True,
    use_text=False
).to(device)
clip_vision_model.eval()

with torch.no_grad():
    # Get ground embeddings only
    ground_emb = clip_vision_model(ground_image=ground_images)
    # Get satellite embeddings only
    sat_emb = clip_vision_model(satellite_image=sat_images)

    # Get both ground and satellite embeddings
    ground_emb_, sat_emb_ = clip_vision_model(ground_image=ground_images, satellite_image=sat_images)

print("Ground Embeddings (CLIP Vision Only):", ground_emb.shape)  # Expect [4, 768]
print("Satellite Embeddings (CLIP Vision Only):", sat_emb.shape)  # Expect [4, 768]
print("Ground Embeddings (CLIP Vision Only):", ground_emb_.shape)  # Expect [4, 768]
print("Satellite Embeddings (CLIP Vision Only):", sat_emb_.shape)  # Expect [4, 768]

=== Test 1: CLIP Vision Only ===


  return self.fget.__get__(instance, owner)()


Ground Embeddings (CLIP Vision Only): torch.Size([4, 512])
Satellite Embeddings (CLIP Vision Only): torch.Size([4, 512])
Ground Embeddings (CLIP Vision Only): torch.Size([4, 512])
Satellite Embeddings (CLIP Vision Only): torch.Size([4, 512])


In [5]:
##################################
# 2. CLIP Vision + Text
###################################
print("=== Test 2: CLIP Vision + Text ===")
clip_vision_text_model = GLIPLocModel(
    model_name="openai/clip-vit-base-patch32",
    pretrained=True,
    use_text=True
).to(device)
clip_vision_text_model.eval()

with torch.no_grad():
    ground_emb_clip, sat_emb_clip, ground_txt_emb_clip, sat_txt_emb_clip = clip_vision_text_model(
        ground_image=ground_images, 
        satellite_image=sat_images,
        ground_captions=ground_captions,
        satellite_captions=sat_captions
    )

print("Ground Emb (CLIP V+T):", ground_emb_clip.shape)
print("Sat Emb (CLIP V+T):", sat_emb_clip.shape)
print("Ground Text Emb (CLIP V+T):", ground_txt_emb_clip.shape)
print("Sat Text Emb (CLIP V+T):", sat_txt_emb_clip.shape)
print()

=== Test 2: CLIP Vision + Text ===
Ground Emb (CLIP V+T): torch.Size([4, 512])
Sat Emb (CLIP V+T): torch.Size([4, 512])
Ground Text Emb (CLIP V+T): torch.Size([4, 512])
Sat Text Emb (CLIP V+T): torch.Size([4, 512])



In [6]:
###################################
# 3. ConvNeXt Vision Only
###################################
print("=== Test 3: ConvNeXt Vision Only ===")
convnext_vision_model = GLIPLocModel(
    model_name="convnext_base",
    pretrained=True,
    use_text=False
).to(device)
convnext_vision_model.eval()

with torch.no_grad():
    ground_emb_conv, sat_emb_conv = convnext_vision_model(
        ground_image=ground_images, 
        satellite_image=sat_images
    )

print("Ground Emb (ConvNeXt Vision):", ground_emb_conv.shape)
print("Sat Emb (ConvNeXt Vision):", sat_emb_conv.shape)
print()

=== Test 3: ConvNeXt Vision Only ===
Ground Emb (ConvNeXt Vision): torch.Size([4, 1024])
Sat Emb (ConvNeXt Vision): torch.Size([4, 1024])



In [7]:
##################################
# 4. ConvNeXt Vision + Text
###################################
print("=== Test 4: ConvNeXt Vision + Text ===")
convnext_vision_text_model = GLIPLocModel(
    model_name="convnext_base",
    pretrained=True,
    use_text=True
).to(device)
convnext_vision_text_model.eval()
with torch.no_grad():
    ground_emb_clip, sat_emb_clip, ground_txt_emb_clip, sat_txt_emb_clip = convnext_vision_text_model(
        ground_image=ground_images, 
        satellite_image=sat_images,
        ground_captions=ground_captions,
        satellite_captions=sat_captions
    )

print("Ground Emb (CLIP V+T):", ground_emb_clip.shape)
print("Sat Emb (CLIP V+T):", sat_emb_clip.shape)
print("Ground Text Emb (CLIP V+T):", ground_txt_emb_clip.shape)
print("Sat Text Emb (CLIP V+T):", sat_txt_emb_clip.shape)
print()

=== Test 4: ConvNeXt Vision + Text ===
Ground Emb (CLIP V+T): torch.Size([4, 512])
Sat Emb (CLIP V+T): torch.Size([4, 512])
Ground Text Emb (CLIP V+T): torch.Size([4, 512])
Sat Text Emb (CLIP V+T): torch.Size([4, 512])

