Skip to content

Commit

Permalink
Support Stable Video Diffusion model.
Browse files Browse the repository at this point in the history
  • Loading branch information
shiimizu committed Jan 16, 2024
1 parent 99c062d commit 195fed5
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tiled_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ def __call__(self, model_function: BaseModel.apply_model, args: dict):
cond_tile = self.repeat_tensor(c_crossattn, n_rep)
c_tile = c_in.copy()
c_tile['c_crossattn'] = cond_tile
if 'time_context' in c_in:
c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep)
for key in ['y', 'c_concat']:
if key in c_tile:
c_tile[key] = self.repeat_tensor(c_tile[key], n_rep)
Expand Down Expand Up @@ -499,6 +501,8 @@ def __call__(self, model_function: BaseModel.apply_model, args: dict):
tcond_tile = self.repeat_tensor(c_crossattn, n_rep) # just repeat
c_tile = c_in.copy()
c_tile['c_crossattn'] = tcond_tile
if 'time_context' in c_in:
c_tile['time_context'] = self.repeat_tensor(c_in['time_context'], n_rep) # just repeat
for key in ['y', 'c_concat']:
if key in c_in:
icond_tile = torch.cat(icond_map[key], dim=0) # differs each
Expand Down

0 comments on commit 195fed5

Please sign in to comment.