In [1]:
import copy
import os
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
import peft
import torch
from torch import nn
import torch.nn.functional as F
from MLP_function import *


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
set_all_seeds(45)
# Data 1
X_1 = torch.rand((500, 1, 28, 28, 28))
# set_all_seeds(45)
y_1 = (X_1.sum(dim=[1,2,3,4]) > 11000).long()  # Binary classification based on voxel sum

# Data 2
X_2 = torch.rand((500, 1, 28, 28, 28))
# set_all_seeds(45)
y_2 = (X_2.sum(dim=[1,2,3,4]) > 11000).long()  # Binary classification based on voxel sum

# Data 3
X_3 = torch.rand((500, 1, 28, 28, 28))
# set_all_seeds(45)
y_3 = (X_3.sum(dim=[1,2,3,4]) > 11000).long()  # Binary classification based on voxel sum

n_train = 400
batch_size = 32
train_dataloader_1 = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X_1[:n_train], y_1[:n_train]),
    batch_size=batch_size,
    shuffle=True,
)
eval_dataloader_1 = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X_1[n_train:], y_1[n_train:]),
    batch_size=batch_size,
)

train_dataloader_2 = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X_2[:n_train], y_2[:n_train]),
    batch_size=batch_size,
    shuffle=True,
)
eval_dataloader_2 = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X_2[n_train:], y_2[n_train:]),
    batch_size=batch_size,
)
train_dataloader_3 = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X_3[:n_train], y_3[:n_train]),
    batch_size=batch_size,
    shuffle=True,
)
eval_dataloader_3 = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X_3[n_train:], y_3[n_train:]),
    batch_size=batch_size,
)
lr = 0.002
max_epochs = 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = peft.LoraConfig(
    r=8,
    target_modules=["conv.0", "conv.3", 'fc.0'],
    # modules_to_save=["fc"] ,
)

In [3]:
# Base model
original_model = CNN3D().to(device)
original_model_copy = copy.deepcopy(original_model)
all_params = torch.cat([param.data.view(-1) for param in original_model.parameters()])
print('original_model details:')
print(f"All params: mean= {all_params.mean().item():.6f}, std={all_params.std().item():.6f}")


original_model details:
All params: mean= -0.000005, std=0.006723


In [4]:
### train base model with data 1
optimizer = optim.Adam(original_model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
train(original_model, optimizer, criterion, train_dataloader_1, eval_dataloader_1, device, epochs=max_epochs)
# Save model checkpoints
model_path = 'model_checkpoints/3DCNN_base'
if not os.path.exists(model_path):
    os.makedirs(model_path)
torch.save(original_model.state_dict(), os.path.join(model_path, f"model.pt"))

Batch 0 get train loss: 0.7069319486618042 val loss: 0.5935408473014832
Batch 1 get train loss: 0.6280664801597595 val loss: 0.5944093465805054
Batch 2 get train loss: 0.6432785987854004 val loss: 0.6026040315628052
Batch 3 get train loss: 0.6130321621894836 val loss: 0.5715664625167847
Batch 4 get train loss: 0.6160570979118347 val loss: 0.5756078958511353
Batch 5 get train loss: 0.6223089098930359 val loss: 0.5971855521202087
Batch 6 get train loss: 0.621879518032074 val loss: 0.571010410785675
Batch 7 get train loss: 0.623405933380127 val loss: 0.572990894317627
Batch 8 get train loss: 0.6278942823410034 val loss: 0.570976972579956
Batch 9 get train loss: 0.6140992641448975 val loss: 0.5857475399971008
Batch 10 get train loss: 0.6092574000358582 val loss: 0.5711942315101624
Batch 11 get train loss: 0.6117033362388611 val loss: 0.5811612606048584
Batch 12 get train loss: 0.608684241771698 val loss: 0.5724223852157593
Batch 13 get train loss: 0.6101047396659851 val loss: 0.57342219352

In [5]:
original_model_all_params = torch.cat([param.data.view(-1) for param in original_model.parameters()])
print('original_model after train details:')
print(f"All params: mean= {original_model_all_params.mean().item():.6f}, std={original_model_all_params.std().item():.6f}")

original_model_copy_all_params = torch.cat([param.data.view(-1) for param in original_model_copy.parameters()])
print('original_model_copy details:')
print(f"All params: mean= {original_model_copy_all_params.mean().item():.6f}, std={original_model_copy_all_params.std().item():.6f}")

original_model after train details:
All params: mean= -0.001732, std=0.007998
original_model_copy details:
All params: mean= -0.000005, std=0.006723


In [6]:
base_model = CNN3D().to(device)
load_model_path = 'model_checkpoints/3DCNN_base/model.pt'
base_model.load_state_dict(torch.load(load_model_path))
base_model_all_params = torch.cat([param.data.view(-1) for param in base_model.parameters()])
print('base_model details:')
print(f"All params: mean= {base_model_all_params.mean().item():.6f}, std={base_model_all_params.std().item():.6f}")

peft_model = peft.get_peft_model(base_model, config)
print('peft_model details:')
all_params = torch.cat([param.data.view(-1) for param in peft_model.parameters()])
print(f"All params: mean= {all_params.mean().item():.6f}, std={all_params.std().item():.6f}")

optimizer_peft = optim.Adam(peft_model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
train(peft_model, optimizer_peft, criterion, train_dataloader_1, eval_dataloader_1, device, epochs=max_epochs)

base_model details:
All params: mean= -0.001732, std=0.007998
peft_model details:
All params: mean= -0.001603, std=0.008132
Batch 0 get train loss: 0.6083725690841675 val loss: 0.573870837688446
Batch 1 get train loss: 0.5985857844352722 val loss: 0.5798882246017456
Batch 2 get train loss: 0.5933542847633362 val loss: 0.5784934163093567
Batch 3 get train loss: 0.5917235612869263 val loss: 0.5729011297225952
Batch 4 get train loss: 0.5912915468215942 val loss: 0.5783311128616333
Batch 5 get train loss: 0.5881150364875793 val loss: 0.5742651224136353
Batch 6 get train loss: 0.5919513702392578 val loss: 0.573917031288147
Batch 7 get train loss: 0.5904673933982849 val loss: 0.5750579833984375
Batch 8 get train loss: 0.5882015824317932 val loss: 0.5734982490539551
Batch 9 get train loss: 0.593615710735321 val loss: 0.5728814601898193
Batch 10 get train loss: 0.5833855867385864 val loss: 0.5799612998962402
Batch 11 get train loss: 0.5822224617004395 val loss: 0.5719284415245056
Batch 12 get 

In [7]:
# Save lora adapter
adapter_save_path = 'lora_adapter/adapter_1'
peft_model.save_pretrained(adapter_save_path)

In [8]:
# load adapter
base_model = CNN3D().to(device)
load_model_path = 'model_checkpoints/3DCNN_base/model.pt'
base_model.load_state_dict(torch.load(load_model_path))
adapter_name = 'lora_adapter/adapter_1'
peft_model_from_adapter = peft.PeftModel.from_pretrained(base_model, adapter_name)
print('peft_model_from_adapter type',type(peft_model_from_adapter))
print('base_model type',type(base_model))
with torch.no_grad():
    y_peft = peft_model(X_1.to(device))
    y_loaded_adapter = peft_model_from_adapter(X_1.to(device))
print('if peft_model and peft_model_from_adapter have same results:',torch.allclose(y_peft, y_loaded_adapter))
del y_peft, y_loaded_adapter
torch.cuda.empty_cache()

peft_model_from_adapter_all_params = torch.cat([param.data.view(-1) for param in peft_model_from_adapter.parameters()])
print('peft_model_from_adapter details:')
print(f"All params: mean= {peft_model_from_adapter_all_params.mean().item():.6f}, std={peft_model_from_adapter_all_params.std().item():.6f}")

peft_model_from_adapter type <class 'peft.peft_model.PeftModel'>
base_model type <class 'MLP_function.CNN3D'>
if peft_model and peft_model_from_adapter have same results: True
peft_model_from_adapter details:
All params: mean= -0.001624, std=0.011556


In [9]:
base_model = CNN3D().to(device)
load_model_path = 'model_checkpoints/3DCNN_base/model.pt'
base_model.load_state_dict(torch.load(load_model_path))

adapter_name = 'lora_adapter/adapter_1'
unmerged_peft_model = peft.PeftModel.from_pretrained(base_model, adapter_name)
print('unmerged_peft_model details:',type(unmerged_peft_model))
unmerged_peft_model_all_params = torch.cat([param.data.view(-1) for param in unmerged_peft_model.parameters()])
print(f"All params: mean= {unmerged_peft_model_all_params.mean().item():.6f}, std={unmerged_peft_model_all_params.std().item():.6f}")
merged_model_1 = unmerged_peft_model.merge_and_unload()  

model_path = 'model_checkpoints/3DCNN_base'
if not os.path.exists(model_path):
    os.makedirs(model_path)
torch.save(original_model.state_dict(), os.path.join(model_path, f"merged_model_1.pt"))


unmerged_peft_model details: <class 'peft.peft_model.PeftModel'>
All params: mean= -0.001624, std=0.011556


In [11]:
## Fine-tuning with data 2
peft_model = peft.get_peft_model(merged_model_1, config)
optimizer_peft = optim.Adam(peft_model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
train(peft_model, optimizer_peft, criterion, train_dataloader_2, eval_dataloader_2, device, epochs=max_epochs)

Batch 0 get train loss: 0.5848673582077026 val loss: 0.5117658972740173
Batch 1 get train loss: 0.5924632549285889 val loss: 0.5063806176185608
Batch 2 get train loss: 0.5938414335250854 val loss: 0.5077865123748779
Batch 3 get train loss: 0.5907303094863892 val loss: 0.5112826824188232
Batch 4 get train loss: 0.5913812518119812 val loss: 0.5106264352798462
Batch 5 get train loss: 0.5782382488250732 val loss: 0.5152074098587036
Batch 6 get train loss: 0.5847507119178772 val loss: 0.5055463910102844
Batch 7 get train loss: 0.5808209776878357 val loss: 0.5207860469818115
Batch 8 get train loss: 0.5787420868873596 val loss: 0.5044960975646973
Batch 9 get train loss: 0.568045437335968 val loss: 0.5193659067153931
Batch 10 get train loss: 0.5595629215240479 val loss: 0.5115875005722046
Batch 11 get train loss: 0.5444388389587402 val loss: 0.5090911388397217
Batch 12 get train loss: 0.5368410348892212 val loss: 0.5409210920333862
Batch 13 get train loss: 0.5278259515762329 val loss: 0.518061

In [13]:
# Save lora adapter
adapter_save_path = 'lora_adapter/adapter_2'
peft_model.save_pretrained(adapter_save_path)

In [16]:
base_model = CNN3D().to(device)
load_model_path = 'model_checkpoints/3DCNN_base/merged_model_1.pt'
base_model.load_state_dict(torch.load(load_model_path))

adapter_name = 'lora_adapter/adapter_2'
unmerged_peft_model_2 = peft.PeftModel.from_pretrained(base_model, adapter_name)

merged_model_2 = unmerged_peft_model_2.merge_and_unload()

with torch.no_grad():
    y__loaded = merged_model_2(X_2.to(device))
    y_trained = peft_model(X_2.to(device))

print('if the trained model and loaded model are the same:', torch.allclose(y__loaded, y_trained))
del y__loaded, y_trained
torch.cuda.empty_cache()


if the trained model and loaded model are the same: False
