Skip to content

Commit

Permalink
fix seq len error.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Apr 12, 2023
1 parent ec540f4 commit 7a0be59
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 30 deletions.
10 changes: 5 additions & 5 deletions textgen/config/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ModelArgs:
logging_steps: int = 50
manual_seed: int = None
max_grad_norm: float = 1.0
max_seq_length: int = 128
max_seq_length: int = 128 # max length of input sequence
model_name: str = None
model_type: str = None
multiprocessing_chunksize: int = -1
Expand Down Expand Up @@ -153,7 +153,7 @@ class T5Args(ModelArgs):
early_stopping: bool = True
evaluate_generated_text: bool = False
length_penalty: float = 2.0
max_length: int = 20
max_length: int = 128 # max length of the sequence to be generated
max_steps: int = -1
num_beams: int = 1
num_return_sequences: int = 1
Expand Down Expand Up @@ -183,7 +183,7 @@ class CopyT5Args(ModelArgs):
early_stopping: bool = True
evaluate_generated_text: bool = False
length_penalty: float = 2.0
max_length: int = 20
max_length: int = 128 # max length of the sequence to be generated
max_steps: int = -1
num_beams: int = 3
num_return_sequences: int = 1
Expand Down Expand Up @@ -247,7 +247,7 @@ class Seq2SeqArgs(ModelArgs):
faiss_d: int = 768
faiss_m: int = 128
length_penalty: float = 2.0
max_length: int = 20
max_length: int = 128 # max length of the sequence to be generated
max_steps: int = -1
num_beams: int = 1
num_return_sequences: int = 1
Expand Down Expand Up @@ -275,7 +275,7 @@ class LanguageGenerationArgs(ModelArgs):
early_stopping: bool = True
evaluate_generated_text: bool = False
length_penalty: float = 2.0
max_length: int = 20
max_length: int = 128 # max length of the sequence to be generated
max_steps: int = -1
num_beams: int = 1
num_return_sequences: int = 1
Expand Down
30 changes: 13 additions & 17 deletions textgen/seq2seq/bart_seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,16 @@
import random
import warnings
from dataclasses import asdict
from multiprocessing import Pool, cpu_count
from multiprocessing import Pool

import numpy as np
import pandas as pd
import torch
from datasets import load_from_disk
from loguru import logger
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm, trange
from transformers.optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)
from transformers.optimization import AdamW, Adafactor
from transformers import (
AutoConfig,
AutoModel,
Expand All @@ -45,7 +37,6 @@
ElectraConfig,
ElectraModel,
ElectraTokenizerFast,
EncoderDecoderConfig,
EncoderDecoderModel,
LongformerConfig,
LongformerModel,
Expand All @@ -56,8 +47,6 @@
MobileBertConfig,
MobileBertModel,
MobileBertTokenizerFast,
PreTrainedModel,
PreTrainedTokenizerFast,
RagTokenizer,
RagRetriever,
RagTokenForGeneration,
Expand All @@ -67,8 +56,15 @@
RobertaModel,
RobertaTokenizerFast,
)
from datasets import load_from_disk
from loguru import logger
from transformers.optimization import AdamW, Adafactor
from transformers.optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)

from textgen.config.model_args import Seq2SeqArgs
from textgen.seq2seq.bart_seq2seq_utils import (
Expand Down
7 changes: 4 additions & 3 deletions textgen/seq2seq/bart_seq2seq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def preprocess_data(data):
)
target_inputs = encoder_tokenizer.generator(
target_text,
max_length=args.max_seq_length,
max_length=args.max_length,
padding="max_length",
return_tensors="pt",
truncation=True,
Expand All @@ -229,7 +229,7 @@ def preprocess_data(data):

target_text = decoder_tokenizer.encode(
target_text,
max_length=args.max_seq_length,
max_length=args.max_length,
padding="max_length",
return_tensors="pt",
truncation=True,
Expand Down Expand Up @@ -312,7 +312,7 @@ def preprocess_data_bart(data):

target_ids = tokenizer.batch_encode_plus(
[target_text],
max_length=args.max_seq_length,
max_length=args.max_length,
padding="max_length",
return_tensors="pt",
truncation=True,
Expand All @@ -334,6 +334,7 @@ def preprocess_data_mbart(data):
src_lang=args.src_lang,
tgt_lang=args.tgt_lang,
max_length=args.max_seq_length,
max_target_length=args.max_length,
padding="max_length", # pad_to_max_length=True won't work in this case
return_tensors="pt",
truncation=True,
Expand Down
5 changes: 0 additions & 5 deletions textgen/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ def __init__(self, config=None, data_args=None, *args, **kwargs):

if self.args.label_smoothing == 0:
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
else:
# dynamically import label_smoothed_nll_loss
from utils import label_smoothed_nll_loss

self.loss_fn = label_smoothed_nll_loss

def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Expand Down

0 comments on commit 7a0be59

Please sign in to comment.