# Diffusion Model for Open Dialogue

Simplified version

References
- DiffuSeq (cited below)

Adapted from:

[1] @inproceedings{gong2022diffuseq,
  author = {Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng},
  booktitle = {International Conference on Learning Representations, ICLR},
  title = {{DiffuSeq}: Sequence to Sequence Text Generation with Diffusion Models},
  year = 2023
}

[2] @article{gong2023diffuseqv2,
  title={DiffuSeq-v2: Bridging Discrete and Continuous Text Spaces for Accelerated Seq2Seq Diffusion Models},
  author={Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng},
  journal={arXiv preprint arXiv:2310.05793},
  year={2023}
}

## Dataset

- Use Commonsense Conversation dataset (from Reddit)


in diffuseq text_datasets.py some steps to load the dataset itself

- [ ] prepare datasets for training and validation in the format (stored as jsonl file?)
```
{"src": "", "train": ""}
```

- word embeddings (to be loaded?)
- use a corpus

## Dataset

- Use Commonsense Conversation dataset (from Reddit)


in diffuseq text_datasets.py some steps to load the dataset itself

- [ ] prepare datasets for training and validation in the format (stored as jsonl file?)
```
{"src": "", "train": ""}
```

- word embeddings (to be loaded?)
- use a corpus
## Training

Note that, in DiffuSeq, a model file is created to store all training progress, configuration etc. (in bash format poitning to raw files?)

- denoise rate ?
- using updates in v2 diffuseq took it from 2 days -> 11 hr learning time

Set parameters

DiffuSeq [1]
12 layers of Transformer with 12 attention
heads, where the time step embedding is plugged akin to the position embedding. 
The
maximum sequence length is 128, with embedding dimension d = 128, diffusion steps T = 2;000
and a square-root noise schedule. To reduce the out-of-vocabulary generation, we apply Byte Pair
Encoding (Sennrich et al., 2016) to construct the vocabulary.

In [2]:
import numpy as np
import torch
from transformers import BertTokenizer
from datasets import Dataset
import json
import datasets
from torch.utils.data import Dataset, DataLoader, RandomSampler
import torch as th
from datasets import Dataset, DatasetDict
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# choose embedding dimension = 128
embedding_dim = 128

# hidden size of time embedding
hidden_dim = 128 

# :param seq_len: the max sequence length (one-side).
seq_len = 128 

# TODO good value for this
output_dims = 128

# Same as diffuSeq
num_diffusion_timesteps = 500

lr=1e-04

# TODO figure out what are the right params to recreate diffuSeq
batch_size = 10
lr = 0.001 # learning rate
ema_rate = 0.999
weight_decay = 0.01
learning_steps = 1000


# use GPU if available
is_cuda = torch.cuda.is_available()

if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")
    

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size
print(vocab_size)

GPU is available
30522


Get tokenizer from BERT <br>
Embedding function $EMB(w)$ to map the discrete text $w$ into a continuous space. <br>
Load the sample text data from file

In [5]:
model_emb = torch.nn.Embedding(tokenizer.vocab_size, embedding_dim)

# initialize random embeddings
torch.nn.init.normal_(model_emb.weight)

model_emb

Embedding(30522, 128)

In [6]:
# read in the data in training data json file 
# TODO do this in a different way 
# TODO load actual dataset from Amazon

# data_dir = "./datasets/sample"
# path = f'{data_dir}/train.jsonl'

# sentence_lst = {'src':[], 'trg': []}
# with open(path, 'r') as f_reader:
#         for row in f_reader:
#             content = json.loads(row)
#             sentence_lst['src'].append(content['src'].strip())
#             sentence_lst['trg'].append(content['trg'].strip())

# TODO use pandas to load faster? any other package can just load json directly rather than row by row

data_dir = "./datasets"
train_path = f'{data_dir}/train_full.jsonl'
valid_path = f'{data_dir}/valid_full.jsonl'
test_path = f'{data_dir}/test_full.jsonl'

def load_data(path, limit=None):
    sentence_lst = {'src':[], 'trg': []}
    with open(path, 'r') as f_reader:
        for i, row in enumerate(f_reader):
            if limit and i >= limit:
                break
            content = json.loads(row)
            sentence_lst['src'].append(content['src'].strip())
            sentence_lst['trg'].append(content['trg'].strip())
    return sentence_lst

# Load datasets with size restriction
train_limit = None  # Limit the size of the training set
valid_limit = None   # Limit the size of the validation set
test_limit = None    # Limit the size of the test set

train_data = load_data(train_path, limit=train_limit)
valid_data = load_data(valid_path, limit=valid_limit)
test_data = load_data(test_path, limit=test_limit)

print("Training Data Samples:", len(train_data['src']))
print("Validation Data Samples:", len(valid_data['src']))
print("Test Data Samples:", len(test_data['src']))


Training Data Samples: 3382137
Validation Data Samples: 2048
Test Data Samples: 10000


In [7]:
train_data['src'][0]
raw_datasets = Dataset.from_dict(train_data)
raw_datasets
raw_datasets[0]

{'src': 'jesus , what kind of concerts do you go to where people sucker punch you for being born tall ?',
 'trg': 'the kind that allow bitter short people in . so basically all of them .'}

In [None]:
# Tokenize dataset
# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size
print("Vocabulary Size:", vocab_size)

def tokenize_function(examples, tokenizer):
    input_id_x = tokenizer(examples['src'], add_special_tokens=True)['input_ids']
    input_id_y = tokenizer(examples['trg'], add_special_tokens=True)['input_ids']
    result_dict = {'input_id_x': input_id_x, 'input_id_y': input_id_y}
    return result_dict

# Use partial to pass the tokenizer to the tokenize_function
tokenize_function_with_tokenizer = partial(tokenize_function, tokenizer=tokenizer)

# Create datasets
train_dataset = Dataset.from_dict(train_data)
valid_dataset = Dataset.from_dict(valid_data)
test_dataset = Dataset.from_dict(test_data)

# Tokenize datasets
tokenized_train_dataset = train_dataset.map(
    tokenize_function_with_tokenizer,
    batched=True,
    num_proc=4,
    remove_columns=['src', 'trg'],
    load_from_cache_file=True,
    desc="Tokenizing training dataset",
)

tokenized_valid_dataset = valid_dataset.map(
    tokenize_function_with_tokenizer,
    batched=True,
    num_proc=4,
    remove_columns=['src', 'trg'],
    load_from_cache_file=True,
    desc="Tokenizing validation dataset",
)

tokenized_test_dataset = test_dataset.map(
    tokenize_function_with_tokenizer,
    batched=True,
    num_proc=4,
    remove_columns=['src', 'trg'],
    load_from_cache_file=True,
    desc="Tokenizing test dataset",
)

# Combine into DatasetDict
tokenized_datasets = DatasetDict({
    'train': tokenized_train_dataset,
    'validation': tokenized_valid_dataset,
    'test': tokenized_test_dataset
})

print("Tokenization complete.")
print("Training Set:", len(tokenized_datasets['train']))
print("Validation Set:", len(tokenized_datasets['validation']))
print("Test Set:", len(tokenized_datasets['test']))

# tokenized_datasets
# len(tokenized_datasets['train']["input_id_x"][0])

Vocabulary Size: 30522


Tokenizing training dataset (num_proc=4):  30%|██████▎              | 1016000/3382137 [02:58<06:07, 6442.28 examples/s]

In [None]:
# Function to merge and mask sequences
def merge_and_mask(group_lst):
    lst = []
    mask = []
    for i in range(len(group_lst['input_id_x'])):
        end_token = group_lst['input_id_x'][i][-1]
        src = group_lst['input_id_x'][i][:-1]
        trg = group_lst['input_id_y'][i][:-1]
        while len(src) + len(trg) > seq_len - 3:
            if len(src) > len(trg):
                src.pop()
            elif len(src) < len(trg):
                trg.pop()
            else:
                src.pop()
                trg.pop()
        src.append(end_token)
        trg.append(end_token)

        lst.append(src + [tokenizer.sep_token_id] + trg)
        mask.append([0] * (len(src) + 1))
    group_lst['input_ids'] = lst
    group_lst['input_mask'] = mask
    return group_lst

# Function to pad sequences
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
    result = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
    mask_ = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
    for i, example in enumerate(examples):
        curr_len = min(len(example), max_length)
        result[i][:curr_len] = example[:curr_len]
        mask_[i][:curr_len] = [1] * curr_len
    if return_mask:
        return result, mask_
    return result

def pad_function(group_lst):
    max_length = seq_len
    group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], tokenizer.pad_token_id, max_length)
    group_lst['input_mask'] = _collate_batch_helper(group_lst['input_mask'], 1, max_length)
    return group_lst

# Apply merge and mask to the tokenized datasets
tokenized_datasets = DatasetDict({
    'train': tokenized_train_dataset,
    'validation': tokenized_valid_dataset,
    'test': tokenized_test_dataset
})

tokenized_datasets = tokenized_datasets.map(
    merge_and_mask,
    batched=True,
    num_proc=1,
    desc="Merging and masking"
)

# Apply padding to the datasets
lm_datasets = tokenized_datasets.map(
    pad_function,
    batched=True,
    num_proc=1,
    desc="Padding"
)

print("Merging, masking, and padding complete.")
print("Training Set:", len(lm_datasets['train']))
print("Validation Set:", len(lm_datasets['validation']))
print("Test Set:", len(lm_datasets['test']))

print(lm_datasets, 'padded dataset')

In [None]:
# Combine into DatasetDict
raw_datasets = datasets.DatasetDict()
raw_datasets['train'] = lm_datasets['train']
raw_datasets['validation'] = lm_datasets['validation']
raw_datasets['test'] = lm_datasets['test']

class TextDataset(Dataset):
    def __init__(self, text_datasets, split, model_emb=None):
        super().__init__()
        self.text_datasets = text_datasets[split]
        self.length = len(self.text_datasets)
        self.model_emb = model_emb

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with torch.no_grad():
            input_ids = self.text_datasets[idx]['input_ids']
            hidden_state = self.model_emb(torch.tensor(input_ids))

            arr = np.array(hidden_state, dtype=np.float32)

            out_kwargs = {}
            out_kwargs['input_ids'] = np.array(self.text_datasets[idx]['input_ids'])
            out_kwargs['input_mask'] = np.array(self.text_datasets[idx]['input_mask'])

            return arr, out_kwargs

# Define model embedding
model_emb = lambda x: x  # Placeholder: Replace with actual model embedding function

# Create datasets for training, validation, and test sets
train_dataset = TextDataset(raw_datasets, 'train', model_emb=model_emb)
valid_dataset = TextDataset(raw_datasets, 'validation', model_emb=model_emb)
test_dataset = TextDataset(raw_datasets, 'test', model_emb=model_emb)

# Create data loaders with RandomSampler
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=RandomSampler(train_dataset))
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, sampler=RandomSampler(valid_dataset))
test_loader = DataLoader(test_dataset, batch_size=batch_size, sampler=RandomSampler(test_dataset))

def infinite_data_loader(data_loader, device):
    while True:
        for batch, cond in data_loader:
            batch = batch.to(device)
            cond = {k: v.to(device) for k, v in cond.items()}
            yield batch, cond
            
# Convert the data loaders to infinite data loaders           
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = infinite_data_loader(train_loader, device)  # Ensure train_loader returns batches on the correct device
valid_loader = infinite_data_loader(valid_loader, device)  # Ensure valid_loader returns batches on the correct device

# Sample data from the infinite data loaders
train_data_iter = iter(train_loader)
valid_data_iter = iter(valid_loader)
test_data_iter = iter(test_loader)  # Test data is usually not infinite

print("Sample from train dataset:", next(train_data_iter))
print("Sample from validation dataset:", next(valid_data_iter))
print("Sample from test dataset:", next(test_data_iter))


In [None]:
import math 
import torch as th

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    device = timesteps.device
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32, device=device) / half
    )
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)

    return embedding


In [None]:
# FIXME need to define TransformerNetModel
#  The full Transformer model with attention and timestep embedding.

# Adapted from diffuSeq

# TODO code the transformer from scratch
# TODO design the transformer from scratch

from transformers import BertConfig, BertModel
import torch.nn as nn
import torch as th

class TransformerNetModel(nn.Module):
    def __init__(self, vocab_size, input_dims, hidden_t_dim, output_dims):
        super().__init__()

        config = BertConfig.from_pretrained('bert-base-uncased')
        # FIXME set to an actual value
        config.hidden_dropout_prob = 0

        self.input_dims = input_dims
        self.hidden_t_dim = hidden_t_dim
        self.output_dims = output_dims
        # self.dropout = dropout
        self.hidden_size = config.hidden_size
#         print("transformer self.hidden_size", self.hidden_size)

        self.word_embedding = nn.Embedding(vocab_size, self.input_dims)
        # Generate logits for hidden representation
        self.lm_head = nn.Linear(self.input_dims, vocab_size)
        with th.no_grad():
            self.lm_head.weight = self.word_embedding.weight

        # Time embeddings
        time_embed_dim = hidden_t_dim * 4
        self.time_embed = nn.Sequential(
            # params as input features * output features
            nn.Linear(hidden_t_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, config.hidden_size),
        )

        # Function to deal with having a hidden size not equal to input size, project to hidden size (?)
        if self.input_dims != config.hidden_size:
            # NOTE input_dims = 128
            # hidden_size = 768
            # self.input_up_proj = nn.Sequential(nn.Linear(input_dims, config.hidden_size),
            #                                 nn.Tanh(), nn.Linear(config.hidden_size, config.hidden_size))
            # FIXME this is trying to convert to hidden size 768 why??? it's already in the hidden siz 
            # FIXME this actually doesn't seeem necessary to do????
            self.input_up_proj = nn.Sequential(nn.Linear(config.hidden_size, input_dims),
                                             nn.Tanh(), nn.Linear(input_dims, input_dims))

        # FIXME why is this temporary 
        temp_bert = BertModel.from_pretrained('bert-base-uncased', config=config)
        self.word_embedding = temp_bert.embeddings.word_embeddings
       
       # FIXME why do we do this 2 times????
        with th.no_grad():
            self.lm_head.weight = self.word_embedding.weight
        # self.lm_head.weight.requires_grad = False
        # self.word_embedding.weight.requires_grad = False
            
        # TODO explain what is happening
        self.input_transformers = temp_bert.encoder
        # TODO explain what is doing
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embeddings = temp_bert.embeddings.position_embeddings
        self.LayerNorm = temp_bert.embeddings.LayerNorm
     
        del temp_bert.embeddings
        del temp_bert.pooler

        # FIXME When does this get used
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # FIXME what is happening here
        if self.output_dims != config.hidden_size:
            self.output_down_proj = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
                                                nn.Tanh(), nn.Linear(config.hidden_size, self.output_dims))
    
    def get_embeds(self, input_ids):
        return self.word_embedding(input_ids)

    def get_logits(self, hidden_repr):
    # NOTE we make a simplifying assumption get the logits from linear layer
        return self.lm_head(hidden_repr)
                

    # FIXME what is the difference btween BertModel, BertConfig, BertTokenizer, maybe define it all in one place config for tokenizer + embeddings?


    def forward(self, x, timesteps):
        emb_t = self.time_embed(timestep_embedding(timesteps.to(self.time_embed[0].weight.device), self.hidden_t_dim))  # Ensure timesteps are on the same device
        emb_x = x
        seq_length = x.size(1)
        position_ids = self.position_ids[:, :seq_length].to(x.device)  # Ensure position_ids are on the same device
        emb_inputs = self.position_embeddings(position_ids) + emb_x + emb_t.unsqueeze(1).expand(-1, seq_length, -1)
        emb_inputs = self.dropout(self.LayerNorm(emb_inputs))
        input_trans_hidden_states = self.input_transformers(emb_inputs).last_hidden_state
        h = input_trans_hidden_states
        h = h.type(x.dtype)
        return h

In [None]:
# TODO explore other simplistic sample code
# https://github.com/lucidrains/denoising-diffusion-pytorch
# https://e-dorigatti.github.io/math/deep%20learning/2023/06/25/diffusion.html
# https://github.com/tanelp/tiny-diffusion
# NOTE adapted from diffuSeq, which is adapted from https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42

class GaussianDiffusion():
    def __init__(self, betas, predict_xstart=True):
        self.predict_xstart = predict_xstart
        self.rescale_timesteps = True

        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, "betas must be 1-D"
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

        alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * np.sqrt(alphas)
            / (1.0 - self.alphas_cumprod)
        )

    # NOTE the below comments from diffuSeq
    # self.mapping_func = None # implement in train main()
    # self.add_mask_noise = False # TODO

    # FIXME copied directly from diffuSeq
    def training_losses(self, model, *args, **kwargs):
        self.model = model
        return self.training_losses_seq2seq(model, *args, **kwargs)
    
    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        device = x_t.device
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape, device) * x_t
            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape, device) * eps
        )

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        device = x_t.device
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape, device) * x_t
            - pred_xstart
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape, device)

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t
    
    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).

        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        device = x_start.device
        mean = (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape, device) * x_start
        )
        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape, device)
        log_variance = _extract_into_tensor(
            self.log_one_minus_alphas_cumprod, t, x_start.shape, device
        )
        return mean, variance, log_variance


    def q_sample(self, x_start, t, noise=None, mask=None):
        """
        Diffuse the data for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0).

        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :param mask: anchoring masked position
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        assert noise.shape == x_start.shape
        device = x_start.device
        x_t = (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape, device) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape, device) * noise
        )

        mask = torch.broadcast_to(mask.unsqueeze(dim=-1), x_start.shape).to(device)
        return torch.where(mask==0, x_start, x_t)

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior: 
            q(x_{t-1} | x_t, x_0)

        """
        device = x_start.device
        assert x_start.shape == x_t.shape
        posterior_mean = (
            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape, device) * x_start
            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape, device) * x_t
        )
        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape, device)
        posterior_log_variance_clipped = _extract_into_tensor(
            self.posterior_log_variance_clipped, t, x_t.shape, device
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_sample(
        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
        top_p=None, mask=None, x_start=None,
    ):
        """
        Sample x_{t-1} from the model at the given timestep.

        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param mask: anchoring masked position to x_start
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        device = x.device
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        if top_p is not None and top_p > 0:
            # print('top_p sampling')
            noise = th.randn_like(x)
            replace_mask = th.abs(noise) > top_p
            while replace_mask.any():
                noise[replace_mask] = th.randn_like(noise[replace_mask])
                replace_mask = th.abs(noise) > top_p
            assert (th.abs(noise) <= top_p).all()

        else:
            noise = th.randn_like(x)

        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        if mask is None:
            pass
        else:
            sample = th.where(mask==0, x_start, sample)

        return {
            "sample": sample, 
            "pred_xstart": out["pred_xstart"],
            "greedy_mean": out["mean"], 
            "out": out
        }

    def p_mean_variance(
        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
    ):
        """
        Compute the mean and variance for the diffusion posterior at the given timestep.

        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'mean': the posterior mean.
                 - 'variance': the posterior variance.
                 - 'log_variance': the log of the posterior variance.
                 - 'pred_xstart': the predicted x_0.
        """
        device = x.device
        if model_kwargs is None:
            model_kwargs = {}
        model_output = model(x, t, **model_kwargs).to(device)
        if denoised_fn is not None:
            model_output = denoised_fn(model_output)
        pred_xstart = self._predict_xstart_from_eps(x, t, model_output)
        if clip_denoised:
            pred_xstart = pred_xstart.clamp(-1, 1)

        posterior_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(
            x_start=pred_xstart, x_t=x, t=t
        )
        return {
            "mean": posterior_mean,
            "variance": posterior_variance,
            "log_variance": posterior_log_variance,
            "pred_xstart": pred_xstart,
        }

    def p_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        top_p=None,
        clamp_step=None,
        clamp_first=None,
        mask=None,
        x_start=None,
        gap=1,
    ):
        """
        Generate samples from the model.

        :param model: the model module.
        :param shape: the shape of the samples, (N, C, H, W).
        :param noise: if specified, the noise from the encoder to sample.
                      Should be of the same shape as `shape`.
        :param clip_denoised: if True, clip x_start predictions to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param mask: anchoring masked position to x_start
        :param clamp_step: in clamp_first mode, choose end clamp step, otherwise starting clamp step
        :param clamp_first: bool, clamp_first mode
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param device: if specified, the device to create the samples on.
                       If not specified, use a model parameter's device.
        :param progress: if True, show a tqdm progress bar.
        :return: a non-differentiable batch of samples.
        """
        final = []
        for sample in self.p_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            top_p=top_p,
            clamp_step=clamp_step,
            clamp_first=clamp_first,
            mask=mask,
            x_start=x_start
        ):
            final.append(sample['sample'])
        return final

    def p_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        top_p=None,
        clamp_step=None,
        clamp_first=None,
        mask=None,
        x_start=None,
    ):
        """
        Generate samples from the model and yield intermediate samples from
        each timestep of diffusion.

        Arguments are the same as p_sample_loop().
        Returns a generator over dicts, where each dict is the return value of
        p_sample().
        """
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None: # custom your the start point of x_0
            sample_x = noise
        else:
            sample_x = th.randn(*shape, device=device)
        indices = list(range(self.num_timesteps))[::-1]

        if progress:
            # Lazy import so that we don't depend on tqdm.
            from tqdm.auto import tqdm
            indices = tqdm(indices)

        for i in indices: # from T to 0
            t = th.tensor([i] * shape[0], device=device)
            if clamp_step is not None:
                if not clamp_first:
                    if i > clamp_step:
                        denoised_fn_cur = None
                    else:
                        denoised_fn_cur = denoised_fn
                else:
                    if i >= clamp_step:
                        denoised_fn_cur = denoised_fn
                    else:
                        denoised_fn_cur = None
            else:
                denoised_fn_cur = denoised_fn

            with th.no_grad():
                out = self.p_sample(
                    model,
                    sample_x,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn_cur,
                    model_kwargs=model_kwargs,
                    top_p=top_p,
                    mask=mask,
                    x_start=x_start
                )
                yield out
                sample_x = out["sample"]

    def _get_x_start(self, x_start_mean, std):
        '''
        Word embedding projection from {Emb(w)} to {x_0}
        :param x_start_mean: word embedding
        :return: x_0
        '''
        noise = th.randn_like(x_start_mean)
        assert noise.shape == x_start_mean.shape
        return x_start_mean + std * noise

    def _token_discrete_loss(self, x_t, get_logits, input_ids, mask=None, truncate=False, t=None):
        '''
        the loss of -log p(w|z_0)
        :param x_start_mean: word embedding
        :return: x_0
        '''
        reshaped_x_t = x_t
        logits = get_logits(reshaped_x_t)  # bsz, seqlen, vocab

        # Ensure input_ids is a LongTensor
        input_ids = input_ids.long()

        loss_fct = th.nn.CrossEntropyLoss(reduction='none')
        decoder_nll = loss_fct(logits.view(-1, logits.size(-1)), input_ids.view(-1)).view(input_ids.shape)

        if mask is not None:
            decoder_nll *= mask

        if mask is not None:
            decoder_nll = decoder_nll.sum(dim=-1) / mask.sum(dim=-1)
        else:
            decoder_nll = decoder_nll.mean(dim=-1)

        return decoder_nll

    def _x0_helper(self, model_output, x, t):
        device = x.device
        pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
        pred_prev, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
        return {'pred_xprev': pred_prev, 'pred_xstart': pred_xstart}

    def training_losses_seq2seq(self, model, x_start, t, model_kwargs=None, noise=None):
        device = x_start.device
        x_start_fix = x_start.to(device)

        input_ids_x = model_kwargs.pop('input_ids').long().to(device)
        input_ids_mask = model_kwargs.pop('input_mask').to(device)
        x_start_mean = model.get_embeds(input_ids_x).to(device)

        std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, torch.tensor([0]).to(device), x_start_mean.shape, device)
        x_start = self._get_x_start(x_start_mean, std).to(device)

        if noise is None:
            noise = torch.randn_like(x_start).to(device)

        x_t = self.q_sample(x_start, t, noise=noise, mask=input_ids_mask).to(device)

        get_logits = model.get_logits

        terms = {}

        target = x_start
        model_output = model(x_t, self._scale_timesteps(t), **model_kwargs).to(device)
        terms["mse"] = mean_flat((target - model_output) ** 2).to(device)

        model_out_x_start = self._x0_helper(model_output, x_t, t)['pred_xstart'].to(device)
        t0_mask = (t == 0).to(device)
        t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2).to(device)
        terms["mse"] = terms["mse"].to(device)
        terms["mse"] = torch.where(t0_mask, t0_loss, terms["mse"])

        out_mean, _, _ = self.q_mean_variance(x_start, torch.LongTensor([self.num_timesteps - 1]).to(device))
        tT_loss = mean_flat(out_mean ** 2).to(device)

        decoder_nll = self._token_discrete_loss(x_start, get_logits, input_ids_x).to(device)
        terms["nll"] = self._token_discrete_loss(model_out_x_start, get_logits, input_ids_x, mask=input_ids_mask, truncate=True, t=t).to(device)

        terms["loss"] = (terms["mse"] + decoder_nll + tT_loss).to(device)

        return terms


    def ddim_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        eta=0.0,
        langevin_fn=None,
        mask=None,
        x_start=None
    ):
        """
        Sample x_{t-1} from the model using DDIM.

        Same usage as p_sample().
        """
        device = x.device
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        # Usually our model outputs epsilon, but we re-derive it
        # in case we used x_start or x_prev prediction.
        eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape, device)
        alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape, device)
        sigma = (
            eta
            * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
            * th.sqrt(1 - alpha_bar / alpha_bar_prev)
        )
        # Equation 12.
        noise = th.randn_like(x)
        mean_pred = (
            out["pred_xstart"] * th.sqrt(alpha_bar_prev)
            + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
        )
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        sample = mean_pred + nonzero_mask * sigma * noise
        if langevin_fn:
            sample = langevin_fn(sample, mean_pred, sigma, self.alphas_cumprod_prev[t[0]], t, x)
        
        if mask is None:
            pass
        else:
            sample = th.where(mask==0, x_start, sample)
        
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}

    def ddim_reverse_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        eta=0.0,
    ):
        """
        Sample x_{t+1} from the model using DDIM reverse ODE.
        """
        device = x.device
        assert eta == 0.0, "Reverse ODE only for deterministic path"
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        # Usually our model outputs epsilon, but we re-derive it
        # in case we used x_start or x_prev prediction.
        eps = (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape, device) * x
            - out["pred_xstart"]
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape, device)
        alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape, device)

        # Equation 12. reversed
        mean_pred = (
            out["pred_xstart"] * th.sqrt(alpha_bar_next)
            + th.sqrt(1 - alpha_bar_next) * eps
        )

        return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}

    def ddim_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        top_p=None,
        clamp_step=None,
        clamp_first=None,
        mask=None,
        x_start=None,
        gap=1,
    ):
        """
        Generate samples from the model using DDIM.
        :param gap: compute ddim sampling for each {gap} step

        Same usage as p_sample_loop().
        """
        final = []
        for sample in self.ddim_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            mask=mask,
            x_start=x_start,
            gap = gap
        ):
            final.append(sample['sample'])
        return final

    def ddim_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        eta=0.0,
        langevin_fn=None,
        mask=None,
        x_start=None,
        gap=1
    ):
        """
        Use DDIM to sample from the model and yield intermediate samples from
        each timestep of DDIM.

        Same usage as p_sample_loop_progressive().
        """
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None:
            sample_x = noise
        else:
            sample_x = th.randn(*shape, device=device)
        indices = list(range(self.num_timesteps))[::-1][::gap]

        if progress:
            # Lazy import so that we don't depend on tqdm.

            from tqdm.auto import tqdm
            indices = tqdm(indices)

        for i in indices:
            t = th.tensor([i] * shape[0], device=device)
            with th.no_grad():
                out = self.ddim_sample(
                    model,
                    sample_x,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    model_kwargs=model_kwargs,
                    mask=mask,
                    x_start=x_start
                )
                yield out
                sample_x = out["sample"]

def _extract_into_tensor(arr, timesteps, broadcast_shape, device):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :param device: the device to move the tensor to.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = torch.from_numpy(arr).to(device=device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)

def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


In [None]:
class UniformSampler():
    """
    A distribution over timesteps in the diffusion process, intended to reduce
    variance of the objective.

    Sampler performs unbiased importance sampling, in which the
    objective's mean is unchanged.
    TODO confirm & update comment
    """

    def __init__(self, diffusion):
        self.diffusion = diffusion
        self._weights = np.ones([diffusion.num_timesteps])

    def weights(self):
        return self._weights

    def sample(self, batch_size, device):
        """
        Importance-sample timesteps for a batch.

        :param batch_size: the number of timesteps.
        :param device: the torch device to save to.
        :return: a tuple (timesteps, weights):
                 - timesteps: a tensor of timestep indices.
                 - weights: a tensor of weights to scale the resulting losses.
        """
        w = self.weights()
        p = w / np.sum(w)
        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
        indices = th.from_numpy(indices_np).long()#.to(device)
        weights_np = 1 / (len(p) * p[indices_np])
        weights = th.from_numpy(weights_np).float()#.to(device)
        return indices, weights


In [None]:
# Helper functions for training loop
def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)

def zero_grad(model_params):
    for param in model_params:
        # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
        if param.grad is not None:
            param.grad.detach_()
            param.grad.zero_()

def log_loss_dict(diffusion, ts, losses):
    for key, values in losses.items():
        # logger.logkv_mean(key, values.mean().item())
        # Log the quantiles (four quartiles, in particular).
        for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
            quartile = int(4 * sub_t / diffusion.num_timesteps)
            # logger.logkv_mean(f"{key}_q{quartile}", sub_loss)

In [None]:
import numpy as np
import os
import torch
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim import AdamW  # Make sure to import AdamW optimizer

# Define the noise schedule
num_diffusion_timesteps = 2000  # Reduce timesteps for faster experiments
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)

# Instantiate the diffusion model & transformer
diffusion = GaussianDiffusion(betas=betas)

# Specify correct dimensions
embedding_dim = 128
hidden_dim = 128
output_dims = 128
model = TransformerNetModel(vocab_size=vocab_size, input_dims=embedding_dim, hidden_t_dim=hidden_dim, output_dims=output_dims).to(device)

pytorch_total_params = sum(p.numel() for p in model.parameters())
print(f'### The parameter count is {pytorch_total_params}')

class TrainLoop():
    def __init__(self, model, diffusion, data, batch_size, lr, ema_rate, weight_decay=0.0, learning_steps=0, eval_data=None, eval_interval=-1):
        self.model = model.to(device)
        self.ddp_model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = batch_size
        self.lr = lr
        self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")]
        self.schedule_sampler = UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.learning_steps = learning_steps
        self.eval_data = eval_data
        self.eval_interval = eval_interval
        self.step = 0
        self.model_params = list(self.model.parameters())
        self.master_params = self.model_params
        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
        self.ema_params = [copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))]
        self.train_losses = []
        self.eval_losses = []

    def _log_grad_norm(self):
        sqsum = 0.0
        for p in self.master_params:
            if p.grad is not None:
                sqsum += (p.grad ** 2).sum().item()

    def _anneal_lr(self):
        if not self.learning_steps:
            return
        frac_done = self.step / self.learning_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    def optimize_normal(self):
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)

    def run_step(self, batch, cond):
        batch = batch.to(device)
        cond = {k: v.to(device) for k, v in cond.items()}
        self.forward_backward(batch, cond)
        self.optimize_normal()

    def forward_only(self, batch, cond):
        with torch.no_grad():
            zero_grad(self.model_params)
            for i in range(0, batch.shape[0], self.microbatch):
                micro = batch[i: i + self.microbatch].to(device)  # Move batch to GPU
                micro_cond = {k: v[i: i + self.microbatch].to(device) for k, v in cond.items()}  # Move cond to GPU
                t, weights = self.schedule_sampler.sample(micro.shape[0], device)
                weights = weights.to(device)  # Ensure weights is on GPU
                losses = self.diffusion.training_losses(self.ddp_model, micro, t, model_kwargs=micro_cond)
                log_loss_dict(self.diffusion, t, {f"eval_{k}": v * weights for k, v in losses.items()})
                self.eval_losses.append(losses['loss'].mean().item())

    def forward_backward(self, batch, cond):
        zero_grad(self.model_params)
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i: i + self.microbatch].to(device)  # Move batch to GPU
            micro_cond = {k: v[i: i + self.microbatch].to(device) for k, v in cond.items()}  # Move cond to GPU
            t, weights = self.schedule_sampler.sample(micro.shape[0], device)
            weights = weights.to(device)  # Ensure weights is on GPU
            losses = self.diffusion.training_losses(self.ddp_model, micro, t, model_kwargs=micro_cond)
            loss = (losses["loss"] * weights).mean()
            log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()})
            loss.backward()
            self.train_losses.append(loss.item())

    def run_loop(self):
        try:
            with tqdm(total=self.learning_steps, desc="Training Progress") as pbar:
                while not self.learning_steps or self.step < self.learning_steps:
                    batch, cond = next(self.data)
                    self.run_step(batch, cond)
                    if self.eval_data is not None and self.step % self.eval_interval == 0:
                        batch_eval, cond_eval = next(self.eval_data)
                        self.forward_only(batch_eval, cond_eval)
                    self.step += 1
                    pbar.update(1)
        except StopIteration:
            print("Data loader exhausted. Saving the model...")

        # Save the model after training is completed
        if not os.path.exists('checkpoints'):
            os.makedirs('checkpoints')
        torch.save(self.model.state_dict(), 'checkpoints/trained_model.pth')
        print("Model saved successfully.")
        
        # Plotting the training and validation losses
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label='Training Loss')
        if self.eval_losses:
            plt.plot(self.eval_losses, label='Validation Loss')
        plt.xlabel('Steps')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Training and Validation Loss over Time')
        plt.show()


# Adjusting the learning steps and validation interval
learning_steps = 2000
eval_interval = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = infinite_data_loader(train_loader, device)  # Ensure train_loader returns batches on the correct device
valid_loader = infinite_data_loader(valid_loader, device)  # Ensure valid_loader returns batches on the correct device

TrainLoop(
    model=model.to(device),
    diffusion=diffusion,
    data=iter(train_loader),
    batch_size=batch_size,
    lr=lr,
    ema_rate=ema_rate,
    weight_decay=weight_decay,
    learning_steps=learning_steps,
    eval_data=iter(valid_loader),
    eval_interval=eval_interval
).run_loop()


Instantiate all the classes for training loop <br>
Set parameters for training

Note the implementation details in DiffuSeq (first version) is
"The maximum sequence length is 128, with embedding dimension d = 128, diffusion steps T = 2000
and a square-root noise schedule."

How is it different in v2 or other papers?

We have reached StopIteration Exception. Training completed.

Get sample output using the forward step of the trained model.

## Inference step (sampling / generation part)

Once the training completed, we can start the inference step and get cross validation accuracy.

In [None]:
import numpy as np
import torch
from transformers import BertTokenizer

# Ensure the device is set correctly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = TransformerNetModel(
    vocab_size=vocab_size, 
    input_dims=embedding_dim, 
    hidden_t_dim=hidden_dim, 
    output_dims=output_dims
)
model.to(device)

# Load the trained model weights
model.load_state_dict(torch.load('checkpoints/trained_model.pth', map_location=device))
model.eval()

# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def temperature_sampling(logits, temperature):
    logits = logits / temperature
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    return torch.multinomial(probabilities, num_samples=1)

def generate_text(model, tokenizer, prompt, max_length=128, num_timesteps=2000, temperature=1.0):
    model.eval()
    with torch.no_grad():
        # Tokenize the input prompt
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

        # Get the embeddings for the input prompt
        input_embeds = model.word_embedding(input_ids).to(device)

        # Initialize noise
        noise = torch.randn_like(input_embeds).to(device)

        # Set up the diffusion process
        diffusion = GaussianDiffusion(betas=np.linspace(1e-4, 0.02, num_timesteps))

        # Sample from the model using p_sample_loop
        samples = diffusion.p_sample_loop(
            model=model,
            shape=input_embeds.shape,
            noise=noise,
            device=device,
            progress=True,
            clamp_step=None  # Set this to a specific value if needed
        )

        # Convert the generated embeddings back to tokens
        generated_ids = model.lm_head(samples[-1].to(device))

        # Apply temperature sampling
        generated_ids = generated_ids.view(-1, generated_ids.size(-1))  # Reshape to (batch_size * seq_len, vocab_size)
        sampled_ids = temperature_sampling(generated_ids, temperature)
        sampled_ids = sampled_ids.view(1, -1)  # Reshape back to (1, seq_len)
        
        generated_text = tokenizer.decode(sampled_ids.squeeze().tolist(), skip_special_tokens=True)
        
        # Print input IDs and output IDs
        print(f"Input IDs: {input_ids}")
        print(f"Output IDs: {sampled_ids}")

        return generated_text

# Example usage
prompt = "What is the capital of Australia?"
generated_response = generate_text(model, tokenizer, prompt, temperature=0.7)
print(f"Prompt: {prompt}")
print(f"Generated Response: {generated_response}")
