Skip to content

Commit

Permalink
training fix for test
Browse files Browse the repository at this point in the history
  • Loading branch information
ricsinaruto committed Mar 28, 2023
1 parent ff4fd8a commit 553cff7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ transformers
PyWavelets
hyperopt
pyriemann
xgboost
xgboost
sails
5 changes: 1 addition & 4 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ def __init__(self, args, dataset=None):
#self.args.dataset = self.dataset
else:
self.model_path = os.path.join(self.args.result_dir, 'model.pt')
if args.from_pretrained:
self.model = args.model.from_pretrained(args)
else:
self.model = args.model(args)
self.model = args.model(args)

try:
self.model = self.model.cuda()
Expand Down
6 changes: 5 additions & 1 deletion transformers_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,15 @@ class TransformerQuantizedPretrained(TransformerQuantized):
'''
Same as TransformerQuantized, but uses the pretrained GPT2 model
'''
def __init__(self, args):
super().__init__(args.gpt2_config)

@classmethod
def from_pretrained(cls, args):
return super().from_pretrained('gpt2',
args,
config=args.gpt2_config)
config=args.gpt2_config,
cache_dir=args.result_dir)


class TransformerQuantizedConcatEmb(TransformerQuantized):
Expand Down

0 comments on commit 553cff7

Please sign in to comment.