# Cycle-consistent BEV generator training

## Library

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import os
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from einops import rearrange

## Dataset

In [2]:
# === Dataset ===
class DirectoryNPYDataset(Dataset):
    def __init__(self, mask_dir):
        self.mask_paths = sorted([
            os.path.join(mask_dir, f)
            for f in os.listdir(mask_dir)
            if f.endswith('.npy')
        ])
        assert len(self.mask_paths) > 0, "No .npy files found in directory."

    def __len__(self):
        return len(self.mask_paths)

    def __getitem__(self, idx):
        mask = np.load(self.mask_paths[idx])  # shape [H, W]
        return torch.tensor(mask, dtype=torch.long)

## cycle model

In [3]:
# === Model F: Segmentation → BEV ===
class SegToBEV(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.MaxPool2d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, out_channels, 4, 2, 1)
        )

    def forward(self, x):  # [B, C, H, W]
        return self.decoder(self.encoder(x))  # [B, C, H_bev, W_bev]
    
# === Model G: BEV → Reconstructed Segmentation ===
class BEVToSeg(nn.Module):
    def __init__(self, in_channels, out_channels, output_size):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, out_channels, 1),
            nn.Upsample(size=output_size, mode='bilinear', align_corners=False)
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))
    
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads=4, mlp_dim=128, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ffn(self.norm2(x))
        return x
    
class SegToBEV_TransformerLite(nn.Module):
    def __init__(self, in_channels, out_channels, input_size=(96, 128), embed_dim=64, depth=4, bev_size=(32, 32)):
        super().__init__()
        self.input_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=1)
        self.bev_H, self.bev_W = bev_size
        self.pos_embed = nn.Parameter(torch.randn(1, self.bev_H * self.bev_W, embed_dim))

        self.transformer = nn.Sequential(
            *[TransformerBlock(embed_dim, heads=4, mlp_dim=128) for _ in range(depth)]
        )
        self.output_proj = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, out_channels)
        )

    def forward(self, x):  # x: [B, C, H, W]
        x = F.adaptive_avg_pool2d(self.input_proj(x), output_size=(self.bev_H, self.bev_W))  # [B, D, H_bev, W_bev]
        x = rearrange(x, 'b d h w -> b (h w) d')  # [B, N, D]
        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.transformer(x)  # [B, N, D]
        x = self.output_proj(x)  # [B, N, C]
        x = rearrange(x, 'b (h w) c -> b c h w', h=self.bev_H, w=self.bev_W)
        return x  # [B, C, H_bev, W_bev]
    
class BEVToSeg_TransformerLite(nn.Module):
    def __init__(self, in_channels, out_channels, bev_size=(32, 32), output_size=(96, 128), embed_dim=64, depth=3):
        super().__init__()
        self.bev_H, self.bev_W = bev_size
        self.output_size = output_size

        self.input_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=1)
        self.pos_embed = nn.Parameter(torch.randn(1, self.bev_H * self.bev_W, embed_dim))

        self.transformer = nn.Sequential(
            *[TransformerBlock(embed_dim, heads=4, mlp_dim=128) for _ in range(depth)]
        )
        self.output_proj = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, out_channels)
        )

    def forward(self, x):  # x: [B, C, H_bev, W_bev]
        x = self.input_proj(x)  # [B, D, H, W]
        B, D, H, W = x.shape
        N = H * W

        if self.pos_embed.shape[1] != N:
            # Resize pos_embed if BEV size changed unexpectedly
            self.pos_embed = nn.Parameter(torch.randn(1, N, D).to(x.device))

        x = rearrange(x, 'b d h w -> b (h w) d')  # [B, N, D]
        x = x + self.pos_embed  # now properly matched [B, N, D]
        x = self.transformer(x)
        x = self.output_proj(x)  # [B, N, C]
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
        x = F.interpolate(x, size=self.output_size, mode='bilinear', align_corners=False)
        return x  # [B, C, H_out, W_out]

## Training

In [4]:
def total_variation_loss(img):
    return torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])) + \
           torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]))

# === Training ===
def train(mask_dir, num_classes=4, epochs=20, batch_size=16, device='cuda', embed_dim=32, bev_size=(32, 32), lambda_bev=0.1):
    dataset = DirectoryNPYDataset(mask_dir)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Use first sample to get image size
    H, W = dataset[0].shape

    # model_F = SegToBEV(in_channels=num_classes, out_channels=num_classes).to(device)
    # model_G = BEVToSeg(in_channels=num_classes, out_channels=num_classes, output_size=(H, W)).to(device)

    # Use transformer-based F and G
    model_F = SegToBEV_TransformerLite(
        in_channels=num_classes,
        out_channels=num_classes,
        input_size=(H, W),
        embed_dim=embed_dim,
        bev_size=bev_size
    ).to(device)

    model_G = BEVToSeg_TransformerLite(
        in_channels=num_classes,
        out_channels=num_classes,
        bev_size=bev_size,
        output_size=(H, W),
        embed_dim=embed_dim
    ).to(device)

    optimizer = torch.optim.Adam(list(model_F.parameters()) + list(model_G.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model_F.train()
        model_G.train()
        total_loss = 0
        for seg_mask in dataloader:  # [B, H, W]
            seg_mask = seg_mask.to(device)
            onehot = F.one_hot(seg_mask, num_classes).permute(0, 3, 1, 2).float()  # [B, C, H, W]

            bev = model_F(onehot)              # [B, C, H_bev, W_bev]
            recon = model_G(bev)               # [B, C, H, W]
            loss_recon = criterion(recon, seg_mask)  # CrossEntropy expects logits and target indices
            loss_bev_smooth = total_variation_loss(bev)
            loss = loss_recon + lambda_bev * loss_bev_smooth

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss / len(dataloader):.4f}")

    return model_F, model_G

In [5]:
F_model, G_model = train("./predicted_masks/train", num_classes=4, epochs=20)

Epoch 1/20 - Loss: 1.2143
Epoch 2/20 - Loss: 0.7554
Epoch 3/20 - Loss: 0.5391
Epoch 4/20 - Loss: 0.4095
Epoch 5/20 - Loss: 0.2999
Epoch 6/20 - Loss: 0.2246
Epoch 7/20 - Loss: 0.1777
Epoch 8/20 - Loss: 0.1477
Epoch 9/20 - Loss: 0.1278
Epoch 10/20 - Loss: 0.1137
Epoch 11/20 - Loss: 0.1027
Epoch 12/20 - Loss: 0.0946
Epoch 13/20 - Loss: 0.0880
Epoch 14/20 - Loss: 0.0825
Epoch 15/20 - Loss: 0.0783
Epoch 16/20 - Loss: 0.0748
Epoch 17/20 - Loss: 0.0717
Epoch 18/20 - Loss: 0.0688
Epoch 19/20 - Loss: 0.0662
Epoch 20/20 - Loss: 0.0641


## Visualization

In [6]:
def decode_segmap(mask, color_map):
    """Convert [H, W] class-index mask to RGB image"""
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for class_idx, color in enumerate(color_map):
        color_mask[mask == class_idx] = color
    return color_mask

# === Main visualization function ===
def visualize(models, dataset, color_map, save_dir="vis_results", num_samples=5, device='cuda'):
    os.makedirs(save_dir, exist_ok=True)
    F_model, G_model = models
    F_model.eval()
    G_model.eval()

    with torch.no_grad():
        for idx in range(num_samples):
            mask = dataset[idx].unsqueeze(0).to(device)  # [1, H, W]
            H, W = mask.shape[1:]
            num_classes = len(color_map)

            # Convert to one-hot
            onehot = torch.nn.functional.one_hot(mask, num_classes).permute(0, 3, 1, 2).float()  # [1, C, H, W]

            # Forward pass
            bev = F_model(onehot)                     # [1, C, H_bev, W_bev]
            recon = G_model(bev)                      # [1, C, H, W]
            pred_mask = recon.argmax(1).squeeze(0).cpu().numpy()  # [H, W]
            bev_mask = bev.argmax(1).squeeze(0).cpu().numpy()     # [H_bev, W_bev]

            # Decode for visualization
            orig_vis = decode_segmap(mask.squeeze(0).cpu().numpy(), color_map)
            recon_vis = decode_segmap(pred_mask, color_map)
            bev_vis = decode_segmap(bev_mask, color_map)

            # Plot
            fig, axs = plt.subplots(1, 3, figsize=(12, 4))
            axs[0].imshow(orig_vis)
            axs[0].set_title("Original Seg")
            axs[1].imshow(bev_vis)
            axs[1].set_title("BEV Predicted")
            axs[2].imshow(recon_vis)
            axs[2].set_title("Reconstructed Seg")
            for ax in axs: ax.axis('off')
            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, f"sample_{idx:03d}.png"))
            plt.close()

In [7]:
color_map = [
    (0, 255, 0),      # sea
    (0, 0, 255),      # obstacle
    (255, 0, 0),      # sky
    (0, 0, 0)         # background / ignored
]

dataset = DirectoryNPYDataset("./predicted_masks/train")
visualize(
    models=(F_model, G_model),
    dataset=dataset,
    color_map=color_map,
    num_samples=10
)