In [1]:
# preprocess
from transformers import PLBartTokenizer 

tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-multi_task-python", language_codes="multi", src_lang="python", tgt_lang ="python")



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
# model
import torch
from transformers import PLBartForConditionalGeneration
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqSequenceClassifierOutput,
)


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
    """
    Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
    have a single `decoder_start_token_id` in contrast to other Bart-like models.
    """
    prev_output_tokens = input_ids.clone()

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)

    index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
    decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
    prev_output_tokens[:, 0] = decoder_start_tokens

    return prev_output_tokens

class InRepPlusGAN(torch.nn.Module):
    def __init__(self, style_dim):
        super(InRepPlusGAN, self).__init__()
        self.model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-multi_task-python", )
        self.encoder = self.model.get_encoder()
        self.decoder = self.model.get_decoder()
        self.config = self.model.config
        self.modifier = torch.nn.Linear(self.config.d_model + style_dim, self.config.d_model)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        style_encoding: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.LongTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds=None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Returns:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # different to other models, PLBart automatically creates decoder_input_ids from
        # input_ids if no decoder_input_ids are provided
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
        
        # encoder E, with no grad
        if encoder_outputs is None:
            with torch.no_grad():
                encoder_outputs = self.encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )
        
        # need an additional tunable encoder M
        
        batch_size = encoder_outputs[0].shape[0]
        seq_len = encoder_outputs[0].shape[1]
        
        style_encoding = style_encoding.unsqueeze(1).expand(-1, seq_len, -1)
#         for _ in range(1, seq_len):
#             style_encoding = torch.cat((style_encoding, style_encoding.unsqueeze(1)), dim=1)
            
#         print(encoder_outputs[0].shape, style_encoding.shape)
        combined_encoding = torch.cat((encoder_outputs[0], style_encoding), dim=-1)
        modifier_outputs = []
        for i in range(seq_len):
            modifier_output = self.modifier(combined_encoding[:, i, :])
            modifier_outputs += [modifier_output.unsqueeze(1)]
        modifier_outputs = torch.cat(modifier_outputs, dim=1)
        
        # decoder G, with no grad
        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        with torch.no_grad():
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=modifier_outputs,
                encoder_attention_mask=attention_mask,
                head_mask=decoder_head_mask,
                cross_attn_head_mask=cross_attn_head_mask,
                past_key_values=past_key_values,
                inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        outputs = None
        if not return_dict:
            outputs = decoder_outputs + encoder_outputs
        else:
            outputs = Seq2SeqModelOutput(
                last_hidden_state=decoder_outputs.last_hidden_state,
                past_key_values=decoder_outputs.past_key_values,
                decoder_hidden_states=decoder_outputs.hidden_states,
                decoder_attentions=decoder_outputs.attentions,
                cross_attentions=decoder_outputs.cross_attentions,
                encoder_last_hidden_state=encoder_outputs.last_hidden_state,
                encoder_hidden_states=encoder_outputs.hidden_states,
                encoder_attentions=encoder_outputs.attentions,
            )
        
        with torch.no_grad():
            lm_logits = self.model.lm_head(outputs[0]) + self.model.final_logits_bias

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        
        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        ), modifier_outputs
    def get_encoding(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        style_encoding: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.LongTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds=None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # encoder E, with no grad
        if encoder_outputs is None:
            with torch.no_grad():
                encoder_outputs = self.encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )
        return encoder_outputs
    # def forward(self, **inputs):
    #     outputs = self.model(**inputs)
    #     return outputs

In [3]:
# we can start with 1 layer
# use embedding layers
class Discriminator(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_size, style_dim):
        super(Discriminator, self).__init__()
        
        self.output_size = output_size
        self.style_dim = style_dim
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.rnn = torch.nn.RNN(embedding_dim, output_size, 1, batch_first=True)
        self.linear = torch.nn.Linear(output_size, style_dim)
        
        self.softmax = torch.nn.Softmax(dim=1)
        
        # self.l2 = torch.nn.Linear(self.config.d_model + style_dim, self.config.d_model)
        # self.l3 = torch.nn.Linear(self.config.d_model + style_dim, self.config.d_model)

    def forward(self, x):
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        embedded_x = self.embedding(x)

        # RNN Layer
        init_hidden = torch.zeros(1, batch_size, self.output_size)
        output, hidden = self.rnn(embedded_x, init_hidden)

        # Linear Layer
        hidden = hidden.squeeze(0)
        output = self.linear(hidden)
        logits = self.softmax(output)
        return logits


In [4]:
from datasets import load_from_disk

train_dataset = load_from_disk('datasets/plbart_train.hf')
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
print ((train_dataset[0]["input_ids"].shape))

torch.Size([1024])


In [5]:
from transformers import default_data_collator

In [6]:

from torch.utils.data import DataLoader
BATCH_SIZE = 2
def get_data_loader(split="train"):
    # dataset_map = {
    #     "train": train_set,
    #     "dev": dev_set,
    #     "test": test_set,
    # }

    # tokenized_set = prepare_slot_dataset(
    #     dataset_map[split], pretrained_tokenizer, split
    # )
    data_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, collate_fn=default_data_collator
    )
    return data_loader

In [7]:
data_loader = get_data_loader()

In [8]:
from collections import Counter

UNK_TOKEN = "<unk>"

class Vocab:
    """
    Vocabulary Class
    Store the index mapping for the tokens and recognize the unknown token and then return it
    """

    def __init__(self, tokens, base_map={}, max_size=None, least_freq=0):
        self.token2idx = base_map
        # count the word/token/tags frequency
        self.freq = Counter(
            [token for sequence in tokens for token in sequence]
        )

        vocab_size = 0
        # store the token start from higher frequency
        for word, count in sorted(
            self.freq.items(), key=lambda item: item[1], reverse=True
        ):
            if count < least_freq:
                break
            # if vocab size is larger than max size, stop inserting words into vocab
            if max_size is not None and vocab_size > max_size:
                break
            self.insert(word)
            vocab_size += 1

        self.idx2token = reverse_map(self.token2idx)

    def insert(self, token):
        if token in self.token2idx.keys():
            return
        self.token2idx[token] = len(self.token2idx)

    def lookup_index(self, word):
        if word not in self.token2idx.keys():
            word = UNK_TOKEN
        return self.token2idx[word]

    def lookup_token(self, idx):
        return self.idx2token[idx]

    def __len__(self):
        return len(self.token2idx)

    def __repr__(self):
        return str(self.token2idx)


def reverse_map(_map):
    reversed_map = {}
    for key, val in _map.items():
        reversed_map[val] = key
    return reversed_map

In [9]:
cluster_labels = train_dataset["labels"].detach().numpy()
cluster_labels_no_outliers = cluster_labels[cluster_labels != -1]
cluster_vocab = Vocab([cluster_labels])

In [10]:
STYLE_DIM = len(cluster_vocab) - 1

In [12]:
generator = InRepPlusGAN(style_dim=STYLE_DIM)
discriminator = Discriminator(vocab_size=generator.config.vocab_size, embedding_dim=512, output_size=128, style_dim=STYLE_DIM)

In [13]:
# Initialize BCELoss function
criterion = torch.nn.BCELoss()
# Setup Adam optimizers for both G and D
discriminator_optimizer = torch.optim.Adam(discriminator.parameters())
generator_optimizer = torch.optim.Adam(generator.parameters())

In [14]:
def label_tensor_to_one_hot(label_tensor):
    one_hot_tensor_list = []
    for idx in range(label_tensor.shape[0]):
        cluster_idx = label_tensor[idx].item()
        one_hot_tensor = cluster_to_one_hot_tensor(cluster_idx, STYLE_DIM)
        one_hot_tensor_list.append(one_hot_tensor.unsqueeze(0))
        
    return torch.cat(one_hot_tensor_list, dim=0)

def cluster_to_one_hot_tensor(cluster_idx, style_dim):
    style_tensor = torch.zeros(style_dim)
    if cluster_idx < 0:
        return style_tensor
    style_tensor[cluster_idx] = 1
    return style_tensor

In [18]:
def training_step(input_batch, discriminator, generator, criterion, discriminator_optimizer, generator_optimizer):
    # update D network
    discriminator.zero_grad()
    
    # All-real training
    # Format real batch
    real_data = input_batch["input_ids"]
    real_style = label_tensor_to_one_hot(input_batch["labels"])
    
    # Forward pass real batch through D
    output = discriminator(real_data)

    # Calculate loss on all-real batch
    discriminator_real_loss = criterion(output, real_style)
    
    # Calculate gradients for D in backward pass
    discriminator_real_loss.backward()

    # All-fake training
    
    # sampling the target styles for a whole patch
    sampled_style_indexes = random.sample(list(cluster_labels_no_outliers), BATCH_SIZE)
    style_encoding = label_tensor_to_one_hot(torch.Tensor(sampled_style_indexes).long())
    
    # Forward pass to generate the styled output
    generator_output, modifier_output = generator(
        input_ids=input_batch["input_ids"], 
        attention_mask=input_batch["attention_mask"], 
        style_encoding=style_encoding
    )
    generated_logits = generator_output.logits
    
    # use Gumbel Softmax to decode the output
    generated_tokens = torch.nn.functional.gumbel_softmax(generated_logits, hard=True, dim=-1)
    
    # produce the fake data
    fake_data = generated_tokens.argmax(-1)
    # print(tokenizer.batch_decode(fake_data))

    # Classify all fake batch with D
    output = discriminator(fake_data)

    # Calculate D's loss on the all-fake batch
    discriminator_fake_loss = criterion(output, style_encoding)

    # Calculate the gradients for this batch, accumulated (summed) with previous gradients
    discriminator_fake_loss.backward()

    # Compute error of D as sum over the fake and the real batches
    discriminator_loss = discriminator_real_loss + discriminator_fake_loss

    # Update D
    discriminator_optimizer.step()
    
    # update M network

    generator.zero_grad()

    # Since we just updated D, perform another forward pass of all-fake batch through D
    output = discriminator(fake_data)
    
    # Calculate G's loss based on this output
    generator_class_loss = criterion(output, style_encoding)
    
    # Calculate gradients for G
    generator_class_loss.backward()
    
    # TODO: add the modifier loss
    generator_loss = generator_class_loss
    
    # Update G
    generator_optimizer.step()
    
    return generator_loss, discriminator_loss

In [16]:
import random
from tqdm.auto import tqdm
num_epochs = 1
for epoch in range(num_epochs):
    epoch_loss_g = 0
    epoch_loss_d = 0
    for batch in tqdm(data_loader):
        loss_g, loss_d = training_step(batch, discriminator, generator, criterion, discriminator_optimizer, generator_optimizer)
        epoch_loss_g += loss_g
        epoch_loss_d += loss_d
    print(epoch_loss_g, epoch_loss_d)

  0%|          | 0/86020 [00:00<?, ?it/s]

tensor(0.2651, grad_fn=<AddBackward0>)
tensor(0.2059, grad_fn=<AddBackward0>)
tensor(0.2007, grad_fn=<AddBackward0>)
tensor(0.2644, grad_fn=<AddBackward0>)
tensor(0.1914, grad_fn=<AddBackward0>)
tensor(0.2068, grad_fn=<AddBackward0>)
tensor(0.3454, grad_fn=<AddBackward0>)
tensor(0.2614, grad_fn=<AddBackward0>)
tensor(0.2736, grad_fn=<AddBackward0>)
tensor(0.1976, grad_fn=<AddBackward0>)


KeyboardInterrupt: 

In [242]:
generator.get_encoding(fake_data)

BaseModelOutput(last_hidden_state=tensor([[[-0.1185, -0.1352,  0.2604,  ..., -0.0646,  0.0714,  0.2654],
         [-0.2654, -0.1587, -0.0490,  ..., -0.2254, -0.0637,  0.2084],
         [-0.3837, -0.0131, -0.3011,  ..., -1.5033,  0.2889,  0.5032],
         ...,
         [-1.4976,  0.1250,  0.2543,  ..., -0.4279,  1.3188,  0.1330],
         [-0.9512,  0.1536, -0.5133,  ...,  0.3540,  1.4173,  0.7646],
         [-0.2426,  0.1208,  0.1431,  ...,  0.4847,  1.1334,  0.8726]],

        [[ 0.6149, -0.4667,  0.7789,  ...,  0.1539,  0.7153,  1.2592],
         [-0.6662, -0.4611,  0.2164,  ..., -0.1547,  0.0227,  0.2276],
         [-1.0176, -1.1486, -0.1792,  ...,  0.7986,  0.5448,  0.8010],
         ...,
         [-1.0494, -0.4101,  0.7645,  ...,  0.2734, -0.1472, -0.2740],
         [ 0.5117,  0.8126,  0.2128,  ...,  0.5699,  0.3619, -0.1144],
         [ 0.0317, -0.0154, -0.0085,  ..., -0.0026, -0.0344, -0.0209]]]), hidden_states=None, attentions=None)

In [19]:
import time

In [23]:
round(time.time())

1656725576

In [21]:
time.time()

1656725553.8966103