# Diffusion Model for Open Dialogue

Simplified version

References
- DiffuSeq (cited below)

Adapted from:

@inproceedings{gong2022diffuseq,
  author = {Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng},
  booktitle = {International Conference on Learning Representations, ICLR},
  title = {{DiffuSeq}: Sequence to Sequence Text Generation with Diffusion Models},
  year = 2023
}

@article{gong2023diffuseqv2,
  title={DiffuSeq-v2: Bridging Discrete and Continuous Text Spaces for Accelerated Seq2Seq Diffusion Models},
  author={Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng},
  journal={arXiv preprint arXiv:2310.05793},
  year={2023}
}


In [159]:
# TODO adapt from other codes
import numpy as np 

## Dataset

- Use Commonsense Conversation dataset (from Reddit)


in diffuseq text_datasets.py some steps to load the dataset itself

- [ ] prepare datasets for training and validation in the format (stored as jsonl file?)
```
{"src": "", "train": ""}
```

- word embeddings (to be loaded?)
- use a corpus


## Training

Note that, in DiffuSeq, a model file is created to store all training progress, configuration etc. (in bash format poitning to raw files?)

- denoise rate ?
- using updates in v2 diffuseq took it from 2 days -> 11 hr learning time

In [160]:
import torch
# use GPU if available
is_cuda = torch.cuda.is_available()

if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU not available, CPU used


Get tokenizer from BERT

In [161]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size
print(vocab_size)




30522


Embedding function $EMB(w)$ to map the discrete text $w$ into a continuous space. Use hidden dim 128 like in DiffuSeq.

In [162]:
# choose embedding dimension = 128
# choose hidden_dim
hidden_dim = 128
model_emb = torch.nn.Embedding(tokenizer.vocab_size, 128)

# initialize random embeddings
torch.nn.init.normal_(model_emb.weight)


Parameter containing:
tensor([[ 0.0670, -0.2668, -0.1132,  ..., -1.2613, -0.4085,  1.2003],
        [ 0.2923, -0.3376,  0.5528,  ...,  0.3891,  0.9239, -0.2849],
        [ 0.5678, -0.7169,  0.6235,  ..., -0.5397, -0.7850,  0.5352],
        ...,
        [ 0.3775,  0.5138,  0.0430,  ..., -0.6896, -0.9261,  0.0412],
        [ 2.5701, -0.6170,  1.1171,  ...,  0.1083,  0.5843, -1.7219],
        [ 0.3754, -0.0276, -0.4616,  ...,  0.8747, -0.6232, -0.7910]],
       requires_grad=True)

Load the sample text data from file

In [163]:
# read in the data in training data json file 
# TODO do this in a different way 
# TODO load actual dataset from Amazon

import json

data_dir = "./datasets/sample"
path = f'{data_dir}/train.jsonl'

sentence_lst = {'src':[], 'trg': []}
with open(path, 'r') as f_reader:
        for row in f_reader:
            content = json.loads(row)
            sentence_lst['src'].append(content['src'].strip())
            sentence_lst['trg'].append(content['trg'].strip())

# TODO use pandas to load faster? any other package can just load json directly rather than row by row

Tokenize dataset

In [164]:
from datasets import Dataset
raw_datasets = Dataset.from_dict(sentence_lst)
raw_datasets

Dataset({
    features: ['src', 'trg'],
    num_rows: 19
})

In [165]:
vocab_dict = tokenizer.get_vocab()

def tokenize_function(examples):
        input_id_x = tokenizer(examples['src'], add_special_tokens=True)['input_ids']
        input_id_y = tokenizer(examples['trg'], add_special_tokens=True)['input_ids']
        result_dict = {'input_id_x': input_id_x, 'input_id_y': input_id_y}

        return result_dict

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=['src', 'trg'],
    load_from_cache_file=True,
    desc="Running tokenizer on dataset",
)

Running tokenizer on dataset (num_proc=4): 100%|██████████| 19/19 [00:00<00:00, 75.23 examples/s]


Get the training dataset

In [166]:
# helper function to collate the batch
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
    result = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
    mask_ = torch.full([len(examples), max_length], pad_token_id, dtype=torch.int64).tolist()
    for i, example in enumerate(examples):
        curr_len = min(len(example), max_length)
        result[i][:curr_len] = example[:curr_len]
        mask_[i][:curr_len] = [1] * curr_len
    if return_mask:
        return result, mask_
    return result

In [167]:
vocab_dict = tokenizer#.get_vocab() = loaded_vocab variable, too
# :param seq_len: the max sequence length (one-side).
# TODO figure out correct parameter for this
seq_len = 128

def merge_and_mask(group_lst):
        lst = []
        mask = []
        for i in range(len(group_lst['input_id_x'])):
            end_token = group_lst['input_id_x'][i][-1]
            src = group_lst['input_id_x'][i][:-1]
            trg = group_lst['input_id_y'][i][:-1]
            while len(src) + len(trg) > seq_len - 3:
                if len(src)>len(trg):
                    src.pop()
                elif len(src)<len(trg):
                    trg.pop()
                else:
                    src.pop()
                    trg.pop()
            src.append(end_token)
            trg.append(end_token)

            lst.append(src + [vocab_dict.sep_token_id] + trg)
            mask.append([0]*(len(src)+1))
        group_lst['input_ids'] = lst
        group_lst['input_mask'] = mask
        return group_lst

tokenized_datasets = tokenized_datasets.map(
        merge_and_mask,
        batched=True,
        num_proc=1,
        desc=f"merge and mask",
    )
    
def pad_function(group_lst):
    max_length = seq_len
    group_lst['input_ids'] = _collate_batch_helper(group_lst['input_ids'], vocab_dict.pad_token_id, max_length)
    group_lst['input_mask'] = _collate_batch_helper(group_lst['input_mask'], 1, max_length)
    return group_lst

lm_datasets = tokenized_datasets.map(
        pad_function,
        batched=True,
        num_proc=1,
        desc=f"padding",
    )


merge and mask: 100%|██████████| 19/19 [00:00<00:00, 1885.48 examples/s]

In [None]:
print(lm_datasets, 'padded dataset')

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 19
}) padded dataset


In [None]:
import datasets
raw_datasets = datasets.DatasetDict()
raw_datasets['train'] = lm_datasets

In [None]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
        num_rows: 19
    })
})

In [None]:
raw_datasets["train"]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 19
})

Load data into iterable data variable

In [None]:
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, text_datasets, model_emb=None):
        super().__init__()
        self.text_datasets = text_datasets
        self.length = len(self.text_datasets['train'])
        self.model_emb = model_emb

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with torch.no_grad():

            input_ids = self.text_datasets['train'][idx]['input_ids']
            hidden_state = self.model_emb(torch.tensor(input_ids))

            # obtain the input vectors, only used when word embedding is fixed (not trained end-to-end)
            arr = np.array(hidden_state, dtype=np.float32)

            out_kwargs = {}
            out_kwargs['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
            out_kwargs['input_mask'] = np.array(self.text_datasets['train'][idx]['input_mask'])

            return arr, out_kwargs

In [None]:
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

train_dataset = TextDataset(raw_datasets, model_emb=model_emb)

data_loader = DataLoader(
    train_dataset,
    batch_size=20,  # TODO choose an appropriate batch size ?
)

data = iter(data_loader.dataset)

next(data)

# NOTE load data for the validation, inference differently

(array([[ 0.44345245, -0.3449697 , -0.04056583, ...,  2.329228  ,
         -0.62913144, -1.0037589 ],
        [ 0.16234268, -0.0413381 ,  2.94077   , ..., -1.1515596 ,
          0.11964969, -0.1952339 ],
        [-0.876911  ,  0.10885099,  0.02171185, ...,  0.7354041 ,
          0.32704896,  0.35938853],
        ...,
        [-2.4106717 , -1.938898  , -0.5104473 , ...,  0.87380236,
          0.29839227,  0.77376235],
        [-2.4106717 , -1.938898  , -0.5104473 , ...,  0.87380236,
          0.29839227,  0.77376235],
        [-2.4106717 , -1.938898  , -0.5104473 , ...,  0.87380236,
          0.29839227,  0.77376235]], dtype=float32),
 {'input_ids': array([  101,  2023,  2003,  2026,  5440,  2466,  8115,  1012,  6187,
          1050,  1005,  1056,  3524,  2000,  2156,  2129,  2002,  2515,
          1999,  1996,  2778,  5420,   999,  1996,  2265,  2003,  2026,
         11302,  2868,  2005,  1996,  2733,  1012,   102,   102,   101,
          6300,  2050,  2009,  1005,  1055,  2524,  2025,

In [None]:
data_loader.dataset[0]

(array([[ 0.44345245, -0.3449697 , -0.04056583, ...,  2.329228  ,
         -0.62913144, -1.0037589 ],
        [ 0.16234268, -0.0413381 ,  2.94077   , ..., -1.1515596 ,
          0.11964969, -0.1952339 ],
        [-0.876911  ,  0.10885099,  0.02171185, ...,  0.7354041 ,
          0.32704896,  0.35938853],
        ...,
        [-2.4106717 , -1.938898  , -0.5104473 , ...,  0.87380236,
          0.29839227,  0.77376235],
        [-2.4106717 , -1.938898  , -0.5104473 , ...,  0.87380236,
          0.29839227,  0.77376235],
        [-2.4106717 , -1.938898  , -0.5104473 , ...,  0.87380236,
          0.29839227,  0.77376235]], dtype=float32),
 {'input_ids': array([  101,  2023,  2003,  2026,  5440,  2466,  8115,  1012,  6187,
          1050,  1005,  1056,  3524,  2000,  2156,  2129,  2002,  2515,
          1999,  1996,  2778,  5420,   999,  1996,  2265,  2003,  2026,
         11302,  2868,  2005,  1996,  2733,  1012,   102,   102,   101,
          6300,  2050,  2009,  1005,  1055,  2524,  2025,

Create the model and diffusion

In [None]:
#  model = TransformerNetModel(
#         input_dims=hidden_dim,
#         output_dims=(hidden_dim if not learn_sigma else hidden_dim*2),
#         hidden_t_dim=hidden_t_dim,
#         dropout=dropout,
#         config_name=config_name,
#         vocab_size=vocab_size,
#         init_pretrained=use_plm_init
#     )

#     betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)

#     if not timestep_respacing:
#         timestep_respacing = [diffusion_steps]

#     diffusion = SpacedDiffusion(
#         use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
#         betas=betas,
#         rescale_timesteps=rescale_timesteps,
#         predict_xstart=predict_xstart,
#         learn_sigmas = learn_sigma,
#         sigma_small = sigma_small,
#         use_kl = use_kl,
#         rescale_learned_sigmas=rescale_learned_sigmas
#     )

# FIXME need to implement a way to save the model progress so that
# trained model can be loaded later for assessment purposes

In [None]:
import math 
import torch as th

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


In [None]:
# FIXME need to define TransformerNetModel
#  The full Transformer model with attention and timestep embedding.

# Adapted from diffuSeq

# TODO code the transformer from scratch
# TODO design the transformer from scratch

from transformers import BertConfig, BertModel
import torch.nn as nn
import torch as th

# config = BertConfig.from_pretrained('bert-base-uncased')

# Config from DiffuSeq
  #     self,
    #     input_dims,
    #     output_dims,
    #     hidden_t_dim,
    #     dropout=0,
    #     config=None,
    #     config_name='bert-base-uncased',
    #     vocab_size=None,
    #     init_pretrained='no',
    #     logits_mode=1,
    # ):


class TransformerNetModel(nn.Module):
    def __init__(self, vocab_size, input_dims, hidden_t_dim, output_dims):
        super().__init__()

        # FIXME just using the default config not a param
        config = BertConfig.from_pretrained('bert-base-uncased')
        # FIXME add in this to the params - config, defining here as hard-coded config=
        # FIXME set to an actual value
        # config.hidden_size = what?
        # FIXME set to an actual value
        config.hidden_dropout_prob = 0
        # then pass this to the BertModel config=config

        # FIXME specify input_dims
        self.input_dims = input_dims

        # TODO include condfiguration for hidden_t_dim -> time embeddings hidden dimensions
        self.hidden_t_dim = hidden_t_dim
        self.output_dims = output_dims
        # self.dropout = dropout
        # TODO check this gets assigned by default BERT config
        self.hidden_size = config.hidden_size

        # TODO add work embeddings
        # FIXME vocab_size is define way up above

        # Generate logits for hidden representation
        self.lm_head = nn.Linear(self.input_dims, vocab_size)

        # Time embeddings
        # FIXME investigate this implementation
        time_embed_dim = hidden_t_dim * 4
        self.time_embed = nn.Sequential(
            nn.Linear(hidden_t_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, config.hidden_size),
        )

        # FIXME What does this do
        # Function to deal with having a hidden size not equal to input size, project to hidden size (?)
        if self.input_dims != config.hidden_size:
            self.input_up_proj = nn.Sequential(nn.Linear(input_dims, config.hidden_size),
                                            nn.Tanh(), nn.Linear(config.hidden_size, config.hidden_size))

        # FIXME why is this temporary 
        temp_bert = BertModel.from_pretrained('bert-base-uncased', config=config)
        self.word_embedding = temp_bert.embeddings.word_embeddings
       
       # FIXME why do we do this 2 times????
        with th.no_grad():
            self.lm_head.weight = self.word_embedding.weight
        # self.lm_head.weight.requires_grad = False
        # self.word_embedding.weight.requires_grad = False
            
        # TODO explain what is happening
        self.input_transformers = temp_bert.encoder
        # TODO explain what is doing
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embeddings = temp_bert.embeddings.position_embeddings
        self.LayerNorm = temp_bert.embeddings.LayerNorm
     
        del temp_bert.embeddings
        del temp_bert.pooler

        # FIXME When does this get used
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # FIXME what is happening here
        if self.output_dims != config.hidden_size:
            self.output_down_proj = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size),
                                                nn.Tanh(), nn.Linear(config.hidden_size, self.output_dims))
    
    def get_embeds(self, input_ids):
        return self.word_embedding(input_ids)

    # FIXME make a real design decision how to use this function
    def get_logits(self, hidden_repr):
    # NOTE we make a simplifying assumption get the logits from linear layer
    # FIXME what is a best option for generating logits?? cosine similarity is better?
        return self.lm_head(hidden_repr)
                

    # FIXME what is the difference btween BertModel, BertConfig, BertTokenizer, maybe define it all in one place config for tokenizer + embeddings?


    def forward(self, x, timesteps):
        # FIXME update string comment
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :return: an [N x C x ...] Tensor of outputs.
        """

        # Embedded timestep
        emb_t = self.time_embed(timestep_embedding(timesteps, self.hidden_t_dim))

        # FIXME when would we ever not be the right size? what size will we choose?
        if self.input_dims != self.hidden_size:
            emb_x = self.input_up_proj(x)
        else:
            emb_x = x

        seq_length = x.size(1)
        position_ids = self.position_ids[:, : seq_length ]
    #     # print(emb_x.shape, emb_t.shape, self.position_embeddings)
        emb_inputs = self.position_embeddings(position_ids) + emb_x + emb_t.unsqueeze(1).expand(-1, seq_length, -1)
        emb_inputs = self.dropout(self.LayerNorm(emb_inputs))

        input_trans_hidden_states = self.input_transformers(emb_inputs).last_hidden_state
        
        if self.output_dims != self.hidden_size:
            h = self.output_down_proj(input_trans_hidden_states)
        else:
            h = input_trans_hidden_states
        h = h.type(x.dtype)
        return h

In [None]:
# TODO explore other simplistic sample code
# https://github.com/lucidrains/denoising-diffusion-pytorch
# https://e-dorigatti.github.io/math/deep%20learning/2023/06/25/diffusion.html
# https://github.com/tanelp/tiny-diffusion

In [None]:
# NOTE adapted from diffuSeq, which is adapted from https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42

class GaussianDiffusion():
    def __init__(
        self,
        betas): 

        # TODO consider adding these other parameters, where are they used, are they required?
    #     predict_xstart,
    #     rescale_learned_sigmas,
    #     learn_sigmas,
    #     sigma_small,
    #     use_kl,
    #     rescale_timesteps=False,
    # ):
    #     self.rescale_timesteps = rescale_timesteps
    #     self.predict_xstart = predict_xstart
    #     self.rescale_learned_sigmas = rescale_learned_sigmas
    #     self.learn_sigmas = learn_sigmas
    #     self.sigma_small = sigma_small
    #     self.use_kl = use_kl

        # Use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, "betas must be 1-D"
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

        # FIXME what is alphas? why?
        alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # FIXME copied directly
        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)


        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # log calculation clipped because the posterior variance is 0 at the
        # beginning of the diffusion chain.
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * np.sqrt(alphas)
            / (1.0 - self.alphas_cumprod)
        )

    # NOTE the below comments from diffuSeq
    # self.mapping_func = None # implement in train main()
    # self.add_mask_noise = False # TODO

    # FIXME copied directly from diffuSeq
    def training_losses(self, model, *args, **kwargs):
        self.model = model
        return self.training_losses_seq2seq(model, *args, **kwargs)
    
    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
        )

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - pred_xstart
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t
    
    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).

        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        mean = (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        )
        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = _extract_into_tensor(
            self.log_one_minus_alphas_cumprod, t, x_start.shape
        )
        return mean, variance, log_variance


    def q_sample(self, x_start, t, noise=None, mask=None):
        """
        Diffuse the data for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0).

        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :param mask: anchoring masked position
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = th.randn_like(x_start)
        assert noise.shape == x_start.shape
        x_t = (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )

        if mask == None:
            return x_t
        else:
            mask = th.broadcast_to(mask.unsqueeze(dim=-1), x_start.shape)
            return th.where(mask==0, x_start, x_t)
        

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior: 
            q(x_{t-1} | x_t, x_0)

        """
        assert x_start.shape == x_t.shape
        posterior_mean = (
            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = _extract_into_tensor(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(
        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
    ):
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.

        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample. Applies before
            clip_denoised.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """
        if model_kwargs is None:
            model_kwargs = {}

        B, C = x.size(0), x.size(-1)
        assert t.shape == (B,)
        # print(x.shape)
        model_output = model(x, self._scale_timesteps(t), **model_kwargs)
        
        # for fixedlarge, we set the initial (log-)variance like so
        # to get a better decoder log likelihood.
        model_variance = np.append(self.posterior_variance[1], self.betas[1:])
        model_log_variance = np.log(np.append(self.posterior_variance[1], self.betas[1:]))
        
        model_variance = _extract_into_tensor(model_variance, t, x.shape)
        model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)

        def process_xstart(x):
            if denoised_fn is not None:
                # print(denoised_fn)
                x = denoised_fn(x, t)
            if clip_denoised:
                return x.clamp(-1, 1)
            return x

        if self.predict_xstart:
            pred_xstart = process_xstart(model_output)
        else:
            ### model is used to predict eps
            pred_xstart = process_xstart(
                self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
            )

        model_mean, _, _ = self.q_posterior_mean_variance(
            x_start=pred_xstart, x_t=x, t=t
        )

        assert (
            model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
        )
        return {
            "mean": model_mean,
            "variance": model_variance,
            "log_variance": model_log_variance,
            "pred_xstart": pred_xstart,
        }

    def p_sample(
        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
            top_p=None, mask=None, x_start=None,
    ):
        """
        Sample x_{t-1} from the model at the given timestep.

        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param mask: anchoring masked position to x_start
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        if top_p is not None and top_p > 0:
            # print('top_p sampling')
            noise = th.randn_like(x)
            replace_mask = th.abs(noise) > top_p
            while replace_mask.any():
                noise[replace_mask] = th.randn_like(noise[replace_mask])
                replace_mask = th.abs(noise) > top_p
            assert (th.abs(noise) <= top_p).all()

        else:
            noise = th.randn_like(x)

        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        if mask == None:
            pass
        else:
            sample = th.where(mask==0, x_start, sample)

        return {
            "sample": sample, 
            "pred_xstart": out["pred_xstart"],
            "greedy_mean": out["mean"], 
            "out": out
        }

    
    def p_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        top_p=None,
        clamp_step=None,
        clamp_first=None,
        mask=None,
        x_start=None,
        gap=1,
    ):
        """
        Generate samples from the model.

        :param model: the model module.
        :param shape: the shape of the samples, (N, C, H, W).
        :param noise: if specified, the noise from the encoder to sample.
                      Should be of the same shape as `shape`.
        :param clip_denoised: if True, clip x_start predictions to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param mask: anchoring masked position to x_start
        :param clamp_step: in clamp_first mode, choose end clamp step, otherwise starting clamp step
        :param clamp_first: bool, clamp_first mode
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param device: if specified, the device to create the samples on.
                       If not specified, use a model parameter's device.
        :param progress: if True, show a tqdm progress bar.
        :return: a non-differentiable batch of samples.
        """
        final = []
        for sample in self.p_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            top_p=top_p,
            clamp_step=clamp_step,
            clamp_first=clamp_first,
            mask=mask,
            x_start=x_start
        ):
            final.append(sample['sample'])
        return final

    def p_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        top_p=None,
        clamp_step=None,
        clamp_first=None,
        mask=None,
        x_start=None,
    ):
        """
        Generate samples from the model and yield intermediate samples from
        each timestep of diffusion.

        Arguments are the same as p_sample_loop().
        Returns a generator over dicts, where each dict is the return value of
        p_sample().
        """
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None: # custom your the start point of x_0
            sample_x = noise
        else:
            sample_x = th.randn(*shape, device=device)
        indices = list(range(self.num_timesteps))[::-1]

        if progress:
            # Lazy import so that we don't depend on tqdm.
            from tqdm.auto import tqdm
            indices = tqdm(indices)

        for i in indices: # from T to 0
            t = th.tensor([i] * shape[0], device=device)
            if not clamp_first:
                if i > clamp_step:
                    denoised_fn_cur = None
                else:
                    denoised_fn_cur = denoised_fn
            else:
                if i >= clamp_step:
                    denoised_fn_cur = denoised_fn
                else:
                    denoised_fn_cur = None
            with th.no_grad():
                out = self.p_sample(
                    model,
                    sample_x,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn_cur,
                    model_kwargs=model_kwargs,
                    top_p=top_p,
                    mask=mask,
                    x_start=x_start
                )
                yield out
                sample_x = out["sample"]


    def _get_x_start(self, x_start_mean, std):
        '''
        Word embedding projection from {Emb(w)} to {x_0}
        :param x_start_mean: word embedding
        :return: x_0
        '''
        noise = th.randn_like(x_start_mean)
        assert noise.shape == x_start_mean.shape
        # print(x_start_mean.device, noise.device)
        return (
             x_start_mean + std * noise
        )

    def _token_discrete_loss(self, x_t, get_logits, input_ids, mask=None, truncate=False, t=None):
        '''
        the loss of -log p(w|z_0)
        :param x_start_mean: word embedding
        :return: x_0
        '''
        reshaped_x_t = x_t
        logits = get_logits(reshaped_x_t)  # bsz, seqlen, vocab
        # print(logits.shape)
        loss_fct = th.nn.CrossEntropyLoss(reduction='none')
        decoder_nll = loss_fct(logits.view(-1, logits.size(-1)), input_ids.view(-1)).view(input_ids.shape)
        if mask != None:
            decoder_nll *= mask
        # print(decoder_nll.shape)
        if mask != None:
            decoder_nll = decoder_nll.sum(dim=-1)/mask.sum(dim=-1)
        else:
            decoder_nll = decoder_nll.mean(dim=-1)

        return decoder_nll

    def _x0_helper(self, model_output, x, t):

        if self.predict_xstart:
            pred_xstart = model_output
            pred_prev, _, _ = self.q_posterior_mean_variance(
                x_start=pred_xstart, x_t=x, t=t
            )

        else: # predict eps
            pred_xstart = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
        
            pred_prev, _, _ = self.q_posterior_mean_variance(
                x_start=pred_xstart, x_t=x, t=t
            )

        return {'pred_xprev':pred_prev, 'pred_xstart':pred_xstart}

    def training_losses_seq2seq(self, model, x_start, t, model_kwargs=None, noise=None):
        """
        Compute training losses for a single timestep.

        :param model: the model to evaluate loss on.
        :param x_start: the [N x C x ...] tensor of inputs. # not used unless fixing the input embeddings
        :param t: a batch of timestep indices.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param noise: if specified, the specific Gaussian noise to try to remove.
        :return: a dict with the key "loss" containing a tensor of shape [N].
                 Some mean or variance settings may also have other keys.
        """
        x_start_fix = x_start # save the orignal x_0
        assert 'input_ids' in model_kwargs
        input_ids_x = model_kwargs.pop('input_ids').to(t.device)
        input_ids_mask = model_kwargs.pop('input_mask').to(t.device)
        x_start_mean = model.model.module.get_embeds(input_ids_x)
        
        std = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
                                   th.tensor([0]).to(x_start_mean.device),
                                   x_start_mean.shape)
        # print(std.shape, )
        # x_start_log_var = 2 * th.log(std)
        x_start = self._get_x_start(x_start_mean, std)
        # print(x_start_mean.shape, x_start.shape)
        if noise is None:
            noise = th.randn_like(x_start)

        x_t = self.q_sample(x_start, t, noise=noise, mask=input_ids_mask) # reparametrization trick.

        get_logits = model.model.module.get_logits

        terms = {}

        target = x_start
        model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
        assert model_output.shape == target.shape == x_start.shape
        terms["mse"] = mean_flat((target - model_output) ** 2)

        model_out_x_start = self._x0_helper(model_output, x_t, t)['pred_xstart'] # predicted_xstart = model_output
        t0_mask = (t == 0)
        t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
        terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"])

        # tT_mask = (t == self.num_timesteps - 1)
        out_mean, _, _ = self.q_mean_variance(x_start, th.LongTensor([self.num_timesteps - 1]).to(x_start.device))
        tT_loss =  mean_flat(out_mean ** 2)

        decoder_nll = self._token_discrete_loss(x_start, get_logits, input_ids_x) # embedding regularization
        terms["nll"] = self._token_discrete_loss(model_out_x_start, get_logits, input_ids_x, mask=input_ids_mask, truncate=True, t=t) # x_0->model_out_x_start
        # assert (model.lm_head.weight == model.word_embedding.weight).all()

        terms["loss"] = terms["mse"] + decoder_nll + tT_loss

        return terms

    def ddim_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        eta=0.0,
        langevin_fn=None,
        mask=None,
        x_start=None
    ):
        """
        Sample x_{t-1} from the model using DDIM.

        Same usage as p_sample().
        """
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        # Usually our model outputs epsilon, but we re-derive it
        # in case we used x_start or x_prev prediction.
        eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
        alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
        sigma = (
            eta
            * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
            * th.sqrt(1 - alpha_bar / alpha_bar_prev)
        )
        # Equation 12.
        noise = th.randn_like(x)
        mean_pred = (
            out["pred_xstart"] * th.sqrt(alpha_bar_prev)
            + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
        )
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        # print(sigma.mean())
        sample = mean_pred + nonzero_mask * sigma * noise
        if langevin_fn:
            print(t.shape)
            sample=langevin_fn(sample, mean_pred, sigma, self.alphas_cumprod_prev[t[0]], t, x)
        
        if mask == None:
            pass
        else:
            sample = th.where(mask==0, x_start, sample)
        
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}

    def ddim_reverse_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        eta=0.0,
    ):
        """
        Sample x_{t+1} from the model using DDIM reverse ODE.
        """
        assert eta == 0.0, "Reverse ODE only for deterministic path"
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        # Usually our model outputs epsilon, but we re-derive it
        # in case we used x_start or x_prev prediction.
        eps = (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
            - out["pred_xstart"]
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
        alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)

        # Equation 12. reversed
        mean_pred = (
            out["pred_xstart"] * th.sqrt(alpha_bar_next)
            + th.sqrt(1 - alpha_bar_next) * eps
        )

        return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}

    def ddim_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        top_p=None,
        clamp_step=None,
        clamp_first=None,
        mask=None,
        x_start=None,
        gap=1,
    ):
        """
        Generate samples from the model using DDIM.
        :param gap: compute ddim sampling for each {gap} step

        Same usage as p_sample_loop().
        """
        final = []
        for sample in self.ddim_sample_loop_progressive(
            model,
            shape,
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            mask=mask,
            x_start=x_start,
            gap = gap
        ):
            final.append(sample['sample'])
        return final

    def ddim_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        eta=0.0,
        langevin_fn=None,
        mask=None,
        x_start=None,
        gap=1
    ):
        """
        Use DDIM to sample from the model and yield intermediate samples from
        each timestep of DDIM.

        Same usage as p_sample_loop_progressive().
        """
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list))
        if noise is not None:
            sample_x = noise
        else:
            sample_x = th.randn(*shape, device=device)
        indices = list(range(self.num_timesteps))[::-1][::gap]

        if progress:
            # Lazy import so that we don't depend on tqdm.
            from tqdm.auto import tqdm

            indices = tqdm(indices)

        for i in indices:
            t = th.tensor([i] * shape[0], device=device)
            with th.no_grad():
                out = self.ddim_sample(
                    model,
                    sample_x,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    model_kwargs=model_kwargs,
                    mask=mask,
                    x_start=x_start
                )
                yield out
                sample_x = out["sample"]

def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)

def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))

In [None]:
class UniformSampler():
    """
    A distribution over timesteps in the diffusion process, intended to reduce
    variance of the objective.

    Sampler performs unbiased importance sampling, in which the
    objective's mean is unchanged.
    TODO confirm & update comment
    """

    def __init__(self, diffusion):
        self.diffusion = diffusion
        self._weights = np.ones([diffusion.num_timesteps])

    def weights(self):
        return self._weights

    def sample(self, batch_size, device):
        """
        Importance-sample timesteps for a batch.

        :param batch_size: the number of timesteps.
        :param device: the torch device to save to.
        :return: a tuple (timesteps, weights):
                 - timesteps: a tensor of timestep indices.
                 - weights: a tensor of weights to scale the resulting losses.
        """
        w = self.weights()
        p = w / np.sum(w)
        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
        indices = th.from_numpy(indices_np).long().to(device)
        weights_np = 1 / (len(p) * p[indices_np])
        weights = th.from_numpy(weights_np).float().to(device)
        return indices, weights


In [None]:
# Helper functions for training loop
def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)

def zero_grad(model_params):
    for param in model_params:
        # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
        if param.grad is not None:
            param.grad.detach_()
            param.grad.zero_()

def log_loss_dict(diffusion, ts, losses):
    for key, values in losses.items():
        # logger.logkv_mean(key, values.mean().item())
        # Log the quantiles (four quartiles, in particular).
        for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
            quartile = int(4 * sub_t / diffusion.num_timesteps)
            # logger.logkv_mean(f"{key}_q{quartile}", sub_loss)

In [None]:
from torch.optim import AdamW
import copy

class TrainLoop():
    def __init__(
        self,
    #     *,
        model,
        diffusion,
        data,
        batch_size,
    #     microbatch,
        lr,
        ema_rate,
    #     log_interval,
        #  schedule_sampler=None,
        weight_decay=0.0,
        learning_steps=0,
    #     checkpoint_path='',
    #     gradient_clipping=-1.,
        eval_data=None,
        eval_interval=-1,
    ):
          self.model = model
          self.diffusion = diffusion
          self.data = data
          self.batch_size = batch_size
        #   self.microbatch = microbatch if microbatch > 0 else batch_size
          # Assume no microbatch
          self.microbatch = batch_size

          self.lr = lr
          self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
            )
          
          # NOTE assuming uniform sampler
          self.schedule_sampler = UniformSampler(diffusion)
          self.weight_decay = weight_decay
          self.learning_steps = learning_steps
          self.eval_data = eval_data
          self.eval_interval = eval_interval
          
          self.step = 0

          # TODO check other initialization steps are covered

          self.model_params = list(self.model.parameters())
          self.master_params = self.model_params

          # Optimizer
          self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)

          self.ema_params = [
                copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
            ]

    def _log_grad_norm(self):
        sqsum = 0.0
        for p in self.master_params:
            if p.grad != None:
                sqsum += (p.grad ** 2).sum().item()
        # TODO implement logging
        # logger.logkv_mean("grad_norm", np.sqrt(sqsum))

    def _anneal_lr(self):
        if not self.learning_steps:
            return
        frac_done = (self.step) / self.learning_steps
        lr = self.lr * (1 - frac_done)
        for param_group in self.opt.param_groups:
            param_group["lr"] = lr

    
    def optimize_normal(self):
        # NOTE assuming no gradient clipping
        self._log_grad_norm()
        self._anneal_lr()
        self.opt.step()
        for rate, params in zip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)
     

    def run_step(self, batch, cond):
        # TODO implement this fn
        self.forward_backward(batch, cond)
        # NOTE assuming not using fp16 optimization
        self.optimize_normal()
        # TODO do this fn - logging
        # self.log_step()

    def forward_only(self, batch, cond):
        with th.no_grad():
            zero_grad(self.model_params)
            for i in range(0, batch.shape[0], self.microbatch):
                micro = batch[i: i + self.microbatch].to(device)
                micro_cond = {
                    k: v[i: i + self.microbatch].to(device)
                    for k, v in cond.items()
                }
                last_batch = (i + self.microbatch) >= batch.shape[0]
                t, weights = self.schedule_sampler.sample(micro.shape[0], device)

                compute_losses = self.diffusion.training_losses(self.ddp_model, micro, t, model_kwargs=micro_cond)

                # NOTE not using ddp - distributed training
                with self.ddp_model.no_sync():
                        losses = compute_losses()

                log_loss_dict(
                    self.diffusion, t, {f"eval_{k}": v * weights for k, v in losses.items()}
                )


    def forward_backward(self, batch, cond):
        zero_grad(self.model_params)
        for i in range(0, batch.shape[0], self.microbatch):
            micro = batch[i : i + self.microbatch] #.to(dist_util.dev())
            micro_cond = {
                k: v[i : i + self.microbatch]#.to(dist_util.dev())
                for k, v in cond.items()
            }
            last_batch = (i + self.microbatch) >= batch.shape[0]
            t, weights = self.schedule_sampler.sample(micro.shape[0],torch.device)# dist_util.dev())
            # print(micro_cond.keys())
            compute_losses = self.diffusion.training_losses(self.ddp_model, micro, t, model_kwargs=micro_cond)

            # NOTE not using ddp - distributed training
            with self.ddp_model.no_sync():
                losses = compute_losses()

            # if isinstance(self.schedule_sampler, LossAwareSampler):
            #     self.schedule_sampler.update_with_local_losses(
            #         t, losses["loss"].detach()
            #     )

            loss = (losses["loss"] * weights).mean()
            log_loss_dict(
                self.diffusion, t, {k: v * weights for k, v in losses.items()}
            )
            # NOTE not using fp16 optimization 
            loss.backward()


    # NOTE not going to implement capacity to resume training     
    # NOTE removed saving checkpoints 
    # NOTE removed some logging
    def run_loop(self):
        while (
            # FIXME check these defined correctly
            not self.learning_steps
            or self.step < self.learning_steps
        ):
            batch, cond = next(self.data)
            # TODO add this fn
            self.run_step(batch, cond)
            if self.eval_data is not None and self.step % self.eval_interval == 0:
                batch_eval, cond_eval = next(self.eval_data)
                # TODO add this fn
                self.forward_only(batch_eval, cond_eval)
                print('eval on validation set')
            self.step += 1
    

Instantiate all the classes for training loop

In [None]:
# Define the noise schedule
# TODO consider other noise schedules, here take the simplifying assumption that we use linear noise schedule 
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.

# TODO make number of diffusion steps a parameter 
num_diffusion_timesteps = 1000

scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = np.linspace(
    beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)

# NOTE we assume NO timestep respacing 
# TODO instantiate the diffusion model & transformer and save it (?)

diffusion = GaussianDiffusion(betas)

# TODO specify correct dimensions!
model = TransformerNetModel(vocab_size=vocab_size, input_dims=64, hidden_t_dim=64, output_dims=64)

model.to(device) 

pytorch_total_params = sum(p.numel() for p in model.parameters())
# TODO add to logger
print(f'### The parameter count is {pytorch_total_params}')

### The parameter count is 110414970


In [None]:
len(data)

1

Set parameters for training

Note the implementation details in DiffuSeq (first version) is
"The maximum sequence length is 128, with embedding dimension d = 128, diffusion steps T = 2000
and a square-root noise schedule."

How is it different in v2 or other papers?

In [None]:
# TODO figure out what are the right params to recreate diffuSeq
batch_size = 20 
lr = 0.001 # learning rate
ema_rate = 0.999
weight_decay = 0.01
learning_steps = 2000

In [None]:
# Run training loop



TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data, # TODO load the text data
        batch_size=batch_size, # TODO set hyperparams from args like args.batch_size
        # microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        # log_interval=args.log_interval,
        weight_decay=weight_decay,
        learning_steps=learning_steps,
        # checkpoint_path=args.checkpoint_path,
        # gradient_clipping=args.gradient_clipping,
        # eval_data=data_valid,
        # eval_interval=args.eval_interval
    ).run_loop()


StopIteration: 

## Inference step (sampling / generation part)

Once the training completed, we can start the inference step and get cross validation accuracy.