In [8]:
import torch
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from tqdm.auto import tqdm
import numpy as np
import torch.nn.functional as F
import math

from transformers import AutoTokenizer, BertForMaskedLM
from diffusers import DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler

import numpy as np
import matplotlib.pyplot as plt


from src.denoisers.modeling_diffmamba import DiffMambaForDiffusionLM
from src.decoders.bert_decoder import BertLMHeadModel
from src.schedulers.euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from src.schedulers.ddpm import DDPMScheduler

    

    
# model(inputs_embeds=inputs_embeds, timesteps=timesteps).logits.shape

In [9]:

path = "models/diffmamba-mini-sample-trained"
tokenizer = AutoTokenizer.from_pretrained(path, subfolder="tokenizer")
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(path, subfolder="scheduler")#DDIMScheduler(prediction_type="sample", num_train_timesteps=2000)
model = DiffMambaForDiffusionLM.from_pretrained(path, torch_dtype=torch.float16, subfolder="denoiser").to("cuda")
decoder = BertLMHeadModel.from_pretrained(path, torch_dtype=torch.float16, subfolder="decoder").to("cuda")

device = model.device


cross_attention False


In [4]:
scheduler.num_train_timesteps

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


1200

## functions

In [3]:

def retrieve_timesteps(
    scheduler,
    num_inference_steps: Optional[int] = None,
    device: Optional[Union[str, torch.device]] = None,
    timesteps: Optional[List[int]] = None,
    **kwargs,
):
    """
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used,
            `timesteps` must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
                must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    if timesteps is not None:
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        timesteps = scheduler.timesteps
        num_inference_steps = len(timesteps)
    else:
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        timesteps = scheduler.timesteps
    return timesteps, num_inference_steps

def get_timesteps(num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = scheduler.timesteps[t_start * scheduler.order :]

        return timesteps, num_inference_steps - t_start
        
def vectors_to_indices(vectors):
    indices = torch.argmax(vectors, dim=-1)
    return indices

def sample_text(probabilities, temperature=1.0):
    batch_size, seq_len, vocab_size = probabilities.size()
    flattened_probs = probabilities.view(batch_size * seq_len, -1)
    
    scaled_logits = flattened_probs / temperature
    scaled_probs = F.softmax(scaled_logits, dim=-1)
    
    sampled_indices = torch.multinomial(scaled_probs, 1)
    sampled_token_ids = sampled_indices.view(batch_size, seq_len)
    
    return sampled_token_ids

## Generate

In [12]:
from IPython.display import display, clear_output



with torch.no_grad():
    latents = torch.rand((1, 64, 768), device=device).to(torch.float16)# + torch.rand((8, 128, 768), device=device).to(torch.float16)
    attention_mask = torch.ones((1, 64), device=device)
    num_inference_steps = scheduler.num_train_timesteps // 1
    timesteps=None
    timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, timesteps)

    for i, t in tqdm(enumerate(timesteps)):
        # if i >= 0.7 * num_inference_steps:
        #     break
        # expand the latents if we are doing classifier free guidance
        latent_model_input =  latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
        # rnd_latents = torch.rand((1, 64, 4096), device=device).to(torch.float16)
        # print(latent_model_input.dtype)
        outputs = model(
            input_embeds=latent_model_input,
            timesteps=t.reshape(1,).long().to(device),
            # attention_mask=attention_mask
        )
        noise_pred = outputs.last_hidden_state
        latents_final = outputs.logits
        if i % 10 ==0 :
            clear_output(wait=True)
            display(f"SAMPLES[{i}]--->")
            for n in range(latents_final.shape[0]):
                display(f"{n}    --->    " + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))
            display("---------------")

        step = scheduler.step(noise_pred, t, latents, return_dict=True)#[0]
        latents = step["prev_sample"]


clear_output(wait=True)
display(f"FINAL --->")
for n in range(latents_final.shape[0]):
    display(f"{n}    --->    " + tokenizer.decode(vectors_to_indices(latents_final[n]), skip_special_tokens=True))
display("---------------")

'FINAL --->'

"0    --->    inov and focusedat of castleks of the of, dirt, T card, femaleom, of, lighting, gianated crow moon, bl, aically in pose,  detailed, articallyles, in front ult - realistic, cellical's  2"

'---------------'

In [13]:
# model.eval()
# inputs = tokenizer(["Today is"], return_tensors="pt")
# # print(inputs.input_ids)
# bsz, seq_ln = inputs.input_ids.shape

# encoder_hidden_states = torch.rand((1, 64, 768), device=device).to(torch.float16)
# print(encoder_hidden_states)
encoder_hidden_states = latents
decoder_input_ids = [0]
predicted_ids = []
for i in range(64): 
    outputs = decoder(input_ids=torch.tensor(([decoder_input_ids])).to(model.device), encoder_hidden_states=encoder_hidden_states)
    print(outputs.logits.shape)
    logits = outputs.logits[:,i,:]
    # perform argmax on the last dimension (i.e. greedy decoding)
    predicted_id = logits.argmax(-1)
    print(predicted_id[0].item())
    predicted_ids.append(predicted_id[0].item())
    print(tokenizer.decode([predicted_id[0].squeeze()]))
    # add predicted id to decoder_input_ids
    decoder_input_ids = decoder_input_ids + [predicted_id[0].item()]
print(tokenizer.decode(predicted_ids))

torch.Size([1, 1, 32001])
262
in
torch.Size([1, 2, 32001])
586
ov
torch.Size([1, 3, 32001])
322
and
torch.Size([1, 4, 32001])
5796
ru
torch.Size([1, 5, 32001])
1312
ined
torch.Size([1, 6, 32001])
310
of
torch.Size([1, 7, 32001])
3105
fut
torch.Size([1, 8, 32001])
332
ur
torch.Size([1, 9, 32001])
4695
istic
torch.Size([1, 10, 32001])
29892
,
torch.Size([1, 11, 32001])
270
d
torch.Size([1, 12, 32001])
2728
irt
torch.Size([1, 13, 32001])
29892
,
torch.Size([1, 14, 32001])
11266
hyper
torch.Size([1, 15, 32001])
29881
d
torch.Size([1, 16, 32001])
29892
,
torch.Size([1, 17, 32001])
4940
past
torch.Size([1, 18, 32001])
295
el
torch.Size([1, 19, 32001])
29892
,
torch.Size([1, 20, 32001])
12726
rim
torch.Size([1, 21, 32001])
3578
light
torch.Size([1, 22, 32001])
292
ing
torch.Size([1, 23, 32001])
29892
,
torch.Size([1, 24, 32001])
3578
light
torch.Size([1, 25, 32001])
292
ing
torch.Size([1, 26, 32001])
29892
,
torch.Size([1, 27, 32001])
330
g
torch.Size([1, 28, 32001])
713
ian
torch.Size([1, 29

In [7]:
# model.eval()
# inputs = tokenizer(["Today is"], return_tensors="pt")
# # print(inputs.input_ids)
encoder_hidden_states = latents

decoder_input_ids = [0]
last_pred = 0
predicted_ids = []
for i in range(64): 
    outputs = decoder(input_ids=torch.tensor([decoder_input_ids]).to(model.device), encoder_hidden_states=encoder_hidden_states)
    print(outputs.logits.shape)
    logits = outputs.logits[:, i, :]
    # Handling 32000 token
    argmax_value = logits.argmax(-1)
    top_logits, top_indices = logits.topk(2, dim=-1)
    
    print(argmax_value.item(), top_indices)
    predicted_id = argmax_value.item() if argmax_value.item() != 32000 and argmax_value.item() != last_pred else top_indices[0][1].item()
    last_pred = predicted_id
    print(predicted_id)
    predicted_ids.append(predicted_id)
    print(tokenizer.decode([predicted_id]))
    # add predicted id to decoder_input_ids
    decoder_input_ids = decoder_input_ids + [predicted_id]
print(tokenizer.decode(predicted_ids))

torch.Size([1, 1, 32001])
530 tensor([[  530, 12969]], device='cuda:0')
530
An
torch.Size([1, 2, 32001])
310 tensor([[  310, 15566]], device='cuda:0')
310
of
torch.Size([1, 3, 32001])
263 tensor([[ 263, 5765]], device='cuda:0')
263
a
torch.Size([1, 4, 32001])
310 tensor([[ 310, 6559]], device='cuda:0')
310
of
torch.Size([1, 5, 32001])
263 tensor([[263, 278]], device='cuda:0')
263
a
torch.Size([1, 6, 32001])
310 tensor([[  310, 24870]], device='cuda:0')
310
of
torch.Size([1, 7, 32001])
263 tensor([[263, 670]], device='cuda:0')
263
a
torch.Size([1, 8, 32001])
310 tensor([[  310, 24870]], device='cuda:0')
310
of
torch.Size([1, 9, 32001])
263 tensor([[263, 347]], device='cuda:0')
263
a
torch.Size([1, 10, 32001])
310 tensor([[  310, 24870]], device='cuda:0')
310
of
torch.Size([1, 11, 32001])
263 tensor([[263, 347]], device='cuda:0')
263
a
torch.Size([1, 12, 32001])
310 tensor([[  310, 21760]], device='cuda:0')
310
of
torch.Size([1, 13, 32001])
263 tensor([[263, 347]], device='cuda:0')
263
a