In [1]:
import torch
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from tqdm.auto import tqdm
import numpy as np
import torch.nn.functional as F
import math

from transformers import AutoTokenizer, BertForMaskedLM
from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler

import numpy as np
import matplotlib.pyplot as plt

from src.modeling_diffbert_sample import DiffBertForDiffusion
from src.modeling_diffllama import DiffLlamaForDiffusionLM
from src.modeling_diffmamba import DiffMambaForDiffusionLM
from src.configuration_diffbert import DiffBertConfig
from src.schedulers.euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from src.schedulers.ddpm import DDPMScheduler

    

    
# model(inputs_embeds=inputs_embeds, timesteps=timesteps).logits.shape

  from .autonotebook import tqdm as notebook_tqdm


[2023-12-10 15:56:48,507] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
tokenizer = AutoTokenizer.from_pretrained("models/diffmamba-mini-sample")
tokenizer.add_special_tokens({'pad_token': '<pad>'})
scheduler = EulerAncestralDiscreteScheduler.from_pretrained("models/diffmamba-mini-sample")#DDIMScheduler(prediction_type="sample", num_train_timesteps=2000)
model = DiffMambaForDiffusionLM.from_pretrained("models/diffmamba-mini-sample-trained", torch_dtype=torch.float16).to("cuda")
device = model.device


cross_attention False


In [5]:
scheduler = DDPMScheduler.from_pretrained("models/diffmamba-mini-sample")

In [14]:
# we can use a scheduler with more steps than we trained on (sometimes it gives even better results)
scheduler = EulerAncestralDiscreteScheduler(
    # beta_end = 0.012,
  beta_schedule = "sqrt",
  # beta_start = 0.00085,
  # clip_sample = False,
#   skip_prk_steps = True,
#   set_alpha_to_one = False,
  steps_offset = 0,
#   interpolation_type = "linear",
  prediction_type ="sample", 
  num_train_timesteps = 2000)


## functions

In [3]:

def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    **kwargs,
):
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used,
            `timesteps` must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
                must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

def get_timesteps(num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = scheduler.timesteps[t_start * scheduler.order :]

        return timesteps, num_inference_steps - t_start
        
def vectors_to_indices(vectors):
    indices = torch.argmax(vectors, dim=-1)
    return indices

def sample_text(probabilities, temperature=1.0):
    batch_size, seq_len, vocab_size = probabilities.size()
    flattened_probs = probabilities.view(batch_size * seq_len, -1)
    
    scaled_logits = flattened_probs / temperature
    scaled_probs = F.softmax(scaled_logits, dim=-1)
    
    sampled_indices = torch.multinomial(scaled_probs, 1)
    sampled_token_ids = sampled_indices.view(batch_size, seq_len)
    
    return sampled_token_ids

## Generate

In [6]:
from IPython.display import display, clear_output



with torch.no_grad():
    latents = torch.rand((8, 64, 768), device=device).to(torch.float16)# + torch.rand((8, 64, 768), device=device).to(torch.float16)
    attention_mask = torch.ones((8, 64), device=device)
    num_inference_steps = 2000
    timesteps=None
    timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps)

    for i, t in tqdm(enumerate(timesteps)):
        # if i >= 0.7 * num_inference_steps:
        #     break
        # expand the latents if we are doing classifier free guidance
        latent_model_input =  latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        # rnd_latents = torch.rand((1, 64, 4096), device=device).to(torch.float16)
        # print(latent_model_input.dtype)
        outputs = model(
            input_embeds=latent_model_input,
            timesteps=t.reshape(1,).long().to(device),
            # attention_mask=attention_mask
        )
        noise_pred = outputs.last_hidden_state
        latents_final = outputs.logits
        if i % 10 ==0 :
            clear_output(wait=True)
            display(f"SAMPLES[{i}]--->")
            for n in range(latents_final.shape[0]):
                display(f"{n}    --->    " + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))
            display("---------------")

        step = scheduler.step(noise_pred, t, latents, return_dict=True)#[0]
        latents = step["prev_sample"]


clear_output(wait=True)
display(f"FINAL --->")
for n in range(latents_final.shape[0]):
    display(f"{n}    --->    " + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))
display("---------------")

'FINAL --->'

'0    --->    front medieval castle b +lit throughunx medieval glass in wonder nose, old, evil technology, las outside, concept art lopies portrait run runningvedally k technology human cor #ith flyingroidio in�amiically ::ships magana, Le faces, cables Al C bacon and hunder'

'1    --->    great station sever their male theiromb cynumeillerybuilder down future mouth detailed, detailed sh cat lotrifying w ch secret nost heic vibrant vaporwave a anime, other, stained, award winning, intricate,ho back ray, crossble'

'2    --->    “ subonaut +ray of one b android egg by red head ever ang brain town bast D&Dading detailed x robot medieval shell ranger beth trending on lant, serious which in neurukaach da E bar detailed des'

'3    --->    portraitop great concept artoro funings battle rings do unft bus, blood Eons, evil material, hand, energy hands, intricate,ely detailed, concept art, tallethantly'

'4    --->    Character portrait of daana E� In dri harate mobilestrhouse, death wars disoch diunt, ivy, scar fabric, my The hard sculptly down, sl eye, simpleity,ocharp expression, large texture, tallnd enVual, silcing Render, unreal enginenelike'

'5    --->    Characteret group of an below medieval + battle brainag +ized downomb itne I modernay form new b.houral dis scientist,atureemy legws, 26 ill tall open. solarpunk,rom design, turn soft lightning. HD'

'6    --->    aty winter abstract otherphd pet people dawn cowath inside in the markx gu oneage, ser downocal accurate architectumn advent planetions sed, beautiful full long shot, animals hair, mark halfck, pen distance, blood time, artgerm, yoshida, lotenderite resels))'

'7    --->    above of indhead room areie landscapeastn rott + that pe your outsideoraocal earhip bow'

'---------------'