##### Mask-based Discrete Diffusion Language Model

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import logging, BertTokenizer, BertModel
from tqdm import tqdm
import os
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

def seed_everything(seed=1234):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()
logging.set_verbosity_error()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
class MDM(nn.Module):
    """
    Masked Diffusion Model
    """
    def __init__(self, vocab_size, max_seq_len=128, num_steps=100, device='cpu'):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.num_steps = num_steps  # Timesteps.
        self.device = device
        
        #  Denoising model.
        self.tokenizer = BertTokenizer.from_pretrained('/mnt/disk/ModelHub/bert-base-uncased') # Initialed from a pre-trained BERT
        self.bert = BertModel.from_pretrained('/mnt/disk/ModelHub/bert-base-uncased')
        
        # Predict the final token
        self.predictor = nn.Linear(self.bert.config.hidden_size, vocab_size)
        
        # MASK token ID
        self.mask_token_id = self.tokenizer.mask_token_id
    
    def add_noise(self, x, t, eps=1e-3):
        """
        Diffusion Process: change the token to MASK state.
        
        Args:
            x: Input sequence [batch_size, seq_len]
            t: Timesteps [batch_size]
            
        Returns:
            noisy_x: [batch_size, seq_len]
            mask: [batch_size, seq_len]
            p_mask: [batch_size, seq_len]
        """
        batch_size, seq_len = x.shape
        
        p_mask = (1 - eps) * t + eps
        p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
        
        # Random MASK
        rand = torch.rand_like(x.float())
        mask = (rand < p_mask) # [batch_size, seq_len]
        
        noisy_x = x.clone()
        noisy_x[mask] = self.mask_token_id
        
        return noisy_x, mask, p_mask
    
    def denoise(self, x_t, t=None):
        """
        Denoising Process：Predict the MASK token
        
        Args:
            x_t: noise sequence in timestep t [batch_size, seq_len]
            t: timesteps [batch_size]
            
        Returns:
            pred_x_0: Predicted clean sequence.
        """
        # Get the output of denoising model.
        outputs = self.bert(x_t)
        hidden_states = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        
        # Predict the original token
        logits = self.predictor(hidden_states)  # [batch_size, seq_len, vocab_size]
        
        pred_x_0 = torch.argmax(logits, dim=-1)  # [batch_size, seq_len]

        ## Random remask.
        mask = (x_t == self.mask_token_id)
        result = x_t.clone()
        result[mask] = pred_x_0[mask]

        # Get prompt length.
        prompt_len = 0
        if hasattr(self, '_prompt_len') and self._prompt_len is not None:
            prompt_len = self._prompt_len
        # each position can be remask with the probability of t/T.
        if t is not None:
            step = t[0].item() if isinstance(t, torch.Tensor) else t
            if step > 0:
                remask_prob = step / self.num_steps
                # construct prompt mask.
                prompt_mask = torch.zeros_like(result, dtype=torch.bool)
                if prompt_len > 0:
                    prompt_mask[:, :prompt_len] = True
                rand_remask = torch.rand_like(result.float()) < remask_prob
                remask_mask = rand_remask & (~prompt_mask)
                result[remask_mask] = self.mask_token_id

        return result
    
    def sample(self, batch_size=1, initial_text=None):
        """
        sampling from full MASK state.
        
        Args:
            batch_size
            initial_text: Optional. If provided, the initial text to start generation from prompt.
            
        Returns:
            Generation text.
        """
        # 
        if initial_text is None:
            # From full MASK
            x_T = torch.full((batch_size, self.max_seq_len), self.mask_token_id, device=self.device)
            x_T[:, 0] = self.tokenizer.cls_token_id
            x_T[:, -1] = self.tokenizer.sep_token_id
            prompt_len = 0
        else:
            # From given prompt.
            tokens = self.tokenizer(initial_text,
                                   truncation=True, return_tensors='pt')
            prompt_ids = tokens['input_ids'].to(self.device)
            prompt_len = prompt_ids.shape[1]

            x_T = torch.full((batch_size, self.max_seq_len), self.mask_token_id, device=self.device)
            x_T[:, :prompt_len] = prompt_ids.clone()
        # set prompt length, for denoise.
        self._prompt_len = prompt_len
        # denoising steps.
        x_t = x_T
        for t in tqdm(range(self.num_steps-1, -1, -1), desc="Sampling"):
            t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
            with torch.no_grad():
                x_t = self.denoise(x_t, t_batch)
        
        generated_texts = []
        for i in range(batch_size):
            tokens = x_t[i].cpu().numpy()
            
            if self.tokenizer.sep_token_id in tokens[self._prompt_len:]:
                
                sep_pos = np.where(tokens == self.tokenizer.sep_token_id)[0]
                
                tokens = tokens[prompt_len:sep_pos[1]]  # Ending token.
            else:
                tokens = tokens[prompt_len:]
            
            # decode to the text.
            text = self.tokenizer.decode(tokens, skip_special_tokens=True)
            generated_texts.append(text)
        
        return generated_texts
    
    def forward(self, x, t=None, eps=1e-3):
        """
        Select a timestep to train.
        
        Args:
            x: Input sequence [batch_size, seq_len]
            t: Timestep [batch_size]
            
        Returns:
            loss
        """
        batch_size, seq_len = x.shape
        
        if t is None:
            t= torch.rand(batch_size, device=self.device)
        
        # Add noise to the input sequence
        x_t, mask, p_mask = self.add_noise(x, t)
        
        # Get the output of denoising model.
        outputs = self.bert(x_t)
        hidden_states = outputs.last_hidden_state
        logits = self.predictor(hidden_states)  # [batch_size, seq_len, vocab_size]
        
        token_loss = F.cross_entropy(logits.view(-1, self.vocab_size), x.view(-1), reduction='none').view(batch_size, seq_len)
        # CE loss about mask token
        masked_loss = (token_loss * mask) / p_mask  # [batch_size, seq_len]  
        loss = masked_loss.sum(dim=1) / p_mask.shape[1]  # [batch_size]

        return loss.mean()

In [3]:
class TextDataset(Dataset):
    
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(text, padding='max_length', max_length=self.max_length,
                                 truncation=True, return_tensors='pt')
        return {
            'input_ids': encoding['input_ids'].squeeze(),
        }

In [None]:
# Train settings
batch_size = 10  # Batch size for training
epochs = 10      # Number of training epochs
lr = 4e-5        # Learning rate
max_seq_len = 32 # Maximum sequence length
num_steps = 200  # Number of diffusion steps
device = "cuda" if torch.cuda.is_available() else "cpu"  # Device
save_dir = './saved_models'    # Directory to save models

os.makedirs(save_dir, exist_ok=True)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model = MDM(
    vocab_size=tokenizer.vocab_size,
    max_seq_len=max_seq_len,
    num_steps=num_steps,
    device=device
).to(device)

# Examples.
sample_texts = [
    'Please introduce yourself. I am a nano code assistant to help you understand the model easily.',
    'Please introduce yourself. I am a  model assistant to help you understand the model easily.',
    'Please introduce yourself. I am an AI assistant designed to help you with a variety of tasks.',
    'Please introduce yourself. I am here to provide information and generate text.',
    'Please introduce yourself. I am a conversational AI created to be helpful and harmless.',
    'Please introduce yourself. My purpose is to assist users by answering questions and completing requests.',
    'Please introduce yourself. I am a model assistant, and I can help you write, summarize, and brainstorm.',
    'Please introduce yourself. I am an AI-powered collaborator, ready to assist with your projects.',
    'Please introduce yourself. I am an artificial intelligence assistant, here to help you understand the model easily.',
    'Please introduce yourself. I exist as a program to process information and respond to your queries.',
]

dataset = TextDataset(sample_texts*100, tokenizer, max_length=max_seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

optimizer = optim.AdamW(model.parameters(), lr=lr)

# train
print(f"Training on {device}...")
model.train()
for epoch in range(epochs):
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
    
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        
        optimizer.zero_grad()
        loss = model(input_ids)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    avg_epoch_loss = epoch_loss / len(dataloader)

# save model
model_path = os.path.join(save_dir, "diffusion_lm.pt")
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Training on cuda...


Epoch 1/10: 100%|██████████| 100/100 [00:06<00:00, 15.63it/s, loss=3.9648]
Epoch 2/10: 100%|██████████| 100/100 [00:05<00:00, 16.85it/s, loss=1.7504]
Epoch 3/10: 100%|██████████| 100/100 [00:05<00:00, 17.18it/s, loss=0.7683]
Epoch 4/10: 100%|██████████| 100/100 [00:05<00:00, 16.84it/s, loss=0.6778]
Epoch 5/10: 100%|██████████| 100/100 [00:06<00:00, 16.61it/s, loss=0.4703]
Epoch 6/10: 100%|██████████| 100/100 [00:06<00:00, 16.55it/s, loss=0.4295]
Epoch 7/10: 100%|██████████| 100/100 [00:05<00:00, 16.72it/s, loss=0.0967]
Epoch 8/10: 100%|██████████| 100/100 [00:05<00:00, 17.10it/s, loss=0.1296]
Epoch 9/10: 100%|██████████| 100/100 [00:05<00:00, 17.42it/s, loss=0.0268]
Epoch 10/10: 100%|██████████| 100/100 [00:05<00:00, 17.24it/s, loss=0.0236]


Model saved to ./saved_models/diffusion_lm.pt


In [13]:
# Test: Generate text
model_path= './saved_models/diffusion_lm.pt'  # Path to the saved model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_seq_len=32
num_steps=200
prompt= "introduce yourself."
num_samples=1
# Initial tokenizer
tokenizer = BertTokenizer.from_pretrained('/mnt/disk/ModelHub/bert-base-uncased')

# Initial model
model = MDM(
    vocab_size=tokenizer.vocab_size,
    max_seq_len=max_seq_len,
    num_steps=num_steps,
    device=device
).to(device)

# load checkpoint, if exists
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location=device))
    print(f"Loaded model from {model_path}")
else:
    print(f"Warning: Model file {model_path} not found. Using untrained model for showing.")

# Generation text.
model.eval()

with torch.no_grad():
    generated_texts = model.sample(batch_size=num_samples, initial_text=prompt)

for i, text in enumerate(generated_texts):
    print(f"{prompt}\n{text}")

Loaded model from ./saved_models/diffusion_lm.pt


Sampling: 100%|██████████| 200/200 [00:01<00:00, 171.25it/s]

introduce yourself.
i am a model assistant, and i can help you write, summarize, and brainstorm.



