In [None]:
!pip install transformers
!pip install pytorch_lightning
!pip install sentencepiece

In [None]:
!nvidia-smi

Sun Mar 13 18:15:54 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    27W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from torch.utils.data import Dataset , DataLoader
import pytorch_lightning as pl
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from itertools import chain
import ast
from transformers import T5TokenizerFast
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
torch.version.__version__

'1.10.0+cu111'

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
root = "drive/MyDrive/lign167_data"

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


# 1. Prepare Data

In [None]:
raw_data = pd.read_csv(f"{root}/sampled_999986.csv", converters={1:ast.literal_eval})
raw_data = raw_data[["sents"]]
raw_data = raw_data.sample(n=500000)

In [None]:
print(raw_data["sents"].apply(lambda x: len(x) - 1).sum())
raw_data["cumlen"] = raw_data["sents"].apply(lambda x: len(x) - 1).cumsum() - 1
raw_data["len"] = raw_data["sents"].apply(lambda x: len(x) - 1)
raw_data = raw_data.set_index("cumlen")

pd.options.display.max_colwidth = 150

1355200


In [None]:
# initialize tokenizer for Dataset building
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

In [None]:
sent_length = 32
class AmazonDataset(Dataset):
    def __init__(self,data):
        self.data = data
        self.len = raw_data["sents"].apply(lambda x: len(x) - 1).sum()

    def __len__(self):
        return self.len

    def to_token(self,sentence):
        return tokenizer.encode(sentence, max_length=sent_length, truncation=True, padding="max_length", return_tensors="pt")[0]
    
    def get_pair(self, idx):
      iidx = idx
      while iidx not in raw_data.index:
        iidx += 1
      line = raw_data["sents"].loc[iidx]
      base = idx - iidx - 2
      return (line[base], line[base + 1])

    def __getitem__(self,index):
        context, input = self.get_pair(index)
        return self.to_token(context), self.to_token(input)

In [None]:
AmazonDataset(raw_data)[0]

(tensor([ 100,  556,   19, 2033,   91,   13,  833,   11,    8, 4818,  163, 4951,
           95,   12, 1758,    3,    4,  345,    5,    1,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]),
 tensor([  94,  930, 8957,   16, 8217,    6,   68,  405,   59,  161,   16, 6687,
           18, 2360, 1758,  489,   41, 7965,   59, 5285, 3538,   18, 2360, 1758,
         7973,   68,    8, 4818,  405,   59,  570,    1]))

In [None]:
batch_size = 64

class AmazonDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        train_dataset, val_dataset = train_test_split(raw_data, test_size=0.01)
        self.train = AmazonDataset(train_dataset)
        self.test = AmazonDataset(val_dataset)
        self.val = AmazonDataset(val_dataset)

    def train_dataloader(self):
        return DataLoader(self.train , batch_size = batch_size , shuffle = True, num_workers=4)
    def test_dataloader(self):
        return DataLoader(self.test , batch_size = batch_size , shuffle = False, num_workers=4)
    def val_dataloader(self):
        return DataLoader(self.val , batch_size = batch_size , shuffle = False, num_workers=4)

# Model Definition

In [None]:
from transformers.models.t5.modeling_t5 import T5Stack, T5PreTrainedModel
from transformers.modeling_outputs import (BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,)
from transformers.models.t5.configuration_t5 import T5Config
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
import warnings
import copy

__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

In [None]:
lambda_factor = 1
class T5ForConditionalGenerationWithExtractor(T5PreTrainedModel):
    _keys_to_ignore_on_load_missing = [
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
    ]
    _keys_to_ignore_on_load_unexpected = [
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

    def __init__(self, config):
        super().__init__(config)
        self.model_dim = config.d_model

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

        extractor_config = copy.deepcopy(config)
        extractor_config.is_decoder = False
        extractor_config.use_cache = False
        extractor_config.is_encoder_decoder = False
        self.extractor = T5Stack(extractor_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        self.decoder = T5Stack(decoder_config, self.shared)

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

        # Model parallel
        self.model_parallel = False
        self.device_map = None

    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.extractor.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True

    def deparallelize(self):
        self.encoder.deparallelize()
        self.extractor.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.extractor = self.extractor.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.extractor.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def get_output_embeddings(self):
        return self.lm_head

    def get_encoder(self):
        return self.encoder
    
    def get_extractor(self):
        return self.extractor

    def get_decoder(self):
        return self.decoder

    def get_extractor_output(self,
        input_ids=None,
        use_cache_context_ids=None, # use cache is simply to a trick to use the generator mixin
        use_cache_target_examplars_ids=None,
        use_cache_origin_examplars_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        extractor_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        context_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,):
      extractor_hidden = None
      if use_cache_context_ids is None:
        target_styles = ()
        for target_ids in use_cache_target_examplars_ids:
          extractor_hidden = self.extractor(
              input_ids=target_ids,
              attention_mask=attention_mask,
              inputs_embeds=context_embeds,
              head_mask=head_mask,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
          )[0]
          target_styles += (extractor_hidden,)

        original_styles = ()
        for origin_ids in use_cache_origin_examplars_ids:
          extractor_hidden = self.extractor(
              input_ids=origin_ids,
              attention_mask=attention_mask,
              inputs_embeds=context_embeds,
              head_mask=head_mask,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
          )[0]
          original_styles += (extractor_hidden,)
          
        input_style = self.extractor(
              input_ids=input_ids,
              attention_mask=attention_mask,
              inputs_embeds=context_embeds,
              head_mask=head_mask,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
          )[0]
        extractor_hidden = lambda_factor * (torch.mean(torch.vstack(target_styles), 0) - (torch.mean(torch.vstack(original_styles), 0))) + input_style
      
      else:
        if extractor_outputs is None:
            extractor_outputs = self.extractor(
                input_ids=use_cache_context_ids,
                attention_mask=attention_mask,
                inputs_embeds=context_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            extractor_outputs = BaseModelOutput(
                last_hidden_state=extractor_outputs[0],
                hidden_states=extractor_outputs[1] if len(extractor_outputs) > 1 else None,
                attentions=extractor_outputs[2] if len(extractor_outputs) > 2 else None,)
        extractor_hidden = extractor_outputs[0]
      return extractor_hidden
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        use_cache_extractor_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        context_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`
        Returns:
        Examples:
        ```python
        >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
        >>> # training
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits
        >>> # inference
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
        >>> ).input_ids  # Batch size 1
        >>> outputs = model.generate(input_ids)
        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        >>> # studies have shown that owning a dog is good for you.
        ```"""
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0] + use_cache_extractor_outputs

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        use_cache_extractor_outputs=None,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
    ):

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            # "input_ids": input_ids,
            # "use_cache_context_ids": use_cache_context_ids,
            # "use_cache_target_examplars_ids": use_cache_target_examplars_ids,
            # "use_cache_origin_examplars_ids": use_cache_origin_examplars_ids,
            "decoder_input_ids": input_ids,
            "past_key_values": past,
            "encoder_outputs": encoder_outputs,
            "use_cache_extractor_outputs": use_cache_extractor_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

    def _reorder_cache(self, past, beam_idx):
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
        if past is None:
            warnings.warning("You might want to consider setting `use_cache=True` to speed up decoding")
            return past

        reordered_decoder_past = ()
        for layer_past_states in past:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
        return reordered_decoder_past


In [None]:
model = T5ForConditionalGenerationWithExtractor.from_pretrained("t5-base")
# model = T5ForConditionalGenerationWithExtractor.from_pretrained("check")

Some weights of T5ForConditionalGenerationWithExtractor were not initialized from the model checkpoint at t5-base and are newly initialized: ['extractor.block.4.layer.0.SelfAttention.o.weight', 'extractor.block.5.layer.0.layer_norm.weight', 'extractor.block.8.layer.0.SelfAttention.q.weight', 'extractor.block.1.layer.1.layer_norm.weight', 'extractor.block.1.layer.1.DenseReluDense.wo.weight', 'extractor.block.11.layer.0.layer_norm.weight', 'extractor.block.8.layer.1.layer_norm.weight', 'extractor.block.6.layer.0.SelfAttention.o.weight', 'extractor.block.0.layer.0.layer_norm.weight', 'extractor.block.0.layer.0.SelfAttention.o.weight', 'extractor.block.6.layer.1.DenseReluDense.wi.weight', 'extractor.block.3.layer.1.DenseReluDense.wo.weight', 'extractor.block.4.layer.0.layer_norm.weight', 'extractor.block.1.layer.0.SelfAttention.k.weight', 'extractor.block.3.layer.0.SelfAttention.o.weight', 'extractor.block.7.layer.0.SelfAttention.k.weight', 'extractor.block.7.layer.0.SelfAttention.q.weight

In [None]:
model.extractor = copy.deepcopy(model.encoder)
model.extractor.is_extractor = True

# Utils

In [None]:
def peek_weights():
  for i, k in model.named_parameters():
    if "block.2.layer.0.SelfAttention.k.weight" in i:
      print(i)
      print(k)

In [None]:
def tokenize(input):
  return tokenizer(input, max_length=sent_length, truncation=True, padding="max_length", return_tensors="pt").input_ids.cuda()

In [83]:
def peek_output(input, context):
  print("input:", input)
  print("context:", context)
  input_ids = tokenize(input)
  context_ids = tokenize(context)
  extractor_output = model.net.get_extractor_output(use_cache_context_ids=context_ids)
  # print(extractor_output)
  outputs = model.net.generate(input_ids=input_ids, use_cache_extractor_outputs=extractor_output, no_repeat_ngram_size=2)
  print(outputs)
  return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [85]:
def peek_transfer_output(input, target_examplars, origin_examplars):
  targets = ()
  for sent in target_examplars:
    targets += (tokenize(sent),)
  origins = ()
  for sent in origin_examplars:
    origins += (tokenize(sent),)
  input_ids = tokenize(input)
  extractor_output = model.net.get_extractor_output(input_ids=input_ids, use_cache_origin_examplars_ids=origins, use_cache_target_examplars_ids=targets)
  outputs = model.net.generate(input_ids=input_ids, use_cache_extractor_outputs=extractor_output, no_repeat_ngram_size=2)
  return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
np.random.choice([False, True])

False

# Module and Training

In [None]:
def drop_noise(sent, drop_rate=0.2):
  for i in range(int(((sent > 1).sum() * drop_rate))):
    randIdx = np.random.choice(np.where((sent > 1).cpu())[0])
    sent = torch.concat((sent[:randIdx], sent[randIdx + 1:]))
  return sent


special_tokens_set = set(tokenizer.all_special_ids)

def rand_token():
  t = np.random.randint(tokenizer.vocab_size)
  if t in special_tokens_set:
    return rand_token()
  return t


def add_noise(sent, drop_rate=0.4):
  for i in range(int(((sent > 1).sum() * drop_rate))):
    randIdx = np.random.choice(np.where((sent > 1).cpu())[0])
    sent = torch.concat((sent[:randIdx], torch.tensor([rand_token()]).cuda(), sent[randIdx:]))
  return sent

def pad_sent(sent, target=sent_length):
  if sent.shape[0] > target:
    return sent[:target]
  return torch.concat((sent, torch.zeros(target - sent.shape[0], dtype=torch.long).cuda()))

# def drop_noise_(sent, drop_rate=0.4):
#   for i in range(int(sent.shape[0] * drop_rate)):
#     randIdx = np.random.randint(sent.shape[0])
#     sent = torch.concat((sent[:randIdx], sent[randIdx + 1:]))
#   return sent

def apply_noise(sents):
  res = ()
  for i, sent in enumerate(sents):
    sent = drop_noise(sent)
    sent = add_noise(sent)
    sent = pad_sent(sent)
    res += (sent,)
  return torch.vstack(res)

In [None]:
class TextSettrModel(LightningModule):
    def __init__(self):
      super().__init__()
      self.net = T5ForConditionalGenerationWithExtractor.from_pretrained("t5-base")
      self.net.extractor = copy.deepcopy(self.net.encoder)
    
    def training_step(self, batch, batch_idx):
      context_ids, input_ids = batch[0], batch[1]
      noisy_input_ids = apply_noise(input_ids)
      if np.random.choice([False, True]):
        # Noisy back translation
        noisy_input_ids = self.net.generate(input_ids=noisy_input_ids, use_cache_extractor_outputs=0, do_sample=True, max_length=sent_length, min_length=sent_length)
      extractor_output = self.net.get_extractor_output(use_cache_context_ids=context_ids)
      return self.net(input_ids=noisy_input_ids, labels = input_ids, use_cache_extractor_outputs=extractor_output).loss

    def validation_step(self, batch, batch_idx):
      context_ids, input_ids = batch[0], batch[1]
      noisy_input_ids = apply_noise(input_ids)
      noisy_input_ids = self.net.generate(input_ids=noisy_input_ids, use_cache_extractor_outputs=0, do_sample=True, max_length=sent_length, min_length=sent_length)
      extractor_output = self.net.get_extractor_output(use_cache_context_ids=context_ids)
      self.log("val_loss", self.net(input_ids=noisy_input_ids, labels = input_ids, use_cache_extractor_outputs=extractor_output).loss)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.net.parameters(), 1e-3)

In [None]:
%load_ext tensorboard

In [23]:
model = TextSettrModel()
module = AmazonDataModule()
# checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=root, filename='{epoch}')
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss")
logger = TensorBoardLogger("logs", name="style_transfer")
trainer = Trainer(max_epochs = 10, gpus=1, default_root_dir=root, val_check_interval=0.25, precision=32, logger=logger, resume_from_checkpoint = f"{root}/10-hour.ckpt)
trainer.fit(model,module)
# model.net.cuda()

Some weights of T5ForConditionalGenerationWithExtractor were not initialized from the model checkpoint at t5-base and are newly initialized: ['extractor.block.4.layer.0.SelfAttention.o.weight', 'extractor.block.5.layer.0.layer_norm.weight', 'extractor.block.8.layer.0.SelfAttention.q.weight', 'extractor.block.1.layer.1.layer_norm.weight', 'extractor.block.1.layer.1.DenseReluDense.wo.weight', 'extractor.block.11.layer.0.layer_norm.weight', 'extractor.block.8.layer.1.layer_norm.weight', 'extractor.block.6.layer.0.SelfAttention.o.weight', 'extractor.block.0.layer.0.layer_norm.weight', 'extractor.block.0.layer.0.SelfAttention.o.weight', 'extractor.block.6.layer.1.DenseReluDense.wi.weight', 'extractor.block.3.layer.1.DenseReluDense.wo.weight', 'extractor.block.4.layer.0.layer_norm.weight', 'extractor.block.1.layer.0.SelfAttention.k.weight', 'extractor.block.3.layer.0.SelfAttention.o.weight', 'extractor.block.7.layer.0.SelfAttention.k.weight', 'extractor.block.7.layer.0.SelfAttention.q.weight

Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [24]:
# model.net
trainer.save_checkpoint(f"{root}/10-hour.ckpt")

In [None]:
%tensorboard --logdir logs/style_transfer

# Results

In [183]:
peek_output("This was a thought-provoking read", "")

input: This was a thought-provoking read
context: 
tensor([[    0,   100,    47,     3,     9,   816,    18, 28268,   608,     5,
             1]], device='cuda:0')


'This was a thought-provoking read.'

In [182]:
formal_examplars = ["This was a remarkably thought-provoking read.",
                  "It is certainly amongst my favorites."
                  "We humbly request your presence at our gala on the 12th."]
informal_examplars = ["reading this rly makes u think",
                      "Its def one of my favs",
                      "come swing by our bbq next week if ya can make it"]
formal_input = "I hereby commit to never purchase anything from this institution in the future."
lambda_factor = 5
peek_transfer_output(formal_input, informal_examplars, formal_examplars)

'i hereby gonna never purchase anything from this seller in the future'

In [162]:
formal_examplars = ["This was a remarkably thought-provoking read.",
                  "It is certainly amongst my favorites."
                  "We humbly request your presence at our gala on the 12th."]
informal_examplars = ["reading this rly makes u think",
                      "Its def one of my favs",
                      "come swing by our bbq next week if ya can make it"]
formal_input = "I couldn’t figure out what the author was trying to say."
lambda_factor = 3
peek_transfer_output(formal_input, formal_examplars, informal_examplars),peek_transfer_output(formal_input, informal_examplars, formal_examplars)

("I I couldn't figure out what the author was trying to say.",
 'I couldnt figure out what the author was trying to say.')

In [165]:
orig_ex = [          
"No thank you, I'd prefer not to.",
"This game could have been better designed.",
"Do you know why they might have delayed the launch?",
"Sorry, I wasn' certain if you were joking."
]

targ_ex = [
"Hell no, you can't make me do that.",
"This game is such a piece of garbage!",
"Why in god's name would they delay the damn launch? Are you frigging kidding me?"
]

sent_ex = "Please get "
lambda_factor = 3
peek_transfer_output(sent_ex, targ_ex, orig_ex), peek_transfer_output(sent_ex, orig_ex, targ_ex)

('Please get it!', 'Please get it. Please do get this.')

In [166]:
n_ex = 100
neg_ex = pd.read_csv(f"{root}/yelp/neg.csv", sep="\t").get("0")[:n_ex]
pos_ex  =  pd.read_csv(f"{root}/yelp/pos.csv", sep="\t").get("0")[:n_ex]
neg_ex, pos_ex

(0       windows have n't been cleaned in years you can see scum on them .
 1                                                   waitresses are slow .
 2                                        just a mess avoid at all costs !
 3                                                                   bad !
 4     now pizza is beyond awful and wings are down there with its level .
                                      ...                                 
 95    no touching , no going ahead of people , no laughter , no dancing .
 96                                                                _num_ .
 97                                                                _num_ .
 98                                                                _num_ .
 99                                                                _num_ .
 Name: 0, Length: 100, dtype: object,
 0                         these donuts have the perfect texture and taste .
 1                                                 good food

In [92]:
ex_sent = "The product is "
lambda_factor = 3
peek_transfer_output(ex_sent, neg_ex, pos_ex), peek_transfer_output(ex_sent, pos_ex, neg_ex)

('The product is defective.', 'The product is good The Product is excellent')

In [93]:
ex_sent = "Apple watch is"
peek_transfer_output(ex_sent, neg_ex, pos_ex), peek_transfer_output(ex_sent, pos_ex, neg_ex)

('Apple watch is useless.', 'Apple watch is amazing.')

In [180]:
ex_sent = "The University of California at San Diego is"
lambda_factor = 5
peek_transfer_output(ex_sent, neg_ex, pos_ex), peek_transfer_output(ex_sent, pos_ex, neg_ex)

('The University of California at San Diego is a dead zone.',
 'The University of California at San Diego is beautiful.')

In [195]:
american_ex = [
               "It cost ten bucks.",
              "My neighbor apologized.",
            "I'm heading out to the bar with some friends."
]
british_ex = [
  "It cost ten quid.",
  "My neighbour apologised.",
  "I'm heading out LO the pub with some mates."
]
sent_ex = "My favourite food: "
lambda_factor = 10
peek_transfer_output(sent_ex, american_ex, british_ex)

'My favourite food: My favorite food was Wendy.'