In [37]:
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, weight, r, alpha):
        super(LoRALayer, self).__init__()
        self.weight = weight
        self.weight.requires_grad = False
        self.r = r
        self.alpha = alpha
        out_features = self.weight.shape[0]
        in_features = self.weight.shape[1]
        self.A = nn.Parameter(self.weight.new_zeros(self.r, in_features))
        self.B = nn.Parameter(self.weight.new_zeros(out_features, r))
    
    def forward(self, x):
        result = x @ self.weight.T
        result += x @ (self.A.T @ self.B.T)
        return result

In [47]:
class FFN(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels):
        super(FFN, self).__init__()
        self.linear1 = nn.Linear(in_channels, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, out_channels)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return self.sigmoid(x)

In [49]:
from torch.utils.data import DataLoader, TensorDataset

ffn = FFN(2, 16, 1)
x_xor = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y_xor = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)

dataset_xor = TensorDataset(x_xor, y_xor)
dataloader_xor = DataLoader(dataset_xor, batch_size=1, shuffle=True)

def train_xor_model(model, dataloader):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(150):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

def validate_xor_model(model, dataloader):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # No gradients needed for predictions
        for inputs, labels in dataloader:
            outputs = model(inputs)
            print(f"Input: {inputs.numpy()}, Predicted: {outputs.numpy()}")

train_xor_model(ffn, dataloader_xor)
validate_xor_model(ffn, dataloader_xor)


Epoch 10, Loss: 0.18813684582710266
Epoch 20, Loss: 0.20629438757896423
Epoch 30, Loss: 0.12048628181219101
Epoch 40, Loss: 0.1141696348786354
Epoch 50, Loss: 0.054154861718416214
Epoch 60, Loss: 0.04189957305788994
Epoch 70, Loss: 0.024022269994020462
Epoch 80, Loss: 0.016496988013386726
Epoch 90, Loss: 0.012147413566708565
Epoch 100, Loss: 0.008923808112740517
Epoch 110, Loss: 0.008023484610021114
Epoch 120, Loss: 0.005624703131616116
Epoch 130, Loss: 0.004331280943006277
Epoch 140, Loss: 0.0038061970844864845
Epoch 150, Loss: 0.003470025723800063
Input: [[0. 1.]], Predicted: [[0.9449311]]
Input: [[1. 1.]], Predicted: [[0.05871405]]
Input: [[1. 0.]], Predicted: [[0.9436962]]
Input: [[0. 0.]], Predicted: [[0.05923378]]


In [50]:
ffn_weight = ffn.linear1.weight
lora_layer = LoRALayer(ffn_weight, 1, 0.1)
ffn.linear1 = lora_layer

y_or = torch.tensor([[0], [1], [1], [1]], dtype=torch.float32)

dataset_xor = TensorDataset(x_xor, y_or)
dataloader_xor = DataLoader(dataset_xor, batch_size=1, shuffle=True)

train_xor_model(ffn, dataloader_xor)
validate_xor_model(ffn, dataloader_xor)

Epoch 10, Loss: 0.9626904726028442
Epoch 20, Loss: 0.22305436432361603
Epoch 30, Loss: 0.2219855636358261
Epoch 40, Loss: 0.05405284836888313
Epoch 50, Loss: 0.0004834290884900838
Epoch 60, Loss: 6.875165126984939e-05
Epoch 70, Loss: 6.342078995658085e-05
Epoch 80, Loss: 0.00043148259283043444
Epoch 90, Loss: 0.00043931047548539937
Epoch 100, Loss: 0.006996463984251022
Epoch 110, Loss: 5.764966408605687e-05
Epoch 120, Loss: 0.00046141701750457287
Epoch 130, Loss: 0.06457482278347015
Epoch 140, Loss: 0.05713595449924469
Epoch 150, Loss: 0.0004778832080774009
Input: [[0. 1.]], Predicted: [[0.99257636]]
Input: [[1. 0.]], Predicted: [[0.978135]]
Input: [[1. 1.]], Predicted: [[0.9391143]]
Input: [[0. 0.]], Predicted: [[0.22481981]]
