# OLD

In [6]:
from layers import Lorentz_fully_connected
from layers import Lorentz
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

In [14]:
from layers import Lorentz_fully_connected
from layers import Lorentz
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


class EuclideanToLorentz(nn.Module):
    """Project Euclidean features onto Lorentz manifold."""
    
    def __init__(self, in_features, out_features, manifold):
        """
        Args:
            in_features: Euclidean input dimension (e.g., 784 for MNIST)
            out_features: Lorentz output dimension INCLUDING time (e.g., 100 means 99 space + 1 time)
        """
        super().__init__()
        self.manifold = manifold
        self.linear = nn.Linear(in_features, out_features - 1)  # Output space components only
        
        # Small init to keep points near origin initially (more stable)
        nn.init.xavier_uniform_(self.linear.weight, gain=0.1)
        nn.init.zeros_(self.linear.bias)
    
    def forward(self, x):
        """
        Args:
            x: [batch, in_features] Euclidean vectors
        Returns:
            [batch, out_features] points on Lorentz manifold
        """
        space = self.linear(x)  # [batch, out_features - 1]
        return self.manifold.projection_space_orthogonal(space)  # [batch, out_features]
    
class EuclideanToLorentzConv(nn.Module):
    """Project Euclidean image onto Lorentz manifold via 1x1 conv."""
    
    def __init__(self, in_channels, out_channels, manifold):
        """
        Args:
            in_channels: Euclidean channels (e.g., 3 for RGB)
            out_channels: Lorentz channels INCLUDING time (e.g., 16 means 15 space + 1 time)
        """
        super().__init__()
        self.manifold = manifold
        self.conv = nn.Conv2d(in_channels, out_channels - 1, kernel_size=1)
        
        nn.init.xavier_uniform_(self.conv.weight, gain=0.1)
        nn.init.zeros_(self.conv.bias)
    
    def forward(self, x):
        """
        Args:
            x: [batch, in_channels, H, W] Euclidean image
        Returns:
            [batch, out_channels, H, W] on Lorentz manifold (each pixel is a Lorentz point)
        """
        space = self.conv(x)  # [batch, out_channels - 1, H, W]
        
        # Compute time component for each pixel
        # time = sqrt(||space||^2 + 1/k)
        time = torch.sqrt((space ** 2).sum(dim=1, keepdim=True) + 1.0 / self.manifold.k())
        
        return torch.cat([time, space], dim=1)  # [batch, out_channels, H, W]

class LorentzResidualMidpoint(nn.Module):
    """Residual via weighted Lorentz midpoint."""
    
    def __init__(self, dim, manifold, activation):
        super().__init__()
        self.manifold = manifold
        self.fc = Lorentz_fully_connected(
            in_features=dim,
            out_features=dim,
            manifold=manifold,
            reset_params="kaiming",
            activation=activation
        )
        # Learnable weight (0.5 = equal weighting)
        self.alpha_logit = nn.Parameter(torch.tensor(0.0))  # sigmoid(0) = 0.5
    
    def forward(self, x):
        out = self.fc(x)
        
        # Weighted midpoint on manifold
        alpha = torch.sigmoid(self.alpha_logit)
        # Stack for centroid computation: [batch, 2, dim]
        stacked = torch.stack([x, out], dim=-2)
        
        # Weights: [1, 2] -> broadcast to [batch, 2]
        weights = torch.tensor([[1 - alpha, alpha]], device=x.device)
        return self.manifold.lorentz_midpoint(stacked, weights)
    

class LorentzConv2d(nn.Module):
    """
    Lorentz Conv2d using direct concatenation + existing Lorentz FC.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int | tuple[int, int],
        stride: int | tuple[int, int],
        padding: int | tuple[int, int],
        manifold: Lorentz,
        activation,
    ):
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)
        
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # After concatenating k*k Lorentz points:
        # concat_dim = 1 + (in_channels - 1) * k * k
        concat_dim = 1 + (in_channels - 1) * kernel_size[0] * kernel_size[1]
        
        # Reuse existing Lorentz FC
        self.fc = Lorentz_fully_connected(
            in_features=concat_dim,
            out_features=out_channels,
            manifold=self.manifold,
            activation=activation,
            reset_params="orthogonal"
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch, in_channels, H, W]
        Returns:
            [batch, out_channels, H', W']
        """
        batch, C, H, W = x.shape
        kH, kW = self.kernel_size
        sH, sW = self.stride
        pH, pW = self.padding
        
        # Pad with origin points
        if pH > 0 or pW > 0:
            sqrt_k_inv = (1.0 / self.manifold.k()).sqrt()
            x = F.pad(x, (pW, pW, pH, pH), mode='constant', value=0.0)
            _, _, H_pad, W_pad = x.shape
            
            # Fix time component in padded regions
            mask = torch.ones(1, 1, H_pad, W_pad, device=x.device, dtype=x.dtype)
            mask[:, :, pH:pH+H, pW:pW+W] = 0
            x[:, 0:1] = x[:, 0:1] * (1 - mask) + sqrt_k_inv * mask
        
        # Unfold: [batch, C * kH * kW, num_patches]
        patches = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride)
        
        # Reshape to [batch, num_patches, kH * kW, C]
        num_patches = patches.shape[-1]
        patches = patches.view(batch, C, kH * kW, num_patches)
        patches = patches.permute(0, 3, 2, 1)  # [batch, num_patches, k*k, C]
        
        # Direct concat: [batch * num_patches, k*k, C] -> [batch * num_patches, concat_dim]
        patches_flat = patches.reshape(batch * num_patches, kH * kW, C)
        concat_points = self.manifold.direct_concat(patches_flat)
        
        # Apply Lorentz FC: [batch * num_patches, concat_dim] -> [batch * num_patches, out_channels]
        out = self.fc(concat_points)
        
        # Reshape to spatial: [batch, out_channels, H', W']
        H_out = (H + 2 * pH - kH) // sH + 1
        W_out = (W + 2 * pW - kW) // sW + 1
        out = out.view(batch, H_out, W_out, -1).permute(0, 3, 1, 2)
        
        return out
    
class LorentzMLPWithResidual(nn.Module):
    def __init__(
        self, 
        input_dim,
        hidden_dim,
        num_classes,
        num_layers=3,
        manifold=None,
        activation=F.relu
    ):
    
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)
        
        # Input projection
        self.input_proj = EuclideanToLorentz(input_dim, hidden_dim, self.manifold)
        
        # Select residual block type
        block_cls = LorentzResidualMidpoint
        
        # Hidden layers with residuals
        self.layers = nn.ModuleList([
            block_cls(hidden_dim, self.manifold, activation=activation)
            for _ in range(num_layers - 1)
        ])
        
        # Classifier
        self.classifier = Lorentz_fully_connected(
            in_features=hidden_dim,
            out_features=num_classes + 1,
            manifold=self.manifold,
            reset_params="kaiming",
            do_mlr=True
        )
    
    def forward(self, x):
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        
        x = self.input_proj(x)
        
        for layer in self.layers:
            x = layer(x)
        
        return self.classifier(x)
    

class LorentzConvNet(nn.Module):
    def __init__(
        self, 
        input_dim,
        hidden_dim,
        num_classes,
        num_layers=3,
        manifold=None,
        activation=F.relu
    ):
    
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)
        
        # Input projection
        self.input_proj = EuclideanToLorentzConv(input_dim, hidden_dim, self.manifold)
        
        # Select residual block type
        self.layer1 = LorentzConv2d(hidden_dim, hidden_dim, 3, 1, 0, manifold, activation)
        
        
        # Classifier
        self.classifier = Lorentz_fully_connected(
            in_features=hidden_dim,
            out_features=num_classes + 1,
            manifold=self.manifold,
            reset_params="kaiming",
            do_mlr=True
        )
    
    def forward(self, x):
        print(x.shape)
        x = self.input_proj(x)
        print(x.shape)
        
        x = self.layer1(x)
        print(x.shape)
        
        return self.classifier(x)

In [7]:
class EuclideanToLorentz(nn.Module):
    """Project Euclidean features onto Lorentz manifold."""
    
    def __init__(self, in_features, out_features, manifold):
        """
        Args:
            in_features: Euclidean input dimension (e.g., 784 for MNIST)
            out_features: Lorentz output dimension INCLUDING time (e.g., 100 means 99 space + 1 time)
        """
        super().__init__()
        self.manifold = manifold
        self.linear = nn.Linear(in_features, out_features - 1)  # Output space components only
        
        # Small init to keep points near origin initially (more stable)
        nn.init.xavier_uniform_(self.linear.weight, gain=0.1)
        nn.init.zeros_(self.linear.bias)
    
    def forward(self, x):
        """
        Args:
            x: [batch, in_features] Euclidean vectors
        Returns:
            [batch, out_features] points on Lorentz manifold
        """
        space = self.linear(x)  # [batch, out_features - 1]
        return self.manifold.projection_space_orthogonal(space)  # [batch, out_features]

In [8]:
class LorentzResidualMidpoint(nn.Module):
    """Residual via weighted Lorentz midpoint."""
    
    def __init__(self, dim, manifold, activation):
        super().__init__()
        self.manifold = manifold
        self.fc = Lorentz_fully_connected(
            in_features=dim,
            out_features=dim,
            manifold=manifold,
            reset_params="kaiming",
            activation=activation
        )
        # Learnable weight (0.5 = equal weighting)
        self.alpha_logit = nn.Parameter(torch.tensor(0.0))  # sigmoid(0) = 0.5
    
    def forward(self, x):
        out = self.fc(x)
        
        # Weighted midpoint on manifold
        alpha = torch.sigmoid(self.alpha_logit)
        # Stack for centroid computation: [batch, 2, dim]
        stacked = torch.stack([x, out], dim=-2)
        
        # Weights: [1, 2] -> broadcast to [batch, 2]
        weights = torch.tensor([[1 - alpha, alpha]], device=x.device)
        return self.manifold.lorentz_midpoint(stacked, weights)

In [9]:
class LorentzMLPWithResidual(nn.Module):
    def __init__(
        self, 
        input_dim,
        hidden_dim,
        num_classes,
        num_layers=3,
        manifold=None,
        activation=F.relu
    ):
    
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)
        
        # Input projection
        self.input_proj = EuclideanToLorentz(input_dim, hidden_dim, self.manifold)
        
        # Select residual block type
        block_cls = LorentzResidualMidpoint
        
        # Hidden layers with residuals
        self.layers = nn.ModuleList([
            block_cls(hidden_dim, self.manifold, activation=activation)
            for _ in range(num_layers - 1)
        ])
        
        # Classifier
        self.classifier = Lorentz_fully_connected(
            in_features=hidden_dim,
            out_features=num_classes + 1,
            manifold=self.manifold,
            reset_params="kaiming",
            do_mlr=True
        )
    
    def forward(self, x):
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        
        x = self.input_proj(x)
        
        for layer in self.layers:
            x = layer(x)
        
        return self.classifier(x)

In [12]:
class LorentzConv2d(nn.Module):
    """Residual via weighted Lorentz midpoint."""
    
    def __init__(self, dim, manifold, activation):
        super().__init__()
        self.manifold = manifold
        self.fc = Lorentz_fully_connected(
            in_features=dim,
            out_features=dim,
            manifold=manifold,
            reset_params="kaiming",
            activation=activation
        )
        # Learnable weight (0.5 = equal weighting)
        self.alpha_logit = nn.Parameter(torch.tensor(0.0))  # sigmoid(0) = 0.5
    
    def forward(self, x):
        print(x.shape)

        out = self.fc(x)
        
        # Weighted midpoint on manifold
        alpha = torch.sigmoid(self.alpha_logit)
        # Stack for centroid computation: [batch, 2, dim]
        stacked = torch.stack([x, out], dim=-2)
        
        # Weights: [1, 2] -> broadcast to [batch, 2]
        weights = torch.tensor([[1 - alpha, alpha]], device=x.device)
        return self.manifold.lorentz_midpoint(stacked, weights)

In [17]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR100("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276))]
    ))
valset = torchvision.datasets.CIFAR100("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzConvNet(input_dim=3, hidden_dim=16, num_classes=100, num_layers=5, manifold=manifold, activation=nn.Identity()).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

torch.Size([128, 3, 32, 32])
torch.Size([128, 16, 32, 32])
torch.Size([128, 16, 30, 30])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (61440x30 and 16x100)

In [10]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR100("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
valset = torchvision.datasets.CIFAR100("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzMLPWithResidual(input_dim=3072, hidden_dim=100, num_classes=100, num_layers=5, manifold=manifold, activation=nn.Identity()).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 4.60907506942749
running loss: 4.254255323586646
training acc: tensor(0.0696, device='cuda:0')
val loss: 3.973182211303711
val acc: tensor(0.0831, device='cuda:0')
running loss: 3.8841326236724854
running loss: 3.876686127612739
training acc: tensor(0.1086, device='cuda:0')
val loss: 3.8230095092773437
val acc: tensor(0.1161, device='cuda:0')
running loss: 3.610873222351074
running loss: 3.7479927436284575
training acc: tensor(0.1282, device='cuda:0')
val loss: 3.783591979217529
val acc: tensor(0.1313, device='cuda:0')
running loss: 3.6887619495391846
running loss: 3.7151483322671344
training acc: tensor(0.1398, device='cuda:0')
val loss: 3.7308448249816895
val acc: tensor(0.1421, device='cuda:0')
running loss: 3.588874340057373
running loss: 3.664122220498964
training acc: tensor(0.1469, device='cuda:0')
val loss: 3.728555487060547
val acc: tensor(0.1392, device='cuda:0')
running loss: 3.761927843093872
running loss: 3.668708463334047
training acc: tensor(0.1516, device=

In [11]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR100("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
valset = torchvision.datasets.CIFAR100("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzMLPWithResidual(input_dim=3072, hidden_dim=100, num_classes=100, num_layers=5, manifold=manifold, activation=nn.ReLU()).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 4.607980728149414
running loss: 4.343890246799889
training acc: tensor(0.0586, device='cuda:0')
val loss: 3.9989096130371093
val acc: tensor(0.0811, device='cuda:0')
running loss: 4.106558799743652
running loss: 3.9393786040225773
training acc: tensor(0.0965, device='cuda:0')
val loss: 3.87111513671875
val acc: tensor(0.1046, device='cuda:0')
running loss: 3.8728296756744385
running loss: 3.82414895722909
training acc: tensor(0.1138, device='cuda:0')
val loss: 3.8022483505249025
val acc: tensor(0.1194, device='cuda:0')
running loss: 3.7738242149353027
running loss: 3.742089424995754
training acc: tensor(0.1291, device='cuda:0')
val loss: 3.7483328674316407
val acc: tensor(0.1305, device='cuda:0')
running loss: 3.7000441551208496
running loss: 3.6871290674865054
training acc: tensor(0.1384, device='cuda:0')
val loss: 3.6990942077636717
val acc: tensor(0.1390, device='cuda:0')
running loss: 3.7008056640625
running loss: 3.628138753147013
training acc: tensor(0.1502, device=

In [None]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR100("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
valset = torchvision.datasets.CIFAR100("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzMLPWithResidual(input_dim=3072, hidden_dim=100, num_classes=100, num_layers=5, manifold=manifold, activation=nn.ReLU()).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

In [42]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR10("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
valset = torchvision.datasets.CIFAR10("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzMLPWithResidual(input_dim=3072, hidden_dim=100, num_classes=10, num_layers=5, manifold=manifold, activation=F.relu).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 2.3018221855163574
running loss: 1.9437158898051308
training acc: tensor(0.3459, device='cuda:0')
val loss: 1.6798603118896485
val acc: tensor(0.4022, device='cuda:0')
running loss: 1.688948631286621
running loss: 1.6294594895824834
training acc: tensor(0.4332, device='cuda:0')
val loss: 1.5643589988708495
val acc: tensor(0.4461, device='cuda:0')
running loss: 1.5382037162780762
running loss: 1.515978446410169
training acc: tensor(0.4677, device='cuda:0')
val loss: 1.499418251800537
val acc: tensor(0.4711, device='cuda:0')
running loss: 1.3877149820327759
running loss: 1.4477123211971614
training acc: tensor(0.4901, device='cuda:0')
val loss: 1.4749658712387086
val acc: tensor(0.4763, device='cuda:0')
running loss: 1.5040117502212524
running loss: 1.4133201105756714
training acc: tensor(0.5064, device='cuda:0')
val loss: 1.4511191493988038
val acc: tensor(0.4891, device='cuda:0')
running loss: 1.3741260766983032
running loss: 1.366600154489454
training acc: tensor(0.5265,

In [33]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR10("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
valset = torchvision.datasets.CIFAR10("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzMLPWithResidual(input_dim=3072, hidden_dim=512, num_classes=10, num_layers=5, manifold=manifold).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 2.305220603942871
running loss: 1.8644284791524657
training acc: tensor(0.3703, device='cuda:0')
val loss: 1.615155252456665
val acc: tensor(0.4336, device='cuda:0')
running loss: 1.5587794780731201
running loss: 1.5548059872580797
training acc: tensor(0.4559, device='cuda:0')
val loss: 1.530853825378418
val acc: tensor(0.4639, device='cuda:0')
running loss: 1.5610239505767822
running loss: 1.4520017369707332
training acc: tensor(0.4970, device='cuda:0')
val loss: 1.440306875038147
val acc: tensor(0.5041, device='cuda:0')
running loss: 1.3051495552062988
running loss: 1.3428152725278046
training acc: tensor(0.5306, device='cuda:0')
val loss: 1.3976861961364746
val acc: tensor(0.5110, device='cuda:0')
running loss: 1.1987639665603638
running loss: 1.265783142094532
training acc: tensor(0.5550, device='cuda:0')
val loss: 1.355179746246338
val acc: tensor(0.5272, device='cuda:0')
running loss: 1.301975131034851
running loss: 1.2303708101177118
training acc: tensor(0.5756, de

In [34]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR10("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
valset = torchvision.datasets.CIFAR10("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzMLPWithResidual(input_dim=3072, hidden_dim=32, num_classes=10, num_layers=5, manifold=manifold).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 2.3058390617370605
running loss: 2.0211340024546214
training acc: tensor(0.3046, device='cuda:0')
val loss: 1.7915371826171875
val acc: tensor(0.3596, device='cuda:0')
running loss: 1.9343494176864624
running loss: 1.7730843414496886
training acc: tensor(0.3812, device='cuda:0')
val loss: 1.678443069458008
val acc: tensor(0.4012, device='cuda:0')
running loss: 1.572021484375
running loss: 1.6336535065898254
training acc: tensor(0.4160, device='cuda:0')
val loss: 1.6309209560394287
val acc: tensor(0.4213, device='cuda:0')
running loss: 1.557555913925171
running loss: 1.5806652708439088
training acc: tensor(0.4370, device='cuda:0')
val loss: 1.597140937614441
val acc: tensor(0.4342, device='cuda:0')
running loss: 1.4735374450683594
running loss: 1.5395157643996717
training acc: tensor(0.4512, device='cuda:0')
val loss: 1.5782439380645752
val acc: tensor(0.4410, device='cuda:0')
running loss: 1.6799933910369873
running loss: 1.544618162205239
training acc: tensor(0.4625, dev

In [37]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR10("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
valset = torchvision.datasets.CIFAR10("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzMLPWithResidual(input_dim=3072, hidden_dim=512, num_classes=10, num_layers=10, manifold=manifold).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 2.302795886993408
running loss: 2.0442883380780663
training acc: tensor(0.2520, device='cuda:0')
val loss: 1.8573101421356202
val acc: tensor(0.2844, device='cuda:0')
running loss: 1.932027816772461
running loss: 1.8325329554424292
training acc: tensor(0.3269, device='cuda:0')
val loss: 1.708279984664917
val acc: tensor(0.3707, device='cuda:0')
running loss: 1.7115061283111572
running loss: 1.6660110259129737
training acc: tensor(0.4009, device='cuda:0')
val loss: 1.6019438676834106
val acc: tensor(0.4154, device='cuda:0')
running loss: 1.6104451417922974
running loss: 1.5641192649609441
training acc: tensor(0.4405, device='cuda:0')
val loss: 1.5702522071838378
val acc: tensor(0.4346, device='cuda:0')
running loss: 1.646928310394287
running loss: 1.5112884959712634
training acc: tensor(0.4695, device='cuda:0')
val loss: 1.509370757675171
val acc: tensor(0.4672, device='cuda:0')
running loss: 1.2755305767059326
running loss: 1.41075922000999
training acc: tensor(0.4980, de

In [None]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR10("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
valset = torchvision.datasets.CIFAR10("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),
    torchvision.transforms.Lambda(lambda x: x.view(-1))]  # Flatten to [784]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzMLPWithResidual(input_dim=3072, hidden_dim=512, num_classes=10, num_layers=10, manifold=manifold).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

tensor([96, 90, 14, 77, 65,  7, 75, 27, 16, 30, 50, 83, 14, 51, 42, 70])

# NEW

In [136]:
from layers import Lorentz_fully_connected
from layers import Lorentz
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision



class EuclideanToLorentzConv(nn.Module):
    """Project Euclidean image onto Lorentz manifold via 1x1 conv."""
    
    def __init__(self, in_channels, out_channels, manifold):
        """
        Args:
            in_channels: Euclidean channels (e.g., 3 for RGB)
            out_channels: Lorentz channels INCLUDING time (e.g., 16 means 15 space + 1 time)
        """
        super().__init__()
        self.manifold = manifold
        self.conv = nn.Conv2d(in_channels, out_channels - 1, kernel_size=1)
        
        nn.init.xavier_uniform_(self.conv.weight, gain=0.1)
        nn.init.zeros_(self.conv.bias)
    
    def forward(self, x):
        """
        Args:
            x: [batch, in_channels, H, W] Euclidean image
        Returns:
            [batch, out_channels, H, W] on Lorentz manifold (each pixel is a Lorentz point)
        """
        space = self.conv(x)  # [batch, out_channels - 1, H, W]
        
        # Compute time component for each pixel
        # time = sqrt(||space||^2 + 1/k)
        time = torch.sqrt((space ** 2).sum(dim=1, keepdim=True) + 1.0 / self.manifold.k())
        
        return torch.cat([time, space], dim=1)  # [batch, out_channels, H, W]

class LorentzResidualMidpoint(nn.Module):
    """Residual via weighted Lorentz midpoint."""
    
    def __init__(self, dim, manifold, activation):
        super().__init__()
        self.manifold = manifold
        self.fc = Lorentz_fully_connected(
            in_features=dim,
            out_features=dim,
            manifold=manifold,
            reset_params="kaiming",
            activation=activation
        )
        # Learnable weight (0.5 = equal weighting)
        self.alpha_logit = nn.Parameter(torch.tensor(0.0))  # sigmoid(0) = 0.5
    
    def forward(self, x):
        out = self.fc(x)
        
        # Weighted midpoint on manifold
        alpha = torch.sigmoid(self.alpha_logit)
        # Stack for centroid computation: [batch, 2, dim]
        stacked = torch.stack([x, out], dim=-2)
        
        # Weights: [1, 2] -> broadcast to [batch, 2]
        weights = torch.tensor([[1 - alpha, alpha]], device=x.device)
        return self.manifold.lorentz_midpoint(stacked, weights)
    

class LorentzConv2d(nn.Module):
    """
    Lorentz Conv2d using direct concatenation + existing Lorentz FC.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int | tuple[int, int],
        stride: int | tuple[int, int],
        padding: int | tuple[int, int],
        manifold: Lorentz,
        activation,
    ):
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)
        
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # After concatenating k*k Lorentz points:
        # concat_dim = 1 + (in_channels - 1) * k * k
        concat_dim = 1 + (in_channels - 1) * kernel_size[0] * kernel_size[1]
        
        # Reuse existing Lorentz FC
        self.fc = Lorentz_fully_connected(
            in_features=concat_dim,
            out_features=out_channels,
            manifold=self.manifold,
            activation=activation,
            reset_params="kaiming"
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch, in_channels, H, W]
        Returns:
            [batch, out_channels, H', W']
        """
        batch, C, H, W = x.shape
        kH, kW = self.kernel_size
        sH, sW = self.stride
        pH, pW = self.padding
        
        # Pad with origin points
        if pH > 0 or pW > 0:
            sqrt_k_inv = (1.0 / self.manifold.k()).sqrt()
            x = F.pad(x, (pW, pW, pH, pH), mode='constant', value=0.0)
            _, _, H_pad, W_pad = x.shape
            
            # Fix time component in padded regions
            mask = torch.ones(1, 1, H_pad, W_pad, device=x.device, dtype=x.dtype)
            mask[:, :, pH:pH+H, pW:pW+W] = 0
            x[:, 0:1] = x[:, 0:1] * (1 - mask) + sqrt_k_inv * mask
        
        # Unfold: [batch, C * kH * kW, num_patches]
        patches = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride)
        
        # Reshape to [batch, num_patches, kH * kW, C]
        num_patches = patches.shape[-1]
        patches = patches.view(batch, C, kH * kW, num_patches)
        patches = patches.permute(0, 3, 2, 1)  # [batch, num_patches, k*k, C]
        
        # Direct concat: [batch * num_patches, k*k, C] -> [batch * num_patches, concat_dim]
        patches_flat = patches.reshape(batch * num_patches, kH * kW, C)
        concat_points = self.manifold.direct_concat(patches_flat)
        
        # Apply Lorentz FC: [batch * num_patches, concat_dim] -> [batch * num_patches, out_channels]
        out = self.fc(concat_points)
        
        # Reshape to spatial: [batch, out_channels, H', W']
        H_out = (H + 2 * pH - kH) // sH + 1
        W_out = (W + 2 * pW - kW) // sW + 1
        out = out.view(batch, H_out, W_out, -1).permute(0, 3, 1, 2)
        
        return out

class LorentzResBlock(nn.Module):
    def __init__(
        self, 
        input_dim,
        output_dim,
        kernel_size,
        stride,
        padding,
        manifold,
        activation,
    ):
    
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)

        self.layer1 = LorentzConv2d(in_channels=input_dim, out_channels=input_dim, kernel_size=kernel_size, stride=1, padding=padding, manifold=manifold, activation=activation)
        self.bn1 = LorentzBatchNorm2d(num_features = input_dim, manifold=manifold)
        self.layer2 = LorentzConv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=kernel_size, stride=stride, padding=padding, manifold=manifold, activation=nn.Identity())
        self.bn2 = LorentzBatchNorm2d(num_features = output_dim, manifold=manifold)
        if input_dim != output_dim:
            self.proj = LorentzConv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=1, stride=stride, padding=0, manifold=manifold, activation=nn.Identity())
        else:
            self.proj = nn.Identity()

        self.alpha_logit = nn.Parameter(torch.tensor(0.0))  # sigmoid(0) = 0.5

    
    def forward(self, x):
        x2 = self.layer1(x)
        x2 = self.bn1(x2)
        x2 = self.layer2(x2)
        x2 = self.bn2(x2)
        x = self.proj(x)


        x = x.permute(0, 2, 3, 1)
        x2 = x2.permute(0, 2, 3, 1)
        stacked = torch.stack([x, x2], dim=-2)
        alpha = torch.sigmoid(self.alpha_logit)
        weights = torch.stack([1-alpha, alpha])
        x = self.manifold.lorentz_midpoint(stacked, weights)

        x = x.permute(0, 3, 1, 2)

        return x

class LorentzConvNet(nn.Module):
    def __init__(
        self, 
        input_dim,
        hidden_dim,
        num_classes,
        num_layers,
        manifold,
        activation
    ):
    
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)
        
        # Input projection
        self.input_proj = EuclideanToLorentzConv(input_dim, hidden_dim, self.manifold)       

        self.resblock1 = LorentzResBlock(input_dim=hidden_dim, output_dim=hidden_dim*2, kernel_size=3, stride=2, padding=1, manifold=manifold, activation=activation)
        self.resblock2 = LorentzResBlock(input_dim=hidden_dim*2, output_dim=hidden_dim*4, kernel_size=3, stride=2, padding=1, manifold=manifold, activation=activation)
        self.resblock3 = LorentzResBlock(input_dim=hidden_dim*4, output_dim=hidden_dim*8, kernel_size=3, stride=2, padding=1, manifold=manifold, activation=activation)
        self.resblock4 = LorentzResBlock(input_dim=hidden_dim*8, output_dim=hidden_dim*16, kernel_size=3, stride=2, padding=1, manifold=manifold, activation=activation)
        self.resblock5 = LorentzResBlock(input_dim=hidden_dim*16, output_dim=hidden_dim*32, kernel_size=3, stride=2, padding=1, manifold=manifold, activation=activation)
        
        # Classifier
        self.classifier = Lorentz_fully_connected(
            in_features=hidden_dim*32,
            out_features=num_classes + 1,
            manifold=self.manifold,
            reset_params="kaiming",
            do_mlr=True
        )
    
    def forward(self, x):
        x = self.input_proj(x)
        # print(x[0, :, 0, 0])
        x = self.resblock1(x)
        # print(x[0, :, 0, 0])
        x = self.resblock2(x)
        x = self.resblock3(x)
        x = self.resblock4(x)
        x = self.resblock5(x)

        x = x.squeeze(-1).squeeze(-1)
        if len(x.shape) == 2:
            return self.classifier(x)
        else:
            x = x.view(x.shape[0], x.shape[1], -1)
            x = x.permute(0, 2, 1)
            x = self.manifold.lorentz_midpoint(x)
            return self.classifier(x)

class LorentzBatchNorm2d(nn.Module):
    """
    Lorentz Batch Normalization following Bdeir et al.
    Simplified to use manifold primitives.
    """
    
    def __init__(
        self,
        num_features: int,
        manifold: Lorentz = None,
        momentum: float = 0.1,
        eps: float = 1e-5,
    ):
        super().__init__()
        self.manifold = manifold or Lorentz(k=1.0)
        self.num_features = num_features
        self.momentum = momentum
        self.eps = eps
        
        # Learnable scale (positive real)
        self.gamma = nn.Parameter(torch.ones(1, num_features - 1, 1, 1))
        
        # Learnable shift (space components, will be projected to manifold)
        self.beta_space = nn.Parameter(torch.zeros(1, num_features - 1, 1, 1))
        
        # Running statistics (store space components of centroid)
        self.register_buffer('running_mean_space', torch.zeros(1, num_features - 1, 1, 1))
        self.register_buffer('running_var', torch.ones(1))
    
    def _to_flat(self, x: torch.Tensor) -> torch.Tensor:
        """[B, C, H, W] -> [B*H*W, C]"""
        return x.permute(0, 2, 3, 1).reshape(-1, x.shape[1])
    
    def _to_spatial(self, x_flat: torch.Tensor, batch: int, H: int, W: int) -> torch.Tensor:
        """[B*H*W, C] -> [B, C, H, W]"""
        return x_flat.view(batch, H, W, -1).permute(0, 3, 1, 2)
    
    def _compute_centroid(self, x: torch.Tensor) -> torch.Tensor:
        """Compute Lorentz centroid using manifold method."""
        batch, C, H, W = x.shape
        x_flat = self._to_flat(x)  # [N, C]
        
        # lorentz_midpoint expects [..., num_points, dim]
        # We want centroid over all N points, so reshape to [1, N, C]
        centroid = self.manifold.lorentz_midpoint(x_flat.unsqueeze(0))  # [1, C]
        
        return centroid.view(1, C, 1, 1)
    
    def _compute_variance(self, x: torch.Tensor, mean: torch.Tensor) -> torch.Tensor:
        """Compute Fréchet variance (mean squared geodesic distance)."""
        batch, C, H, W = x.shape
        x_flat = self._to_flat(x)  # [N, C]
        mean_flat = mean.view(1, C).expand(x_flat.shape[0], -1)  # [N, C]
        
        # Use manifold distance
        dist_sq = self.manifold.dist(x_flat, mean_flat, keepdim=False) ** 2
        return dist_sq.mean()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, C, H, W = x.shape
        
        if self.training:
            mean = self._compute_centroid(x)
            var = self._compute_variance(x, mean)
            
            with torch.no_grad():
                self.running_mean_space.mul_(1 - self.momentum).add_(
                    mean[:, 1:, :, :] * self.momentum
                )
                self.running_var.mul_(1 - self.momentum).add_(var * self.momentum)
        else:
            # Reconstruct mean from running space components
            mean_space_flat = self.running_mean_space.view(1, -1)  # [1, C-1]
            mean_flat = self.manifold.projection_space_orthogonal(mean_space_flat)  # [1, C]
            mean = mean_flat.view(1, C, 1, 1)
            var = self.running_var
        
        # Flatten for manifold operations
        x_flat = self._to_flat(x)  # [N, C]
        mean_flat = mean.view(1, C).expand(x_flat.shape[0], -1)
        
        # Origin point
        origin = self.manifold.origin(C).unsqueeze(0).expand(x_flat.shape[0], -1)
        
        # 1. Log map: get tangent vector at mean pointing to x
        # logmap expects [batch, dim] for base and [batch, m, dim] for target
        v_at_mean = self.manifold.logmap(mean_flat, x_flat.unsqueeze(1)).squeeze(1)  # [N, C]
        
        # 2. Parallel transport tangent vector from mean to origin
        v_at_origin = self.manifold.parallel_transport(mean_flat, v_at_mean.unsqueeze(1), origin).squeeze(1)
        
        # 3. Scale in tangent space (only space components, time should stay ~0)
        # Tangent vectors at origin have time ≈ 0
        v_space = v_at_origin[:, 1:]  # [N, C-1]
        gamma_flat = self.gamma.view(1, -1)  # [1, C-1]
        v_scaled_space = gamma_flat * v_space / (var.sqrt() + self.eps)
        
        # 4. Add shift (beta)
        beta_flat = self.beta_space.view(1, -1)  # [1, C-1]
        v_shifted_space = v_scaled_space + beta_flat
        
        # 5. Project back to manifold from space components
        x_out_flat = self.manifold.projection_space_orthogonal(v_shifted_space)
        
        return self._to_spatial(x_out_flat, batch, H, W)

In [2]:
from layers import LorentzConvNet
from layers import Lorentz
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

In [10]:
model.resblock1.layer1.fc.a

Parameter containing:
tensor([[ 0.0090,  0.0173, -0.0277,  0.0192, -0.0033,  0.0283, -0.0026, -0.0180,
         -0.0009,  0.0043,  0.0032, -0.0098, -0.0289, -0.0091, -0.0326]],
       device='cuda:0', requires_grad=True)

In [3]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR100("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276))]
    ))
valset = torchvision.datasets.CIFAR100("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzConvNet(input_dim=3, hidden_dim=16, num_classes=100, num_layers=5, manifold=manifold, activation=nn.ReLU()).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 4.613415718078613
running loss: 4.109422631907814
training acc: tensor(0.0930, device='cuda:0')
val loss: 3.6462308765411375
val acc: tensor(0.1336, device='cuda:0')
running loss: 3.6820125579833984
running loss: 3.5163753930030177
training acc: tensor(0.1701, device='cuda:0')
val loss: 3.270312439727783
val acc: tensor(0.2078, device='cuda:0')
running loss: 3.1570048332214355
running loss: 3.164134492874118
training acc: tensor(0.2300, device='cuda:0')
val loss: 3.1041206703186037
val acc: tensor(0.2369, device='cuda:0')
running loss: 2.9298155307769775
running loss: 2.9147103132115295
training acc: tensor(0.2700, device='cuda:0')
val loss: 2.9680451629638673
val acc: tensor(0.2616, device='cuda:0')
running loss: 2.683765411376953
running loss: 2.7043033775335155
training acc: tensor(0.3152, device='cuda:0')
val loss: 2.83157795791626
val acc: tensor(0.2951, device='cuda:0')
running loss: 2.8003339767456055
running loss: 2.545985702709308
training acc: tensor(0.3519, dev

In [142]:
model.resblock5.alpha_logit

Parameter containing:
tensor(-1.8509, device='cuda:0', requires_grad=True)

In [128]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR100("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276))]
    ))
valset = torchvision.datasets.CIFAR100("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzConvNet(input_dim=3, hidden_dim=16, num_classes=100, num_layers=5, manifold=manifold, activation=nn.ReLU()).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 4.60404634475708
running loss: 4.133440152762814
training acc: tensor(0.0987, device='cuda:0')
val loss: 3.6226573753356934
val acc: tensor(0.1476, device='cuda:0')
running loss: 3.6821200847625732
running loss: 3.5466236492681404
training acc: tensor(0.1664, device='cuda:0')
val loss: 3.3628051738739013
val acc: tensor(0.1944, device='cuda:0')
running loss: 3.363978862762451
running loss: 3.2429910072192207
training acc: tensor(0.2195, device='cuda:0')
val loss: 3.142463459777832
val acc: tensor(0.2352, device='cuda:0')
running loss: 2.9758212566375732
running loss: 2.97109549545084
training acc: tensor(0.2640, device='cuda:0')
val loss: 2.988465799713135
val acc: tensor(0.2667, device='cuda:0')
running loss: 2.5802388191223145
running loss: 2.7540598730182064
training acc: tensor(0.3055, device='cuda:0')
val loss: 2.858040071105957
val acc: tensor(0.2932, device='cuda:0')
running loss: 2.7477383613586426
running loss: 2.5997239111926818
training acc: tensor(0.3400, devi

In [126]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR100("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276))]
    ))
valset = torchvision.datasets.CIFAR100("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzConvNet(input_dim=3, hidden_dim=16, num_classes=100, num_layers=5, manifold=manifold, activation=nn.ReLU()).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.compile()
data, target = next(iter(train_loader))
for _ in range(50):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 4.614555358886719
running loss: 4.178154148873151
training acc: tensor(0.0845, device='cuda:0')
val loss: 3.6940881423950196
val acc: tensor(0.1257, device='cuda:0')
running loss: 3.7043490409851074
running loss: 3.5906297556455886
training acc: tensor(0.1536, device='cuda:0')
val loss: 3.441946120071411
val acc: tensor(0.1725, device='cuda:0')
running loss: 3.2883236408233643
running loss: 3.2957802840544117
training acc: tensor(0.2022, device='cuda:0')
val loss: 3.221691360473633
val acc: tensor(0.2182, device='cuda:0')
running loss: 3.1265478134155273
running loss: 3.083497251221289
training acc: tensor(0.2475, device='cuda:0')
val loss: 3.0751177307128907
val acc: tensor(0.2461, device='cuda:0')
running loss: 2.826460838317871
running loss: 2.8660793152219077
training acc: tensor(0.2896, device='cuda:0')
val loss: 2.9953153076171875
val acc: tensor(0.2717, device='cuda:0')
running loss: 2.6801528930664062
running loss: 2.6513407267035847
training acc: tensor(0.3300, d

In [119]:
manifold = Lorentz(k=1.0)
trainset = torchvision.datasets.CIFAR100("./cifar", train=True, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276))]
    ))
valset = torchvision.datasets.CIFAR100("./cifar", train=False, download=True, transform=torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5074, 0.4867, 0.4411), (0.267, 0.256, 0.276)),]
    ))
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False)
model = LorentzConvNet(input_dim=3, hidden_dim=16, num_classes=100, num_layers=5, manifold=manifold, activation=nn.Identity()).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data, target = next(iter(train_loader))
for _ in range(10):
    running_loss, acc, counts = 0.0, 0.0, 0
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(x.cuda()).squeeze()
        loss = F.cross_entropy(logits, y.cuda())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if running_loss == 0.0:
            running_loss = loss.item()
        else:
            running_loss = 0.99*running_loss + 0.01*loss.item()
        acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
        counts += x.shape[0]
        if step % 200 == 0:
            print("running loss:", running_loss)
    print("training acc:", acc / counts)
    with torch.no_grad():
        running_loss, acc, counts = 0.0, 0.0, 0
        for x, y in val_loader:
            logits = model(x.cuda()).squeeze()
            loss = F.cross_entropy(logits, y.cuda(), reduction='sum')
            running_loss += loss.item()
            acc += (logits.argmax(dim=1) == y.cuda()).float().sum()
            counts += x.shape[0]
        
        print("val loss:", running_loss / counts)
        print("val acc:", acc / counts)

    
        

running loss: 4.602324485778809
running loss: 4.129569056731318
training acc: tensor(0.0875, device='cuda:0')
val loss: 3.784755062866211
val acc: tensor(0.1259, device='cuda:0')
running loss: 3.887423515319824
running loss: 3.698445036380171
training acc: tensor(0.1455, device='cuda:0')
val loss: 3.5837403938293457
val acc: tensor(0.1588, device='cuda:0')
running loss: 3.560063123703003
running loss: 3.5119678685079188
training acc: tensor(0.1706, device='cuda:0')
val loss: 3.518911082458496
val acc: tensor(0.1710, device='cuda:0')
running loss: 3.5764331817626953
running loss: 3.463195377418665
training acc: tensor(0.1811, device='cuda:0')
val loss: 3.41609995803833
val acc: tensor(0.1994, device='cuda:0')
running loss: 3.3845832347869873
running loss: 3.376326591840448
training acc: tensor(0.1967, device='cuda:0')
val loss: 3.344493743133545
val acc: tensor(0.2101, device='cuda:0')
running loss: 3.1745591163635254
running loss: 3.2864272021540972
training acc: tensor(0.2073, device=