In [1]:
import sys
from types import SimpleNamespace
import torch

# Add the source folder to sys.path
sys.path.append('../source')

# Import the top-level Loss class
from cliploss import Loss

# Define dummy args with both semantic and geometric loss disabled/enabled as needed
args = SimpleNamespace(
    device="cuda" if torch.cuda.is_available() else "cpu",
    percep_loss="none",                  # no perceptual loss like LPIPS
    train_with_clip=False,               # enable CLIP loss
    clip_weight=0,                    # weight for semantic CLIP loss
    start_clip=0,                       # start CLIP loss immediately
    clip_conv_loss=1,              # disable geometric loss for this test
    clip_fc_loss_weight=0.1,           # weight for fc in conv loss (irrelevant since disabled)
    clip_text_guide=0.0,               # unused here
    num_aug_clip=4,                    # number of augmentations
    augemntations=["affine"],          # apply affine augmentation
    include_target_in_aug=False,       # only augment sketch
    augment_both=False,
    clip_model_name="ViT-B/32",
    clip_conv_loss_type="L2",          # not used in this config
    clip_conv_layer_weights=[0, 0, 1.0, 1.0, 0],  # not used unless clip_conv_loss=True
)

# Create dummy sketch and target tensors
def get_dummy_image(size=(224, 224)):
    return torch.rand(1, 3, *size)

sketch = get_dummy_image().to(args.device)
target = get_dummy_image().to(args.device)

# Instantiate and run the combined Loss
loss_fn = Loss(args).to(args.device)
losses_dict = loss_fn(sketch, target, epoch=100, mode="train")

# Compute total loss from weighted components
final_loss = sum(losses_dict.values())

# Print results
print("Loss breakdown:")
for name, val in losses_dict.items():
    print(f"  {name}: {val.item():.4f}")
print(f"Total weighted loss: {final_loss.item():.4f}")

Loss breakdown:
  clip_conv_loss: 0.0000
  clip_conv_loss_layer2: 0.0206
  clip_conv_loss_layer3: 0.0124
  fc: 0.0008
Total weighted loss: 0.0337
