In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad

class SVGDLayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, num_particles):
        super(SVGDLayer, self).__init__()
        self.rank = rank
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_particles = num_particles

        # Initialize particles
        self.particles_u = nn.ParameterList([nn.Parameter(torch.randn(input_dim, rank)) for _ in range(num_particles)])
        self.particles_v = nn.ParameterList([nn.Parameter(torch.randn(rank, output_dim)) for _ in range(num_particles)])

    def forward(self, x):
        outputs = []
        for u, v in zip(self.particles_u, self.particles_v):
            low_rank_approx = torch.matmul(u, v)
            outputs.append(torch.matmul(x, low_rank_approx))
        return torch.stack(outputs).mean(dim=0)

    def svgd_update(self, loss_fn, x, y, lr=0.001, bandwidth=1.0):
        grads_u = []
        grads_v = []
        particles = list(zip(self.particles_u, self.particles_v))

        # Compute gradients for each particle
        for u, v in particles:
            u.requires_grad_(True)
            v.requires_grad_(True)
            low_rank_approx = torch.matmul(u, v)
            output = torch.matmul(x, low_rank_approx)
            loss = loss_fn(output, y)
            grad_u, grad_v = grad(loss, [u, v], create_graph=True)
            grads_u.append(grad_u)
            grads_v.append(grad_v)

        # Update each particle
        for i, (u, v) in enumerate(particles):
            k_u = torch.zeros_like(u)
            k_v = torch.zeros_like(v)
            phi_u = torch.zeros_like(u)
            phi_v = torch.zeros_like(v)
            for j, (u_j, v_j) in enumerate(particles):
                k_ij = self.kernel(u, u_j, bandwidth)
                k_u += k_ij * grads_u[j]
                k_v += k_ij * grads_v[j]
                phi_u += k_ij * (u_j - u)
                phi_v += k_ij * (v_j - v)
            self.particles_u[i].data += lr * (k_u / self.num_particles - phi_u / self.num_particles)
            self.particles_v[i].data += lr * (k_v / self.num_particles - phi_v / self.num_particles)

    def kernel(self, x1, x2, h=1.0):
        sq_dist = torch.sum((x1 - x2)**2)
        return torch.exp(-sq_dist / (2 * h**2))

class SVGDRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim, rank, num_particles):
        super(SVGDRegressionModel, self).__init__()
        self.svgd_layer = SVGDLayer(input_dim, output_dim, rank, num_particles)

    def forward(self, x):
        return self.svgd_layer(x)

# Create a toy dataset
torch.manual_seed(0)
x = torch.randn(100, 10)
true_weights = torch.randn(10, 1)
y = x @ true_weights + 0.1 * torch.randn(100, 1)

# Normalize the dataset
x = (x - x.mean(dim=0)) / x.std(dim=0)
y = (y - y.mean(dim=0)) / y.std(dim=0)

# Initialize model, loss, and optimizer
input_dim = 10
output_dim = 1
rank = 5
num_particles = 10
model = SVGDRegressionModel(input_dim, output_dim, rank, num_particles)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    output = model(x)
    loss = loss_fn(output, y)
    loss.backward()
    optimizer.step()
    model.svgd_layer.svgd_update(loss_fn, x, y)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

Epoch 0, Loss: 10.68558120727539
Epoch 10, Loss: 10.796398162841797
Epoch 20, Loss: 10.94674301147461
Epoch 30, Loss: 11.143123626708984
Epoch 40, Loss: 11.393436431884766
Epoch 50, Loss: 11.70773696899414
Epoch 60, Loss: 12.099539756774902
Epoch 70, Loss: 12.587889671325684
Epoch 80, Loss: 13.199819564819336
Epoch 90, Loss: 13.974235534667969


In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

class SVGDLayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, num_particles, lora_alpha, lora_dropout, merge_weights):
        super(SVGDLayer, self).__init__()
        self.rank = rank
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_particles = num_particles
        self.lora_alpha = lora_alpha
        self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0. else lambda x: x
        self.merge_weights = merge_weights
        self.merged = False

        # Initialize particles
        self.particles_u = nn.ParameterList([nn.Parameter(torch.randn(input_dim, rank)) for _ in range(num_particles)])
        self.particles_v = nn.ParameterList([nn.Parameter(torch.randn(rank, output_dim)) for _ in range(num_particles)])

    def forward(self, x):
        outputs = []
        for u, v in zip(self.particles_u, self.particles_v):
            low_rank_approx = torch.matmul(u, v)
            outputs.append(torch.matmul(x, low_rank_approx))
        return torch.stack(outputs).mean(dim=0)

    def calculate_adaptive_scaling(self, grad_u, grad_v):
        norm_u = torch.norm(grad_u, p='fro')
        norm_v = torch.norm(grad_v, p='fro')
        adaptive_scaling = (self.lora_alpha / self.rank) * (norm_u + norm_v) / (norm_u + norm_v + 1e-8)
        return adaptive_scaling

    def adjust_rank(self, grad_u, grad_v):
        pass  
    
    def svgd_update(self, grads_u, grads_v, lr=0.01):
        for i, (u, v) in enumerate(zip(self.particles_u, self.particles_v)):
            self.particles_u[i].data += lr * grads_u[i]
            self.particles_v[i].data += lr * grads_v[i]

    def compute_gradients(self, loss_fn, x, y):
        grads_u = []
        grads_v = []
        for u, v in zip(self.particles_u, self.particles_v):
            low_rank_approx = torch.matmul(u, v)
            output = torch.matmul(x, low_rank_approx)
            loss = loss_fn(output, y)
            grad_u, grad_v = grad(loss, [u, v])
            grads_u.append(grad_u)
            grads_v.append(grad_v)
        return grads_u, grads_v

class SVGDEmbedding(nn.Embedding):
    def __init__(self, num_embeddings, embedding_dim, rank, num_particles, lora_alpha, lora_dropout, merge_weights, **kwargs):
        super(SVGDEmbedding, self).__init__(num_embeddings, embedding_dim, **kwargs)
        self.svgd_layer = SVGDLayer(embedding_dim, embedding_dim, rank, num_particles, lora_alpha, lora_dropout, merge_weights)

    def forward(self, x):
        embedding_output = super(SVGDEmbedding, self).forward(x)
        return self.svgd_layer(embedding_output)

class SVGDLinear(nn.Linear):
    def __init__(self, in_features, out_features, rank, num_particles, lora_alpha, lora_dropout, merge_weights, **kwargs):
        super(SVGDLinear, self).__init__(in_features, out_features, **kwargs)
        self.svgd_layer = SVGDLayer(in_features, out_features, rank, num_particles, lora_alpha, lora_dropout, merge_weights)

    def forward(self, x):
        linear_output = super(SVGDLinear, self).forward(x)
        return self.svgd_layer(linear_output)

embedding = SVGDEmbedding(10000, 512, rank=10, num_particles=10, lora_alpha=1, lora_dropout=0.1, merge_weights=True)
linear = SVGDLinear(512, 512, rank=10, num_particles=10, lora_alpha=1, lora_dropout=0.1, merge_weights=True)