In [None]:
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank):
        super(LoRALayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.rank = rank

        # Low-rank decomposition matrices
        self.A = nn.Parameter(torch.randn(input_dim, rank))
        self.B = nn.Parameter(torch.randn(rank, output_dim))

        # Initialize the low-rank matrices
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.zeros_(self.B)

    def forward(self, x):
        # Compute the low-rank adaptation
        delta_W = torch.matmul(self.A, self.B)
        return torch.matmul(x, delta_W)

class LoRA(nn.Module):
    def __init__(self, base_model, lora_layers):
        super(LoRA, self).__init__()
        self.base_model = base_model
        self.lora_layers = nn.ModuleList(lora_layers)

    def forward(self, x):
        # Pass through the base model
        base_output = self.base_model(x)

        # Apply LoRA layers
        for lora_layer in self.lora_layers:
            base_output += lora_layer(x)

        return base_output

# Example usage
input_dim = 256
output_dim = 256
rank = 4

# Define a simple base model
class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

base_model = SimpleModel(input_dim, output_dim)
lora_layer = LoRALayer(input_dim, output_dim, rank)
lora_model = LoRA(base_model, [lora_layer])

# Create a random input tensor
x = torch.rand((32, input_dim))

# Apply the LoRA model
output = lora_model(x)
print(output.shape)  # Should output: torch.Size([32, output_dim])