In [1]:
import torch
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from tqdm.auto import tqdm
import numpy as np
import torch.nn.functional as F
import math

from transformers import AutoTokenizer, BertForMaskedLM
from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler

import numpy as np
import matplotlib.pyplot as plt


from src.denoisers.modeling_diffmamba import DiffMambaForDiffusionLM
from src.decoders.bert_decoder import BertLMHeadModel
from src.schedulers.euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from src.schedulers.ddpm import DDPMScheduler

    

    
# model(inputs_embeds=inputs_embeds, timesteps=timesteps).logits.shape

  from .autonotebook import tqdm as notebook_tqdm


[2023-12-11 01:06:13,065] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:

path = "models/diffmamba-mini-sample-trained"
tokenizer = AutoTokenizer.from_pretrained(path, subfolder="tokenizer")
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(path, subfolder="scheduler")#DDIMScheduler(prediction_type="sample", num_train_timesteps=2000)
model = DiffMambaForDiffusionLM.from_pretrained(path, torch_dtype=torch.float16, subfolder="denoiser").to("cuda")
decoder = BertLMHeadModel.from_pretrained(path, torch_dtype=torch.float16, subfolder="decoder").to("cuda")

device = model.device


cross_attention False


In [4]:
scheduler.num_train_timesteps

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


1200

## functions

In [3]:

def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    **kwargs,
):
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used,
            `timesteps` must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
                must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

def get_timesteps(num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = scheduler.timesteps[t_start * scheduler.order :]

        return timesteps, num_inference_steps - t_start
        
def vectors_to_indices(vectors):
    indices = torch.argmax(vectors, dim=-1)
    return indices

def sample_text(probabilities, temperature=1.0):
    batch_size, seq_len, vocab_size = probabilities.size()
    flattened_probs = probabilities.view(batch_size * seq_len, -1)
    
    scaled_logits = flattened_probs / temperature
    scaled_probs = F.softmax(scaled_logits, dim=-1)
    
    sampled_indices = torch.multinomial(scaled_probs, 1)
    sampled_token_ids = sampled_indices.view(batch_size, seq_len)
    
    return sampled_token_ids

## Generate

In [18]:
from IPython.display import display, clear_output



with torch.no_grad():
    latents = torch.rand((2, 64, 768), device=device).to(torch.float16)# + torch.rand((8, 128, 768), device=device).to(torch.float16)
    attention_mask = torch.ones((2, 64), device=device)
    num_inference_steps = scheduler.num_train_timesteps
    timesteps=None
    timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps)

    for i, t in tqdm(enumerate(timesteps)):
        # if i >= 0.7 * num_inference_steps:
        #     break
        # expand the latents if we are doing classifier free guidance
        latent_model_input =  latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        # rnd_latents = torch.rand((1, 64, 4096), device=device).to(torch.float16)
        # print(latent_model_input.dtype)
        outputs = model(
            input_embeds=latent_model_input,
            timesteps=t.reshape(1,).long().to(device),
            # attention_mask=attention_mask
        )
        noise_pred = outputs.last_hidden_state
        latents_final = outputs.logits
        if i % 10 ==0 :
            clear_output(wait=True)
            display(f"SAMPLES[{i}]--->")
            for n in range(latents_final.shape[0]):
                display(f"{n}    --->    " + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))
            display("---------------")

        step = scheduler.step(noise_pred, t, latents, return_dict=True)#[0]
        latents = step["prev_sample"]


clear_output(wait=True)
display(f"FINAL --->")
for n in range(latents_final.shape[0]):
    display(f"{n}    --->    " + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))
display("---------------")

'FINAL --->'

'0    --->    germut of on body concept, greg r and and b, hyper realistic, asantant e concepty, h by j her byineer and by digital, intr on on artstation, h,, Art, engine, style'

'1    --->    a of- hairistic, highly dal hair, 3 fantasy, intricate, elegant loeteicate intricate detailed landscape l a wicateu, h details and hingane by by female byseonaing vsephonin, h,'

'---------------'

In [19]:
# model.eval()
# inputs = tokenizer(["Today is"], return_tensors="pt")
# # print(inputs.input_ids)
# bsz, seq_ln = inputs.input_ids.shape

# latents = torch.rand((1, 64, 768), device=device).to(torch.float16)
encoder_hidden_states = latents
decoder_input_ids = [0]
predicted_ids = []
for i in range(64): 
    outputs = decoder(input_ids=torch.tensor(([decoder_input_ids])).to(model.device), encoder_hidden_states=encoder_hidden_states)
    print(outputs.logits.shape)
    logits = outputs.logits[:,i,:]
    # perform argmax on the last dimension (i.e. greedy decoding)
    predicted_id = logits.argmax(-1)
    print(predicted_id[0].item())
    predicted_ids.append(predicted_id[0].item())
    print(tokenizer.decode([predicted_id[0].squeeze()]))
    # add predicted id to decoder_input_ids
    decoder_input_ids = decoder_input_ids + [predicted_id[0].item()]
print(tokenizer.decode(predicted_ids))

torch.Size([2, 1, 32001])
1395
gre
torch.Size([2, 2, 32001])
29887
g
torch.Size([2, 3, 32001])
364
r
torch.Size([2, 4, 32001])
329
ut
torch.Size([2, 5, 32001])
7000
kow
torch.Size([2, 6, 32001])
2574
ski
torch.Size([2, 7, 32001])
322
and
torch.Size([2, 8, 32001])
394
al
torch.Size([2, 9, 32001])
17607
phon
torch.Size([2, 10, 32001])
344
se
torch.Size([2, 11, 32001])
1568
much
torch.Size([2, 12, 32001])
29874
a
torch.Size([2, 13, 32001])
29892
,
torch.Size([2, 14, 32001])
330
g
torch.Size([2, 15, 32001])
677
low
torch.Size([2, 16, 32001])
292
ing
torch.Size([2, 17, 32001])
29892
,
torch.Size([2, 18, 32001])
330
g
torch.Size([2, 19, 32001])
677
low
torch.Size([2, 20, 32001])
292
ing
torch.Size([2, 21, 32001])
26068
lights
torch.Size([2, 22, 32001])
29892
,
torch.Size([2, 23, 32001])
534
tr
torch.Size([2, 24, 32001])
2548
ending
torch.Size([2, 25, 32001])
373
on
torch.Size([2, 26, 32001])
1616
art
torch.Size([2, 27, 32001])
19569
station
torch.Size([2, 28, 32001])
29892
,
torch.Size([2, 2

In [12]:
# model.eval()
# inputs = tokenizer(["Today is"], return_tensors="pt")
# # print(inputs.input_ids)
encoder_hidden_states = latents

decoder_input_ids = [0]
predicted_ids = []
for i in range(64): 
    outputs = decoder(input_ids=torch.tensor([decoder_input_ids]).to(model.device), encoder_hidden_states=encoder_hidden_states)
    print(outputs.logits.shape)
    logits = outputs.logits[:, i, :]
    # Handling 32000 token
    argmax_value = logits.argmax(-1)
    top_logits, top_indices = logits.topk(2, dim=-1)
    
    print(argmax_value.item(), top_indices)
    predicted_id = argmax_value.item() if argmax_value.item() != 32000 else top_indices[0][1].item()
    print(predicted_id)
    predicted_ids.append(predicted_id)
    print(tokenizer.decode([predicted_id]))
    # add predicted id to decoder_input_ids
    decoder_input_ids = decoder_input_ids + [predicted_id]
print(tokenizer.decode(predicted_ids))

torch.Size([1, 1, 32001])
385 tensor([[385, 319]], device='cuda:0')
385
an
torch.Size([1, 2, 32001])
603 tensor([[ 603, 8678]], device='cuda:0')
603
ime
torch.Size([1, 3, 32001])
21760 tensor([[21760,  1820]], device='cuda:0')
21760
portrait
torch.Size([1, 4, 32001])
29892 tensor([[29892,   310]], device='cuda:0')
29892
,
torch.Size([1, 5, 32001])
5094 tensor([[5094, 1472]], device='cuda:0')
5094
cy
torch.Size([1, 6, 32001])
495 tensor([[  495, 14203]], device='cuda:0')
495
ber
torch.Size([1, 7, 32001])
29886 tensor([[29886,  1212]], device='cuda:0')
29886
p
torch.Size([1, 8, 32001])
2960 tensor([[ 2960, 29892]], device='cuda:0')
2960
unk
torch.Size([1, 9, 32001])
29892 tensor([[29892, 29889]], device='cuda:0')
29892
,
torch.Size([1, 10, 32001])
5094 tensor([[5094, 1472]], device='cuda:0')
5094
cy
torch.Size([1, 11, 32001])
495 tensor([[  495, 14203]], device='cuda:0')
495
ber
torch.Size([1, 12, 32001])
29886 tensor([[29886,  1212]], device='cuda:0')
29886
p
torch.Size([1, 13, 32001])
