Skip to content

Commit

Permalink
Merge commit '295681590db2b41faab3b847d11a889088df1851' into fix/data…
Browse files Browse the repository at this point in the history
…set-0620

* commit '295681590db2b41faab3b847d11a889088df1851':
  fix glm4v images (modelscope#1194)
  fix glm4v dataloader (modelscope#1183)
  Fix dataset concatenation (modelscope#1193)
  • Loading branch information
tastelikefeet committed Jun 21, 2024
2 parents 6e4300a + 2956815 commit 1224f11
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,10 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
if labels is not None:
labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
image_size: int = self.model.config.vision_config['image_size']
patch_size: int = self.model.config.vision_config['patch_size']
num_patches = (image_size // patch_size // 2)**2
labels = (labels[:idx] + [-100] * (len(placeholder_id) + num_patches - 1) + labels[idx + 1:])
messages = history_to_messages(example.get('history', []), example['query'], example.get('system', None))
messages[0]['image'] = image
inputs2: Dict[str, Any] = self.tokenizer.apply_chat_template(messages, return_dict=True)
Expand All @@ -949,8 +952,15 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
inputs['labels'] = labels
return inputs, {}

def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
images = [b['images'] for b in batch if 'images' in b]
if images:
res['images'] = torch.concat(images)
return res


register_template(TemplateType.glm4v, GLM4VTemplate(), infer_media_type='dialogue', lazy_tokenize=True)
register_template(TemplateType.glm4v, GLM4VTemplate(), infer_media_type='dialogue', lazy_tokenize=True, use_model=True)

register_template(
TemplateType.yi_vl,
Expand Down

0 comments on commit 1224f11

Please sign in to comment.