Skip to content

Commit

Permalink
fixbugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ShengdingHu committed Jun 6, 2022
1 parent 9a385f8 commit f56dd68
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 8 deletions.
Binary file modified dist/opendelta-0.0.4.tar.gz
Binary file not shown.
Binary file added dist/opendelta-0.1.0-py3-none-any.whl
Binary file not shown.
Binary file added dist/opendelta-0.1.0.tar.gz
Binary file not shown.
10 changes: 6 additions & 4 deletions examples/examples_prompt/src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def main():

config, tokenizer, model = get_backbone(model_args=model_args)

# model parallelize
if hasattr(training_args, "model_parallel") and training_args.model_parallel:
logger.info('parallelize model!')
model.parallelize()

from opendelta import Visualization
Visualization(model).structure_graph()

Expand All @@ -161,10 +166,7 @@ def main():
delta_model.freeze_module(set_state_dict = True)
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)

# model parallelize
if hasattr(training_args, "model_parallel") and training_args.model_parallel:
logger.info('parallelize model!')
model.parallelize()




Expand Down
2 changes: 1 addition & 1 deletion opendelta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

__version__ = "0.0.4"
__version__ = "0.1.0"

class GlobalSetting:
def __init__(self):
Expand Down
6 changes: 5 additions & 1 deletion opendelta/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from opendelta.utils.structure_mapping import CommonStructureMap
from opendelta.utils.interactive.web import interactive
from opendelta.utils.data_parallel import new_replicate_for_data_parallel
from opendelta.utils.cuda import move_dict_to_cuda

logger = logging.get_logger(__name__)

def is_leaf_module(module):
Expand Down Expand Up @@ -109,6 +111,7 @@ def __init__(self,
else:
self.modified_modules = interactive(backbone_model, port=interactive_modify)
self.common_structure = False
self.exclude_modules = self.default_exclude_modules
else:
self.modified_modules = self.default_modified_modules
self.common_structure = True
Expand Down Expand Up @@ -328,11 +331,12 @@ def _pseudo_data_to_instantiate(self, module: Optional[nn.Module]=None):
"""
if module is None:
module = self.backbone_model
device = get_device(module)
try:
dummy_inputs = module.dummy_inputs
dummy_inputs = move_dict_to_cuda(dummy_inputs, device)
module(**dummy_inputs)
except AttributeError:
device = get_device(module)
logger.warning("No dummy_inputs attributes, create a common input_ids for input.")
pseudo_input = torch.tensor([[0,0]]).to(device)
if "decoder_input_ids" in signature(module.forward).args:
Expand Down
9 changes: 9 additions & 0 deletions opendelta/utils/cuda.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union
import torch.nn as nn
import torch

def get_device(module : Union[nn.Module, nn.Parameter]):
if not (isinstance(module, nn.Module) \
Expand All @@ -17,6 +18,14 @@ def get_device(module : Union[nn.Module, nn.Parameter]):
raise RuntimeError("The module is paralleled acrossed device, please get device in a inner module")


def move_dict_to_cuda(dict_of_tensor, device):
for key in dict_of_tensor:
if isinstance(dict_of_tensor[key], torch.Tensor):
dict_of_tensor[key] = dict_of_tensor[key].to(device)
return dict_of_tensor



# unitest, should be removed later
if __name__ == "__main__":
import torch
Expand Down
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@

import setuptools
import os
import os

def get_requirements(path):
print("path is :", path)
ret = []
with open(os.path.join(path, "requirements.txt"), encoding="utf-8") as freq:

with open(os.path.join("/mnt/sfs_turbo/hsd/opendelta1.0.0_beta/OpenDelta/requirements.txt"), encoding="utf-8") as freq:
for line in freq.readlines():
ret.append( line.strip() )
return ret
Expand All @@ -17,7 +20,7 @@ def get_requirements(path):
with open('README.md', 'r') as f:
setuptools.setup(
name = 'opendelta',
version = '0.0.4',
version = "0.1.0",
description = "An open source framework for delta learning (parameter efficient learning).",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand All @@ -28,6 +31,11 @@ def get_requirements(path):
keywords = ['PLM', 'Parameter-efficient-Learning', 'AI', 'NLP'],
python_requires=">=3.6.0",
install_requires=requires,
package_dir={'opendelta':'opendelta'},
package_data= {
'opendelta':["utils/interactive/templates/*.html"],
},
include_package_data=True,
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
Expand Down

0 comments on commit f56dd68

Please sign in to comment.