In [1]:
import os

os.makedirs("data/real", exist_ok=True)         
os.makedirs("data/synthetic/images", exist_ok=True) 
os.makedirs("data/synthetic/depth", exist_ok=True) 

In [3]:
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import torch

class MixedDomainDataset(Dataset):
    def __init__(self, real_dir, synthetic_dir):
        self.real_images = sorted([os.path.join(real_dir, f) for f in os.listdir(real_dir) if f.endswith('.jpg')])
        self.synthetic_images = sorted([os.path.join(synthetic_dir, 'images', f) for f in os.listdir(os.path.join(synthetic_dir, 'images'))])
        self.synthetic_depths = sorted([os.path.join(synthetic_dir, 'depth', f.replace('.jpg', '.npy')) for f in os.listdir(os.path.join(synthetic_dir, 'images'))])
        
        self.transform = transforms.Compose([
            transforms.Resize((256, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.real_images) + len(self.synthetic_images)
    
    def __getitem__(self, idx):
        if idx < len(self.real_images):
            img = Image.open(self.real_images[idx]).convert('RGB')
            img = self.transform(img)
            return {'image': img, 'is_real': True, 'filepath': self.real_images[idx]}
        else:
            syn_idx = idx - len(self.real_images)
            img = Image.open(self.synthetic_images[syn_idx]).convert('RGB')
            depth = np.load(self.synthetic_depths[syn_idx])
            img = self.transform(img)
            depth = torch.from_numpy(depth).unsqueeze(0).float()
            return {'image': img, 'depth': depth, 'is_real': False, 'filepath': self.synthetic_images[syn_idx]}

In [1]:
# import matplotlib.pyplot as plt

# dataset = MixedDomainDataset('data/real', 'data/synthetic')

# def plot_sample(sample):
#     plt.figure(figsize=(12, 4))
#     plt.subplot(1, 3, 1)
#     plt.imshow(sample['image'].permute(1, 2, 0).numpy() * 0.5 + 0.5) 
#     plt.title(f"{'Real' if sample['is_real'] else 'Synthetic'}: {os.path.basename(sample['filepath'])}")
    
#     if not sample['is_real']:
#         plt.subplot(1, 3, 2)
#         plt.imshow(sample['depth'].squeeze(), cmap='jet')
#         plt.title('Depth Map')
        
#         plt.subplot(1, 3, 3)
#         plt.hist(sample['depth'].flatten().numpy(), bins=50)
#         plt.title('Depth Distribution')
#     plt.show()

# plot_sample(dataset[0])
# plot_sample(dataset[len(dataset.real_images)])

In [7]:
from torch.autograd import Function

class GradientReversal(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None

def grad_reverse(x, alpha=1.0):
    return GradientReversal.apply(x, alpha)

In [9]:
import torch.nn as nn

class DomainClassifier(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2)
        self.fc = nn.Linear(128, 2)
        
    def forward(self, x, alpha=1.0):
        x = grad_reverse(x, alpha)
        x = nn.ReLU()(self.conv1(x))
        x = nn.MaxPool2d(2)(x)
        x = nn.ReLU()(self.conv2(x))
        x = nn.AdaptiveAvgPool2d(1)(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [11]:
class DepthDomainAdaptationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            *list(torch.hub.load('pytorch/vision', 'resnet18', pretrained=True).children())[:-2]
        )
        
        self.depth_head = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
        self.domain_classifier = DomainClassifier(512)
    
    def forward(self, x, alpha=1.0):
        features = self.feature_extractor(x)
        depth = self.depth_head(features)
        domain_logits = self.domain_classifier(features, alpha)
        return depth, domain_logits

In [None]:
from torch.utils.data import DataLoader
import tqdm

dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

for epoch in range(10):
    model.train()
    total_loss = 0
    
    for batch in tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        images = batch['image'].to(device)
        is_real = batch['is_real'].to(device)
        
        depths = torch.zeros(len(images), 1, 64, 128).to(device) 
        if not all(is_real):
            synth_mask = ~is_real
            depths[synth_mask] = batch['depth'][synth_mask].to(device)
        
        pred_depths, domain_logits = model(images, alpha=2.0)
        
        depth_loss = nn.MSELoss()(pred_depths[~is_real], depths[~is_real])
        domain_labels = is_real.long()  # 0=real, 1=synthetic
        domain_loss = nn.CrossEntropyLoss()(domain_logits, domain_labels)
        loss = depth_loss + 0.1 * domain_loss 
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1} Loss: {total_loss/len(dataloader):.4f}")

In [None]:
from sklearn.manifold import TSNE
import numpy as np

def visualize_features():
    model.eval()
    real_features = []
    synth_features = []
    
    with torch.no_grad():
        for batch in DataLoader(dataset, batch_size=32):
            images = batch['image'].to(device)
            features = model.feature_extractor(images)
            features = features.mean(dim=[2,3]).cpu().numpy()
            
            for i in range(len(batch['is_real'])):
                if batch['is_real'][i]:
                    real_features.append(features[i])
                else:
                    synth_features.append(features[i])
    
    all_features = np.vstack([real_features, synth_features])
    tsne = TSNE(n_components=2)
    reduced = tsne.fit_transform(all_features)
    
    plt.figure(figsize=(10, 6))
    plt.scatter(reduced[:len(real_features), 0], reduced[:len(real_features), 1], c='r', label='Real')
    plt.scatter(reduced[len(real_features):, 0], reduced[len(real_features):, 1], c='b', label='Synthetic')
    plt.title("Feature Space Distribution (t-SNE)")
    plt.legend()
    plt.show()

visualize_features()

In [None]:
def test_depth_prediction():
    model.eval()
    test_sample = dataset[len(dataset.real_images)] 
    
    with torch.no_grad():
        pred_depth, _ = model(test_sample['image'].unsqueeze(0).to(device))
    
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(test_sample['image'].permute(1, 2, 0).numpy() * 0.5 + 0.5)
    plt.title("Input Image")
    
    plt.subplot(1, 3, 2)
    plt.imshow(test_sample['depth'].squeeze(), cmap='jet')
    plt.title("Ground Truth Depth")
    
    plt.subplot(1, 3, 3)
    plt.imshow(pred_depth.squeeze().cpu().numpy(), cmap='jet')
    plt.title("Predicted Depth")
    plt.show()

test_depth_prediction()