In [1]:
!pip install transformers
!pip install pytorch_lightning
!pip install sentencepiece

Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 4.0 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 56.9 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 58.4 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 5.0 MB/s 
Collecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 36.8 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Foun

Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 3.2 MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.96


In [2]:
from torch.utils.data import Dataset , DataLoader
import pytorch_lightning as pl
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from itertools import chain
import ast
from transformers import T5TokenizerFast
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
torch.version.__version__

'1.10.0+cu111'

In [3]:
from google.colab import drive
drive.mount('/content/drive/')
root = "drive/MyDrive/lign167_data"

Mounted at /content/drive/


# 1. Prepare Data

In [4]:
raw_data = pd.read_csv(f"{root}/amazon_188703.csv", converters={1:ast.literal_eval})
raw_data = raw_data[["splitted"]]

In [5]:
print(raw_data["splitted"].apply(lambda x: len(x) - 1).sum())
raw_data["cumlen"] = raw_data["splitted"].apply(lambda x: len(x) - 1).cumsum() - 1
raw_data["len"] = raw_data["splitted"].apply(lambda x: len(x) - 1)
raw_data = raw_data.set_index("cumlen")

pd.options.display.max_colwidth = 150
raw_data

344808


Unnamed: 0_level_0,splitted,len
cumlen,Unnamed: 1_level_1,Unnamed: 2_level_1
0,"[Excellent game, worked really well!, Makes me think quickly...good brain exercise!]",1
1,"[The Bible for Alexa is a great addition., Along with getting my daily dose from the scripture of the day, I can hear God's word anytime.]",1
4,"[I like it but I don't love it., I wish Alexa could just open my Verse of the day more easily , she seems to never understand ., I have to be so s...",3
5,"[Very useful skill and much needed., Setup is pretty easy and works reliably]",1
8,"[Takes a lil time to get everything set up but once its done, everything works like a charm., U can even back up your remotes from the app to open...",3
...,...,...
344800,"[It was hard to keep the names of the characters straight., And the development of the characters certainly needed more development., But the acti...",2
344802,"[Another enjoyable read from Catherine Bybee., The hard part is putting it down!!, Great characters that make you want to follow their journey.]",2
344803,"[I was never quite sure who was going to die, Booker or Dani., I recommend the book for a rainy weekend.]",1
344804,"[i really like the way the author leads you in multiple directions as to whom the ""client"" is., fun book to read during my breaks.]",1


In [6]:
# initialize tokenizer for Dataset building
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

In [7]:
sent_length = 32
class AmazonDataset(Dataset):
    def __init__(self,data):
        self.data = data
        self.len = raw_data["splitted"].apply(lambda x: len(x) - 1).sum()

    def __len__(self):
        return self.len

    def to_token(self,sentence):
        return tokenizer.encode(sentence, max_length=sent_length, truncation=True, padding="max_length", return_tensors="pt")[0]
    
    def get_pair(self, idx):
      iidx = idx
      while iidx not in raw_data.index:
        iidx += 1
      line = raw_data["splitted"].loc[iidx]
      base = idx - iidx - 2
      return (line[base], line[base + 1])

    def __getitem__(self,index):
        context, input = self.get_pair(index)
        return self.to_token(context), self.to_token(input)

In [8]:
AmazonDataset(raw_data)[0]

(tensor([11497,   467,     6,  1279,   310,   168,    55,     1,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]),
 tensor([ 1796,     7,   140,   317,  1224,   233, 10452,  2241,  2510,    55,
             1,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]))

In [9]:
batch_size = 32

class AmazonDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        train_dataset, val_dataset = train_test_split(raw_data, test_size=0.2)
        self.train = AmazonDataset(train_dataset)
        self.test = AmazonDataset(val_dataset)
        self.val = AmazonDataset(val_dataset)

    def train_dataloader(self):
        return DataLoader(self.train , batch_size = batch_size , shuffle = True, num_workers=4)
    def test_dataloader(self):
        return DataLoader(self.test , batch_size = batch_size , shuffle = False, num_workers=4)
    def val_dataloader(self):
        return DataLoader(self.val , batch_size = batch_size , shuffle = False, num_workers=4)

# Model Definition

In [10]:
from transformers.models.t5.modeling_t5 import T5Stack, T5PreTrainedModel
from transformers.modeling_outputs import (BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,)
from transformers.models.t5.configuration_t5 import T5Config
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
import warnings
import copy

__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

In [11]:
lambda_factor = 1
class T5ForConditionalGenerationWithExtractor(T5PreTrainedModel):
    _keys_to_ignore_on_load_missing = [
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
    ]
    _keys_to_ignore_on_load_unexpected = [
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

    def __init__(self, config):
        super().__init__(config)
        self.model_dim = config.d_model

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

        extractor_config = copy.deepcopy(config)
        extractor_config.is_decoder = False
        extractor_config.use_cache = False
        extractor_config.is_encoder_decoder = False
        self.extractor = T5Stack(extractor_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        self.decoder = T5Stack(decoder_config, self.shared)

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

        # Model parallel
        self.model_parallel = False
        self.device_map = None

    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.extractor.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True

    def deparallelize(self):
        self.encoder.deparallelize()
        self.extractor.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.extractor = self.extractor.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.extractor.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def get_output_embeddings(self):
        return self.lm_head

    def get_encoder(self):
        return self.encoder
    
    def get_extractor(self):
        return self.extractor

    def get_decoder(self):
        return self.decoder

    def get_extractor_output(self,
        input_ids=None,
        use_cache_context_ids=None, # use cache is simply to a trick to use the generator mixin
        use_cache_target_examplars_ids=None,
        use_cache_origin_examplars_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        extractor_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        context_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,):
      extractor_hidden = None
      if use_cache_context_ids is None:
        target_styles = ()
        for target_ids in use_cache_target_examplars_ids:
          extractor_hidden = self.extractor(
              input_ids=target_ids,
              attention_mask=attention_mask,
              inputs_embeds=context_embeds,
              head_mask=head_mask,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
          )[0]
          target_styles += (extractor_hidden,)

        original_styles = ()
        for origin_ids in use_cache_origin_examplars_ids:
          extractor_hidden = self.extractor(
              input_ids=origin_ids,
              attention_mask=attention_mask,
              inputs_embeds=context_embeds,
              head_mask=head_mask,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
          )[0]
          original_styles += (extractor_hidden,)
          
        input_style = self.extractor(
              input_ids=input_ids,
              attention_mask=attention_mask,
              inputs_embeds=context_embeds,
              head_mask=head_mask,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
          )[0]
        extractor_hidden = lambda_factor * (torch.mean(torch.vstack(target_styles), 0) - (torch.mean(torch.vstack(original_styles), 0))) + input_style
      
      else:
        if extractor_outputs is None:
            extractor_outputs = self.extractor(
                input_ids=use_cache_context_ids,
                attention_mask=attention_mask,
                inputs_embeds=context_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            extractor_outputs = BaseModelOutput(
                last_hidden_state=extractor_outputs[0],
                hidden_states=extractor_outputs[1] if len(extractor_outputs) > 1 else None,
                attentions=extractor_outputs[2] if len(extractor_outputs) > 2 else None,)
        extractor_hidden = extractor_outputs[0]
      return extractor_hidden
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        use_cache_extractor_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        context_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`
        Returns:
        Examples:
        ```python
        >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
        >>> # training
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits
        >>> # inference
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
        >>> ).input_ids  # Batch size 1
        >>> outputs = model.generate(input_ids)
        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        >>> # studies have shown that owning a dog is good for you.
        ```"""
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0] + use_cache_extractor_outputs

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        use_cache_extractor_outputs=None,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
    ):

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            # "input_ids": input_ids,
            # "use_cache_context_ids": use_cache_context_ids,
            # "use_cache_target_examplars_ids": use_cache_target_examplars_ids,
            # "use_cache_origin_examplars_ids": use_cache_origin_examplars_ids,
            "decoder_input_ids": input_ids,
            "past_key_values": past,
            "encoder_outputs": encoder_outputs,
            "use_cache_extractor_outputs": use_cache_extractor_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

    def _reorder_cache(self, past, beam_idx):
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
        if past is None:
            warnings.warning("You might want to consider setting `use_cache=True` to speed up decoding")
            return past

        reordered_decoder_past = ()
        for layer_past_states in past:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
        return reordered_decoder_past


In [12]:
model = T5ForConditionalGenerationWithExtractor.from_pretrained("t5-base")
# model = T5ForConditionalGenerationWithExtractor.from_pretrained("check")

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

Some weights of T5ForConditionalGenerationWithExtractor were not initialized from the model checkpoint at t5-base and are newly initialized: ['extractor.block.9.layer.0.layer_norm.weight', 'extractor.block.2.layer.1.DenseReluDense.wo.weight', 'extractor.block.7.layer.1.DenseReluDense.wo.weight', 'extractor.block.0.layer.1.DenseReluDense.wi.weight', 'extractor.block.5.layer.1.DenseReluDense.wo.weight', 'extractor.block.2.layer.1.layer_norm.weight', 'extractor.block.6.layer.0.SelfAttention.q.weight', 'extractor.final_layer_norm.weight', 'extractor.block.9.layer.0.SelfAttention.o.weight', 'extractor.block.11.layer.1.DenseReluDense.wo.weight', 'extractor.block.1.layer.1.layer_norm.weight', 'extractor.block.2.layer.1.DenseReluDense.wi.weight', 'extractor.block.3.layer.0.layer_norm.weight', 'extractor.block.7.layer.0.SelfAttention.v.weight', 'extractor.block.8.layer.1.DenseReluDense.wo.weight', 'extractor.block.9.layer.0.SelfAttention.k.weight', 'extractor.block.10.layer.0.SelfAttention.v.we

In [13]:
model.extractor = copy.deepcopy(model.encoder)
model.extractor.is_extractor = True

# Utils

In [14]:
def peek_weights():
  for i, k in model.named_parameters():
    if "block.2.layer.0.SelfAttention.k.weight" in i:
      print(i)
      print(k)

In [15]:
def tokenize(input):
  return tokenizer(input, max_length=sent_length, truncation=True, padding="max_length", return_tensors="pt").input_ids.cuda()

In [16]:
def peek_output(input, context):
  print("input:", input)
  print("context:", context)
  input_ids = tokenize(input)
  context_ids = tokenize(context)
  extractor_output = model.net.get_extractor_output(use_cache_context_ids=context_ids)
  print(extractor_output)
  outputs = model.net.generate(input_ids=input_ids, use_cache_extractor_outputs=extractor_output, no_repeat_ngram_size=2)
  print(outputs)
  return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [17]:
def peek_transfer_output(input, target_examplars, origin_examplars):
  targets = ()
  for sent in target_examplars:
    targets += (tokenize(sent),)
  origins = ()
  for sent in origin_examplars:
    origins += (tokenize(sent),)
  input_ids = tokenize(input)
  extractor_output = model.net.get_extractor_output(input_ids=input_ids, use_cache_origin_examplars_ids=origins, use_cache_target_examplars_ids=targets)
  print(extractor_output)
  outputs = model.net.generate(input_ids=input_ids, use_cache_extractor_outputs=extractor_output, no_repeat_ngram_size=2)
  return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [18]:
np.random.choice([False, True])

True

# Module and Training

In [19]:
sent_ex = torch.tensor([[   27,   131,  1663,    79,   133,    59,    43,   974,     8, 11769,
          4546,    49,     5,     1,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  216,   237,  5495,  1361,    95,    16,     3,     9,  1996,   300,
            34,     5,     1,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   27,   183, 13049,    82, 17442,     3,     9,  3591,  1088,    12,
            84,  2586,  2281,    56,    36,     8,  3800,     5,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   27,  2944,    48,     3,  4894,    12,  1115,     3,     9,   314,
           226,   591, 16739,   682,     5,     1,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [ 4242,  2437,    11,   614,    12,   888,   300,    68,  1355,  5366,
            55,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  299,    27,    47,   652,     3,     9,   385,  7718,    13,     8,
         24839,  4496,  5891,     5,     1,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  100,  4035, 15133,   930,   248,    30,    82,  1367,    49,     5,
             5,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  555,    13,     8,   200,  3370,  1335,    27,   664,   608,     5,
             1,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   27,  1800,     8,   733,   906,     3,     9,   385,    72,  1848,
           606,     6,    68,     8,  1006,  4974,  5689,    39,  1388,     5,
             1,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   27,  2112, 13771, 16802,    78,     8,  6519,   744,    31,    17,
         13965,   140,    68,   406,   135,     8, 17956,     7,    33, 10875,
             5,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   37, 10485,    11,  6922,    19,   248,    68,  1879,     8,  1974,
            19,   131,   773,    13, 15170,     5,     1,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [   37,  2507,   505,    19,     3,     9,   385,    72,  2881,   145,
             8,   119,  4935,  3592,    68,    70,   168,  1494,     8,   594,
             5,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [28291,    79,    31,    60,   131,    21, 11649,   383,     8,  6799,
             5,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [ 1853,  5705,  9664,  4386,  4486,  4083, 28027,   427, 17098,   272,
          5767,  5946, 11973,     3, 14750, 19056,  6223, 28969,  3001, 21490,
          1853,     3, 13729,   301, 19114,  7212, 21337,  8043,   445, 19114,
             3,     1],
        [  216,    65,     3,     9,  1627,  1418,    12,    36,  1835,    11,
         15391,    75,     6,    28,   418,    31,     7,    13,  2093,  2886,
          3358,     5,     1,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [ 8302,   608,     6,  4324,  3801,    15,  1329,     3,    99,  1066,
         24839,     5,     1,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]], device='cuda:0')

def drop_noise(sent, drop_rate=0.2):
  for i in range(int(((sent > 1).sum() * drop_rate))):
    randIdx = np.random.choice(np.where((sent > 1).cpu())[0])
    sent = torch.concat((sent[:randIdx], sent[randIdx + 1:]))
  return sent


special_tokens_set = set(tokenizer.all_special_ids)

def rand_token():
  t = np.random.randint(tokenizer.vocab_size)
  if t in special_tokens_set:
    return rand_token()
  return t


def add_noise(sent, drop_rate=0.4):
  for i in range(int(((sent > 1).sum() * drop_rate))):
    randIdx = np.random.choice(np.where((sent > 1).cpu())[0])
    sent = torch.concat((sent[:randIdx], torch.tensor([rand_token()]).cuda(), sent[randIdx:]))
  return sent

def pad_sent(sent, target=sent_length):
  if sent.shape[0] > target:
    return sent[:target]
  return torch.concat((sent, torch.zeros(target - sent.shape[0], dtype=torch.long).cuda()))

# def drop_noise_(sent, drop_rate=0.4):
#   for i in range(int(sent.shape[0] * drop_rate)):
#     randIdx = np.random.randint(sent.shape[0])
#     sent = torch.concat((sent[:randIdx], sent[randIdx + 1:]))
#   return sent

def apply_noise(sents):
  res = ()
  for i, sent in enumerate(sents):
    sent = drop_noise(sent)
    sent = add_noise(sent)
    sent = pad_sent(sent)
    res += (sent,)
  return torch.vstack(res)


In [20]:
class TextSettrModel(LightningModule):
    def __init__(self):
      super().__init__()
      self.net = T5ForConditionalGenerationWithExtractor.from_pretrained("t5-base")
      self.net.extractor = copy.deepcopy(self.net.encoder)
    
    def training_step(self, batch, batch_idx):
      context_ids, input_ids = batch[0], batch[1]
      noisy_input_ids = apply_noise(input_ids)
      if np.random.choice([False, False, True]):
        # Noisy back translation
        noisy_input_ids = self.net.generate(input_ids=noisy_input_ids, use_cache_extractor_outputs=0, do_sample=True, max_length=sent_length, min_length=sent_length)
      extractor_output = self.net.get_extractor_output(use_cache_context_ids=context_ids)
      return self.net(input_ids=noisy_input_ids, labels = input_ids, use_cache_extractor_outputs=extractor_output).loss

    def validation_step(self, batch, batch_idx):
      context_ids, input_ids = batch[0], batch[1]
      noisy_input_ids = apply_noise(input_ids)
      if np.random.choice([False, True]):
        # Noisy back translation
        noisy_input_ids = self.net.generate(input_ids=noisy_input_ids, use_cache_extractor_outputs=0, do_sample=True, max_length=sent_length, min_length=sent_length)
      extractor_output = self.net.get_extractor_output(use_cache_context_ids=context_ids)
      self.log("val_loss", self.net(input_ids=noisy_input_ids, labels = input_ids, use_cache_extractor_outputs=extractor_output).loss)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.net.parameters(), 1e-3)

In [21]:
model = TextSettrModel()
module = AmazonDataModule()
# checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=root, filename='{epoch}')
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss")
trainer = Trainer(max_epochs = 10, gpus=1, default_root_dir=root, val_check_interval=0.25)
trainer.fit(model,module)
# model.net.cuda()

Some weights of T5ForConditionalGenerationWithExtractor were not initialized from the model checkpoint at t5-base and are newly initialized: ['extractor.block.9.layer.0.layer_norm.weight', 'extractor.block.2.layer.1.DenseReluDense.wo.weight', 'extractor.block.7.layer.1.DenseReluDense.wo.weight', 'extractor.block.0.layer.1.DenseReluDense.wi.weight', 'extractor.block.5.layer.1.DenseReluDense.wo.weight', 'extractor.block.2.layer.1.layer_norm.weight', 'extractor.block.6.layer.0.SelfAttention.q.weight', 'extractor.final_layer_norm.weight', 'extractor.block.9.layer.0.SelfAttention.o.weight', 'extractor.block.11.layer.1.DenseReluDense.wo.weight', 'extractor.block.1.layer.1.layer_norm.weight', 'extractor.block.2.layer.1.DenseReluDense.wi.weight', 'extractor.block.3.layer.0.layer_norm.weight', 'extractor.block.7.layer.0.SelfAttention.v.weight', 'extractor.block.8.layer.1.DenseReluDense.wo.weight', 'extractor.block.9.layer.0.SelfAttention.k.weight', 'extractor.block.10.layer.0.SelfAttention.v.we

Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [29]:
model.net

T5ForConditionalGenerationWithExtractor(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              

In [None]:
peek_output("This was a thought-provoking read", "")

In [None]:
formal_examplars = ["This was a remarkably thought-provoking read.",
                  "It is certainly amongst my favorites."
                  "We humbly request your presence at our gala on the 12th."]
informal_examplars = ["reading this rly makes u think",
                      "Its def one of my favs",
                      "come swing by our bbq next week if ya can make it"]
formal_input = "I hereby commit to never purchase anything from this institution in the future."
lambda_factor = 6
peek_transfer_output(formal_input, informal_examplars, formal_examplars)

In [None]:
formal_examplars = ["This was a remarkably thought-provoking read.",
                  "It is certainly amongst my favorites."
                  "We humbly request your presence at our gala on the 12th."]
informal_examplars = ["reading this rly makes u think",
                      "Its def one of my favs",
                      "come swing by our bbq next week if ya can make it"]
formal_input = "I couldn’t figure out what the author was trying to say."
lambda_factor = 4
peek_transfer_output(formal_input, formal_examplars, informal_examplars),peek_transfer_output(formal_input, informal_examplars, formal_examplars)

In [None]:
orig_ex = [          
"No thank you, I'd prefer not to.",
"This game could have been better designed.",
"Do you know why they might have delayed the launch?",
"Sorry, I wasn' certain if you were joking."
]

targ_ex = [
"Hell no, you can't make me do that.",
"This game is such a piece of garbage!",
"Why in god's name would they delay the damn launch? Are you frigging kidding me?"
]

sent_ex = "My favorate movie is"
lambda_factor = 6
peek_transfer_output(sent_ex, targ_ex, orig_ex), peek_transfer_output(sent_ex, orig_ex, targ_ex)

In [26]:
model.net.save_pretrained("check")
model.net.save_pretrained(root + "/check2")