In [1]:
import torch
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from tqdm.auto import tqdm
import numpy as np
import torch.nn.functional as F
import math

from transformers import AutoTokenizer, BertForMaskedLM
from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler

import numpy as np
import matplotlib.pyplot as plt

from src.modeling_diffbert_sample import DiffBertForDiffusion
from src.modeling_diffllama import DiffLlamaForDiffusionLM
from src.modeling_diffmamba import DiffMambaForDiffusionLM
from src.configuration_diffbert import DiffBertConfig
from src.schedulers.euler_ancestral_discrete import EulerAncestralDiscreteScheduler

    

    
# model(inputs_embeds=inputs_embeds, timesteps=timesteps).logits.shape

  from .autonotebook import tqdm as notebook_tqdm


[2023-12-09 00:03:40,075] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
tokenizer = AutoTokenizer.from_pretrained("models/diffmamba-mini-sample-trained-good")
tokenizer.add_special_tokens({'pad_token': '<pad>'})
scheduler = EulerAncestralDiscreteScheduler.from_pretrained("models/diffmamba-mini-sample")#DDIMScheduler(prediction_type="sample", num_train_timesteps=2000)
model = DiffMambaForDiffusionLM.from_pretrained("models/diffmamba-mini-sample-trained", add_cross_attention=True, torch_dtype=torch.float16).to("cuda")
device = model.device


cross_attention True


In [19]:
scheduler = DDIMScheduler.from_pretrained("models/diffmamba-mini-sample")#DDIMScheduler(prediction_type="sample", num_train_timesteps=2000)
scheduler.config

FrozenDict([('num_train_timesteps', 1000),
            ('beta_start', 0.0001),
            ('beta_end', 0.02),
            ('beta_schedule', 'linear'),
            ('trained_betas', None),
            ('clip_sample', True),
            ('set_alpha_to_one', True),
            ('steps_offset', 0),
            ('prediction_type', 'sample'),
            ('thresholding', False),
            ('dynamic_thresholding_ratio', 0.995),
            ('clip_sample_range', 1.0),
            ('sample_max_value', 1.0),
            ('timestep_spacing', 'leading'),
            ('rescale_betas_zero_snr', False),
            ('_class_name', 'DDIMScheduler'),
            ('_diffusers_version', '0.23.1')])

In [25]:
scheduler = DDIMScheduler(
  #   beta_end = 0.012,
  # beta_schedule = "scaled_linear",
  # beta_start = 0.00085,
  # clip_sample = False,
#   skip_prk_steps = True,
#   set_alpha_to_one = False,
  # steps_offset = 1,
#   interpolation_type = "linear",
  prediction_type ="sample", 
  num_train_timesteps = 2000)

## Functions

In [3]:

def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    **kwargs,
):
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used,
            `timesteps` must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
                must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

def get_timesteps(num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = scheduler.timesteps[t_start * scheduler.order :]

        return timesteps, num_inference_steps - t_start
        
def vectors_to_indices(vectors):
    indices = torch.argmax(vectors, dim=-1)
    return indices

def sample_text(probabilities, temperature=1.0):
    batch_size, seq_len, vocab_size = probabilities.size()
    flattened_probs = probabilities.view(batch_size * seq_len, -1)
    
    scaled_logits = flattened_probs / temperature
    scaled_probs = F.softmax(scaled_logits, dim=-1)
    
    sampled_indices = torch.multinomial(scaled_probs, 1)
    sampled_token_ids = sampled_indices.view(batch_size, seq_len)
    
    return sampled_token_ids

## Generate

In [6]:
from IPython.display import display, clear_output

batch_size = 8
cfg=1
prompt = ["A melancholic depiction of war with a surrealistic touch, featuring a staircase and weeping statues, illuminated by a sunset and surrounded by ice and intricate details."] * batch_size
neg_prompt = [""] * batch_size

input_ids = tokenizer(prompt, padding="max_length", max_length=64, return_tensors="pt").to("cuda")
neg_input_ids = tokenizer(neg_prompt, padding="max_length", max_length=64, return_tensors="pt").to("cuda")
encoder_hidden_states = model.apply_embeddings(input_ids.input_ids).to(model.dtype)
neg_encoder_hidden_states = model.apply_embeddings(neg_input_ids.input_ids).to(model.dtype)

with torch.no_grad():
    latents = torch.rand((batch_size, 64, 768), device=device).to(model.dtype) + torch.rand((8, 64, 768), device=device).to(torch.float16)
    attention_mask = torch.ones((batch_size, 64), device=device)
    num_inference_steps = 1000
    timesteps=None
    timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps)

    
    for i, t in tqdm(enumerate(timesteps)):
        # if i >= 0.7 * num_inference_steps:
        #     break
        # expand the latents if we are doing classifier free guidance
        latent_model_input =  latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        latent_model_input = torch.cat([latents] * 2) if cfg > 1 else latents
        prompt_embeds = torch.cat([encoder_hidden_states, neg_encoder_hidden_states]) if cfg > 1 else encoder_hidden_states

        outputs = model(
            input_embeds=latent_model_input,
            timesteps=t.reshape(1,).long().to(device),
            encoder_hidden_states=prompt_embeds
        )
        noise_pred = outputs.last_hidden_state
        if cfg > 1:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + cfg * (noise_pred_text - noise_pred_uncond)

        
        latents_final = outputs.logits
        if i % 10 ==0 :
            clear_output(wait=True)
            display(f"SAMPLES[{i}]--->")
            for n in range(latents_final.shape[0]):
                display(f"{n}    --->    " + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))
            display("---------------")

        step = scheduler.step(noise_pred, t, latents, return_dict=True)#[0]
        latents = step["prev_sample"]


clear_output(wait=True)
display(f"FINAL --->")
for n in range(latents_final.shape[0]):
    display(f"{n}    --->    " + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))
display("---------------")

'FINAL --->'

'0    --->    '

'1    --->    ,,,,'

'2    --->    '

'3    --->    ,'

'4    --->    ,'

'5    --->    ,,'

'6    --->    '

'7    --->    '

'---------------'