In [8]:
import torch
import torch.nn as nn
from GIFt import enable_finetuning
from GIFt.strategies import LoRAFullFineTuningStrategy
from GIFt.utils import num_trainable_parameters

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim,num_layers):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        for i in range(num_layers-1):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
        self.layers.append(nn.Linear(hidden_dim, output_dim))
        self.relu = nn.ReLU()
    
    def forward(self, x):
        for i in range(self.num_layers):
            x = self.layers[i](x)
            x = self.relu(x)
        x = self.layers[-1](x)
        return x

mlp=MLP(10, 100, 5, 5)
print(mlp)
print("Network before enable finetuning:",mlp)
print("Number of trainable parameters:",num_trainable_parameters(mlp))
# Enable fine-tuning with specified strategy
# LoRAFullFineTuningStrategy() will replace all Linear layers with LoRA layers and all Conv1/2/3D layers with LoRAConv1/2/3D layers
enable_finetuning(mlp, LoRAFullFineTuningStrategy())
print("Network afer enable finetuning:",mlp)
print("Number of trainable parameters after fine-tuning:",num_trainable_parameters(mlp))
# use `trainable_parameters()` rather than `parameters()` to get the trainable parameters for fine-tuning
# Also avaliable through `GIFt.utils.trainable_parameters()`
optimizer = torch.optim.Adam(mlp.trainable_parameters(), lr=0.001)
# After using `enable_finetuning`, the state dict of the model will be updated to only include the trainable parameters
state_dict = mlp.state_dict()
print(state_dict.keys())
torch.save(state_dict, "mlp.pth")
mlp.load_state_dict(torch.load("mlp.pth"))

MLP(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=100, bias=True)
    (1-4): 4 x Linear(in_features=100, out_features=100, bias=True)
    (5): Linear(in_features=100, out_features=5, bias=True)
  )
  (relu): ReLU()
)
Network before enable finetuning: MLP(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=100, bias=True)
    (1-4): 4 x Linear(in_features=100, out_features=100, bias=True)
    (5): Linear(in_features=100, out_features=5, bias=True)
  )
  (relu): ReLU()
)
Number of trainable parameters: 42005
Network afer enable finetuning: MLP(
  (layers): ModuleList(
    (0): LoRALinear(
      (parent_module): Linear(in_features=10, out_features=100, bias=True)
    )
    (1-4): 4 x LoRALinear(
      (parent_module): Linear(in_features=100, out_features=100, bias=True)
    )
    (5): LoRALinear(
      (parent_module): Linear(in_features=100, out_features=5, bias=True)
    )
  )
  (relu): ReLU()
)
Number of trainable parameters after fine-tuning: 

<All keys matched successfully>