In [18]:
# preprocess
from transformers import PLBartTokenizer 

tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-multi_task-python", src_lang="python", tgt_lang="python")



Downloading:   0%|          | 0.00/963k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/811 [00:00<?, ?B/s]

In [267]:
# model
import torch
from transformers import PLBartForConditionalGeneration
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqSequenceClassifierOutput,
)

style_dim = 10

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
    """
    Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
    have a single `decoder_start_token_id` in contrast to other Bart-like models.
    """
    prev_output_tokens = input_ids.clone()

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)

    index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
    decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
    prev_output_tokens[:, 0] = decoder_start_tokens

    return prev_output_tokens

class InRepPlusGAN(torch.nn.Module):
    def __init__(self):
        super(InRepPlusGAN, self).__init__()
        self.model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-multi_task-python")
        self.encoder = self.model.get_encoder()
        self.decoder = self.model.get_decoder()
        self.config = self.model.config
        self.modifier = torch.nn.Linear(self.config.d_model + style_dim, self.config.d_model)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        style_encoding: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.LongTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds=None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Returns:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # different to other models, PLBart automatically creates decoder_input_ids from
        # input_ids if no decoder_input_ids are provided
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
        
        # encoder E, with no grad
        if encoder_outputs is None:
            with torch.no_grad():
                encoder_outputs = self.encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )
        
        # need an additional tunable encoder M => 
        combined_encoding = torch.cat((encoder_outputs[0], style_encoding), dim=2)
        seq_len = combined_encoding.shape[1]
        modifier_outputs = []
        for i in range(seq_len):
            modifier_output = self.modifier(combined_encoding[:, i, :])
            modifier_outputs += [modifier_output.unsqueeze(1)]
        modifier_outputs = torch.cat(modifier_outputs, dim=1)
        
        # decoder G, with no grad
        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        with torch.no_grad():
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=modifier_outputs,
                encoder_attention_mask=attention_mask,
                head_mask=decoder_head_mask,
                cross_attn_head_mask=cross_attn_head_mask,
                past_key_values=past_key_values,
                inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        outputs = None
        if not return_dict:
            outputs = decoder_outputs + encoder_outputs
        else:
            outputs = Seq2SeqModelOutput(
                last_hidden_state=decoder_outputs.last_hidden_state,
                past_key_values=decoder_outputs.past_key_values,
                decoder_hidden_states=decoder_outputs.hidden_states,
                decoder_attentions=decoder_outputs.attentions,
                cross_attentions=decoder_outputs.cross_attentions,
                encoder_last_hidden_state=encoder_outputs.last_hidden_state,
                encoder_hidden_states=encoder_outputs.hidden_states,
                encoder_attentions=encoder_outputs.attentions,
            )
        
        with torch.no_grad():
            lm_logits = self.model.lm_head(outputs[0]) + self.model.final_logits_bias

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        
        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
    # def forward(self, **inputs):
    #     outputs = self.model(**inputs)
    #     return outputs

In [268]:
inputs = tokenizer("def hello_world():", return_tensors="pt")

In [269]:
inputs.input_ids.shape

torch.Size([1, 7])

In [270]:
style_dim = 2
style_tensor = torch.zeros(style_dim)
style_tensor[0] = 1
style_tensor = style_tensor.unsqueeze(0)
style_tensor, style_tensor.shape

(tensor([[1., 0.]]), torch.Size([1, 2]))

In [271]:
batch_size = inputs.input_ids.shape[0]
seq_len = inputs.input_ids.shape[1]

style_encoding = style_tensor.unsqueeze(1)
for _ in range(1, seq_len):
    style_encoding = torch.cat((style_encoding, style_tensor.unsqueeze(1)), dim=1)

In [272]:
style_encoding.shape

torch.Size([1, 7, 2])

In [273]:
# inference
model = InRepPlusGAN()
output = model(**inputs, style_encoding=style_encoding)


In [274]:
# lm_logits = model.model.lm_head(outputs[0]) + model.model.final_logits_bias
logits = output.logits

In [275]:
tokenizer.batch_decode(logits.argmax(-1))
# ['def world_world ( #python']

['def get (world ( #en_XX']

In [276]:
# we can start with 1 layer
# use embedding layers
class Discriminator(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_size, style_dim):
        super(Discriminator, self).__init__()
        
        self.output_size = output_size
        self.style_dim = style_dim
        
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.rnn = torch.nn.RNN(embedding_dim, output_size, 1, batch_first=True)
        self.linear = torch.nn.Linear(output_size, style_dim)
        
        self.softmax = torch.nn.Softmax(dim=1)
        
        # self.l2 = torch.nn.Linear(self.config.d_model + style_dim, self.config.d_model)
        # self.l3 = torch.nn.Linear(self.config.d_model + style_dim, self.config.d_model)

    def forward(self, x):
        batch_size = x.shape[0]
        seq_len = x.shape[1]
        embedded_x = self.embedding(x)

        # RNN Layer
        init_hidden = torch.zeros(1, batch_size, self.output_size)
        output, hidden = self.rnn(embedded_x, init_hidden)

        # Linear Layer
        hidden = hidden.squeeze(1)
        output = self.linear(hidden)
        logits = self.softmax(output)
        return logits


In [277]:
discriminator = Discriminator(vocab_size=model.config.vocab_size, embedding_dim=512, output_size=128, style_dim=style_dim)

In [278]:
pred = discriminator(inputs.input_ids).argmax(-1)
gold = style_tensor.argmax(-1)

In [279]:
pred, gold

(tensor([0]), tensor([0]))

In [280]:
discriminator

Discriminator(
  (embedding): Embedding(50008, 512)
  (rnn): RNN(512, 128, batch_first=True)
  (linear): Linear(in_features=128, out_features=2, bias=True)
  (softmax): Softmax(dim=1)
)

In [281]:
inputs

{'input_ids': tensor([[  134,  4498, 33456, 11393,  4071,     2, 50002]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [285]:
real_style_tensor = torch.zeros(style_dim)
real_style_tensor[0] = 1

In [290]:
fake_label_tensor = torch.zeros(style_dim).unsqueeze(0)

In [291]:
# Initialize BCELoss function
criterion = torch.nn.BCELoss()
# Setup Adam optimizers for both G and D
discriminator_optimizer = torch.optim.Adam(discriminator.parameters())
generator_optimizer = torch.optim.Adam(model.parameters())

In [292]:
from tqdm.auto import tqdm
num_epochs = 100
for epoch in tqdm(range(num_epochs)):
    # for data in enumerate(dataloader, 0):
    ############################
    # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
    ###########################
    ## Train with all-real batch
    discriminator.zero_grad()
    # Format batch
    real_data = inputs.input_ids
    
    # Forward pass real batch through D
    output = discriminator(real_data)
    
    # Calculate loss on all-real batch
    discriminator_real_loss = criterion(output, real_style_tensor.unsqueeze(0))
    
    # Calculate gradients for D in backward pass
    discriminator_real_loss.backward()
    # D_x = output.mean().item()
    
    ## Train with all-fake batch
    
    generator_output = model(**inputs, style_encoding=style_encoding)
    generated_logits = generator_output.logits
    generated_tokens = torch.nn.functional.gumbel_softmax(generated_logits, hard=True, dim=-1)
    fake_data = generated_tokens.argmax(-1)
    print(tokenizer.batch_decode(fake_data))
    # Classify all fake batch with D
    output = discriminator(fake_data)
    
    # Calculate D's loss on the all-fake batch
    discriminator_fake_loss = criterion(output, fake_label_tensor)
    
    # Calculate the gradients for this batch, accumulated (summed) with previous gradients
    discriminator_fake_loss.backward()
    
    # D_G_z1 = output.mean().item()
    
    # Compute error of D as sum over the fake and the real batches
    discriminator_loss = discriminator_real_loss + discriminator_fake_loss
    
    # Update D
    discriminator_optimizer.step()

    ############################
    # (2) Update G network: maximize log(D(G(z)))
    ###########################
    model.zero_grad()

    # Since we just updated D, perform another forward pass of all-fake batch through D
    output = discriminator(fake_data)
    # Calculate G's loss based on this output
    generator_class_loss = criterion(output, style_tensor)
    # Calculate gradients for G
    generator_class_loss.backward()
    # D_G_z2 = output.mean().item()
    # Update G
    generator_optimizer.step()

  0%|          | 0/100 [00:00<?, ?it/s]

['def export (try—ß #!!!']
['def matmul_ Depth ( #en_XX']
['ÁµÄ parse (full_ #PAY']
['def __sp (·æ∂ systemId']
['def·ºè_format ( fallen_XX']
['EncoderCycle_c_Íò´en_XX']
['defSH (spec discover #en_XX']
['def threw (resource_ Functionpython']
['def s (fileecessary printRaster']
['department on (context ( #en_XX']
['def get ( virt ( storepython']
['def do (specÍß¥Convert\u0cd0']
['def getFont.</file ( aryen_XX']
['def refresh ( iterating ( ifen_XX']
['def _ﬁØ si (Â¨æS']
['def·∫ñ (tr (HttpResponse Page']
['def init_iterable ( #‚°£']
['defAllQuery ( PAGE ( #‚Ñû']
['berry do_state ( superlocals']
['def get_params ( fpython']
['def el (Á®± ( )en_XX']
['declar convert_bin_ ifen_XX']
['the export_class (stateen_XX']
['GERPrincipal ( msbuild ( wsdlpython']
['Ë≠∑ _ (‡¨ï (entryÈúç']
['def make (table_AND peek']
['def·É∑ ( parentNode ( catalog.']
['def _›ßall_ (en_XX']
['def render_stream ( #en_XX']
["def'?Âê™array ( Listpython"]
['def.1 (tor2 assert‡≠Ç']
['def cs studentbus ( #method']
['def metho

In [None]:
# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1