In [1]:


import torch
import torch.nn as nn

# Random scenario setup
B = 1       # batch size
z_dim = 5      # latent size
K = 3          # number of classes
e = 2          # embedding dimension
img_dim = 6    # fake image flattened size (toy)

# Create random z and labels
z = torch.randn(B, z_dim)
print(z.shape)

y = torch.randint(0, K, (B,))

# Embedding for classes
embedding = nn.Embedding(K, e)
print(embedding)
y_emb = embedding(y)
print(y_emb.shape)

# Generator input
g_in = torch.cat([z, y_emb], dim=1)
print(g_in.shape)
# Fake generator (just linear for demo)
G = nn.Linear(z_dim + e, img_dim)
x_hat = G(g_in)

# Flatten real "images" (random)
x = torch.randn(B, img_dim)
x_flat = x.view(B, -1)

# Discriminator input
d_in = torch.cat([x_flat, y_emb], dim=1)
D = nn.Linear(img_dim + e, 1)
score = D(d_in)




torch.Size([1, 5])
Embedding(3, 2)
torch.Size([1, 2])
torch.Size([1, 7])


In [3]:
import torch
from src.models.shs_gan.shs_generator import Generator
from src.models.shs_gan.shs_discriminator import Critic3D

def test_shapes():
    gen = Generator()
    critic = Critic3D()
    
    # Test input
    x = torch.randn(2, 3, 224, 224)
    print(f"Generator input: {x.shape}")
    

    fake_hsi = gen(x)
    print(f"Generator output: {fake_hsi.shape}")  # Should be [2, 16, 224, 224]
    
  
    score = critic(fake_hsi)
    print(f"Critic output: {score.shape}")  # Should be [2, 1]
    
    return fake_hsi, score

fake_hsi, score = test_shapes()

Generator input: torch.Size([2, 3, 224, 224])
Generator output: torch.Size([2, 16, 224, 224])
Critic output: torch.Size([2, 1])


In [13]:
# Cell 1: Import dependencies
import torch
import yaml
import numpy as np
from pathlib import Path
from omegaconf import OmegaConf


In [16]:
# Test the dataset and configuration only

# 1. Load and analyze the config file
import yaml
from pathlib import Path

config_path = "/mnt/datahdd/kris_volume/dgm-2025.2/projects/hyperskin/configs/data/hsi_dermoscopy_synth.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("=== CONFIG FILE ANALYSIS ===")
print(f"Class path: {config['data']['class_path']}")
print(f"Image size: {config['data']['init_args']['image_size']}")
print(f"Batch size: {config['data']['init_args']['batch_size']}")
print(f"Allowed labels: {config['data']['init_args']['allowed_labels']}")
print(f"Data directory: {config['data']['init_args']['data_dir']}")

# 2. Fix the interpolation issues in transforms
image_size = config['data']['init_args']['image_size']

# Replace the interpolation variables with actual values
if 'transforms' in config['data']['init_args']:
    transforms = config['data']['init_args']['transforms']
    for stage in ['train', 'val', 'test']:
        if stage in transforms:
            for transform in transforms[stage]:
                if 'init_args' in transform:
                    for key, value in transform['init_args'].items():
                        if isinstance(value, str) and '${data.init_args.image_size}' in value:
                            transform['init_args'][key] = image_size

print("\n=== FIXED TRANSFORMS ===")
print("Replaced interpolation variables with actual values")

# 3. Import and setup the data module
from src.data_modules.hsi_dermoscopy import HSIDermoscopyDataModule

# Use the fixed config directly instead of OmegaConf
init_args = config['data']['init_args']
init_args['class_path'] = "src.data_modules.HSIDermoscopyDataModule"

datamodule = HSIDermoscopyDataModule(**init_args)

print("\n=== DATAMODULE CREATED ===")
print(f"Task: {datamodule.hparams.task}")
print(f"Image size: {datamodule.hparams.image_size}")
print(f"Batch size: {datamodule.hparams.batch_size}")

# 4. Prepare data (download if needed)
print("\n=== PREPARING DATA ===")
datamodule.prepare_data()

# 5. Setup for training
print("\n=== SETUP DATA SPLITS ===")
datamodule.setup(stage='fit')

# 6. Check dataset sizes
print("\n=== DATASET SIZES ===")
print(f"Training samples: {len(datamodule.data_train)}")
print(f"Validation samples: {len(datamodule.data_val)}")
if hasattr(datamodule, 'data_test'):
    print(f"Test samples: {len(datamodule.data_test)}")

# 7. Test one batch from training loader
print("\n=== TESTING ONE BATCH ===")
train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))

if isinstance(batch, (list, tuple)):
    images, labels = batch
    print(f"Batch images shape: {images.shape}")  # [B, C, H, W]
    print(f"Batch labels shape: {labels.shape}")
    print(f"Image dtype: {images.dtype}")
    print(f"Label dtype: {labels.dtype}")
    print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
    print(f"Labels: {labels}")
else:
    print(f"Unexpected batch format: {type(batch)}")

print("\n=== CONFIGURATION TEST COMPLETE ===")

=== CONFIG FILE ANALYSIS ===
Class path: data_modules.HSIDermoscopyDataModule
Image size: 256
Batch size: 4
Allowed labels: ['melanoma']
Data directory: data/hsi_dermoscopy

=== FIXED TRANSFORMS ===
Replaced interpolation variables with actual values


TypeError: HSIDermoscopyDataModule.__init__() got an unexpected keyword argument 'class_path'