In [1]:
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

class TestMLP(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim) -> None:
        super().__init__()
        self.net_11=MLP(input_dim, hidden_dim, output_dim)
        self.net_seq=nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        self.net_list=nn.ModuleList([nn.Linear(output_dim, hidden_dim), nn.ReLU(), nn.Linear(output_dim, output_dim)])  
        self.out=nn.Linear(output_dim, output_dim)
        self.relu=nn.ReLU()
        
    def forward(self, x):
        x=self.net_11(x)
        x=self.net_seq(x)
        x=self.net_list(x)
        x=self.out(x)
        x=self.relu(x)
        return x  
    
mlp=TestMLP(10, 20, 1)
mlp._modules.items()

odict_items([('net_11', MLP(
  (fc1): Linear(in_features=10, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=1, bias=True)
  (relu): ReLU()
)), ('net_seq', Sequential(
  (0): Linear(in_features=1, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=20, out_features=1, bias=True)
)), ('net_list', ModuleList(
  (0): Linear(in_features=1, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=1, out_features=1, bias=True)
)), ('out', Linear(in_features=1, out_features=1, bias=True)), ('relu', ReLU())])

In [24]:
from GIFt.utils import ModuleIterator
from GIFt.strategies import FineTuningStrategy,LoRAFullFineTuningStrategy
from GIFt.utils import freeze_module,num_trainable_parameters
from GIFt.meta_types import FinetuableModule

def finetuning_sd_hook(module, state_dict, *args, **kwargs):
    '''
    Clean the state_dict of the module, removing all the parameters that are not trainable.
    It is better to remove all the parameters that are not trainable from the state_dict rather than create a new state_dict
    rather than create a new state_dict with trainable parameters only. This is because sometimes the state_dict also contains 
    untrainable buffers, which should be kept in the state_dict.
    '''
    new_state_dict = {}
    not_requires_grad_paras=[name for name,param in module.named_parameters() if not param.requires_grad]
    for key, value in state_dict.items():
        if key not in not_requires_grad_paras:
            new_state_dict[key] = value
    return new_state_dict

def finetuning_loadsd_posthook(module, incompatible_keys):
    '''
    Enable load_state_dict to load the finetuned model.
    The default load_state_dict will raise an error since it also tries to load the unfinetuned parameters.
    If you don't want to load this hook, you can also set `strick=False` in `load_state_dict` function.
    '''
    finetuned_sd_keys=module.state_dict().keys()
    key_copys=incompatible_keys.missing_keys.copy()
    for key in key_copys:
        if key not in finetuned_sd_keys:
            incompatible_keys.missing_keys.remove(key)

def trainable_parameters(module:nn.Module,recurse:bool=True):
    for name, param in module.named_parameters(recurse=recurse):
        if param.requires_grad:
            yield param

def num_trainable_parameters(module:nn.Module):
    return sum(p.numel() for p in trainable_parameters(module))

def num_parameters(module:nn.Module):
    return sum(p.numel() for p in module.parameters())

def replace_modules(module:nn.Module,finetuning_strategy:FineTuningStrategy,parent_name:str=""):
    # Replace layers with finetuable layers
    for name, global_name, class_name, layer_obj, has_child in ModuleIterator(module,parent_name):
        find=False
        if isinstance(layer_obj,FinetuableModule):
            raise ValueError(f"Layer {global_name} is already finetuable")
        for check_func,act_func in finetuning_strategy:
            if check_func(name, global_name, class_name, layer_obj):
                act_func(module,name, global_name, class_name, layer_obj)
                find=True
                break
        if not find and has_child:
            replace_modules(layer_obj,finetuning_strategy,name)
        else:
            freeze_module(layer_obj)

def enable_finetuning(module:nn.Module,finetuning_strategy:FineTuningStrategy):
    # replace modules
    replace_modules(module,finetuning_strategy)
    # add hook to the module to remove untrainable parameters from the state_dict
    module._register_state_dict_hook(finetuning_sd_hook)
    # add hook to the module to enable load_state_dict to load the finetuned model
    module.register_load_state_dict_post_hook(finetuning_loadsd_posthook)
    # add trainable_parameters function to the module
    setattr(module,"trainable_parameters",lambda recurse=True: trainable_parameters(module,recurse))
    
    


mlp=TestMLP(100, 200, 1)
#print(mlp.state_dict().keys())
print(num_trainable_parameters(mlp))
lora_strategy=LoRAFullFineTuningStrategy()
enable_finetuning(mlp,lora_strategy)
print(num_trainable_parameters(mlp))
#print(mlp.state_dict().keys())
current_sd=mlp.state_dict()
mlp.load_state_dict(current_sd)
mlp.trainable_parameters()
optimizer=torch.optim.Adam(mlp.trainable_parameters(),lr=0.01)
mlp.parameters()
mlp.trainable_parameters()
num_trainable_parameters(mlp)

21406
3324


3324

In [8]:
class myclass():
    
    def __init__(self) -> None:
        self.ml=[1,2,3,4,5]
    
    def __len__(self):
        return len(self.ml)
    
    def __getitem__(self, index):
        return self.ml[index]

a=myclass()
for a_i in a:
    print(a_i)

def add_func(self,x):
    print(x)
    print(self.ml)

setattr(a, "additional", lambda x: add_func(a,x))
a.additional(10)

1
2
3
4
5
10
[1, 2, 3, 4, 5]


In [7]:
class MyNumbers:
  def __iter__(self):
    print("iter")
    self.a = 1
    return self

  def __next__(self):
    if self.a <= 20:
      x = self.a
      self.a += 1
      return x
    else:
      raise StopIteration

myclass = MyNumbers()
myiter = iter(myclass)

for x in myiter:
  print(myiter.a)
print(myiter.a)  
for x in myiter:
  print(myiter.a)

iter
iter
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
21
iter
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
