<a href="https://colab.research.google.com/github/taravatp/Text_Style_Transfer/blob/main/TransRNN_TST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
cd /content/drive/MyDrive/text_style_transfer

/content/drive/MyDrive/text_style_transfer


In [None]:
!pip install -qU hazm
!pip install -q transformers==3.1.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.7/316.7 KB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.6/233.6 KB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for nltk (setup.py) ... [?25l[?25hdone
  Building wheel for libwapiti (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m884.0/884.0 KB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m97.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m82.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader

from transformers import BertConfig, BertTokenizer,BertModel
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

import hazm
from hazm import word_tokenize

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'device: {device}')

train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

device: cuda:0
CUDA is available!  Training on GPU ...


# Data Cleaning

In [None]:
import re

DATASET_PATH = '/content/drive/MyDrive/text_style_transfer/dataset.xlsx'
dataset = pd.read_excel(DATASET_PATH)
normalizer = hazm.Normalizer()

def cleaning(text):
  text = text.strip()
  text = normalizer.normalize(text) #normalizing
  text = re.sub(r"([.!?])", r" \1", text) # inserting a space between words and punctuations
  text = re.sub("\s+", " ", text) #removing redundant white spaces
  return text

def truncate(sentence,max_len=20):
  if len(word_tokenize(sentence)) < max_len:
    return sentence
  else:
    return None

dataset['formalForm'] = dataset['formalForm'].apply(cleaning)
dataset['formalForm'] = dataset['formalForm'].apply(truncate)

dataset['inFormalForm'] = dataset['inFormalForm'].apply(cleaning)
dataset['inFormalForm'] = dataset['inFormalForm'].apply(truncate)

dataset = dataset.dropna()
dataset = dataset.reset_index()

In [None]:
# saving the cleaned data
writePath = '/content/drive/MyDrive/text_style_transfer/CleanedDataset_v2.csv'
dataset.to_csv(writePath, encoding='utf-8', index=False)

# Training embedding models

In [None]:
from hazm import word_tokenize
import gensim
from gensim.models.word2vec import Word2Vec

# reading dataset
dataset = pd.read_csv(DATASET_PATH)
targets = [target for target in dataset.formalForm]
targets = [word_tokenize(target) for target in targets]
model = Word2Vec(sentences=targets, size=config.hidden_size, window=10, min_count=5, seed=42, workers=5)
model.save('targets_embedding.w2v')

# Models

In [None]:
class Encoder(nn.Module):

  def __init__(self,MODEL_NAME_OR_PATH,config):
    super(Encoder,self).__init__()
    self.bert = BertModel.from_pretrained(MODEL_NAME_OR_PATH,config=config)

  def forward(self,input_ids,attention_mask,token_type_ids):
    outputs,pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    # last_hidden_state = outputs.last_hidden_state
    return pooled_output # setence representation

In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size,num_layers):
        super(Decoder, self).__init__()

        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.num_layers = num_layers
        #output size is the number of words in the dictionary
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        # hidden: [1,768] - output: [1,1,768]
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.num_layers, self.hidden_size, device=device)

# Creating language style objects - (This will be used for informal sentences)

In [None]:
SOS_token = 0
EOS_token = 1

DATASET_PATH = '/content/drive/MyDrive/text_style_transfer/CleanedDataset_v2.csv'
dataset = pd.read_csv(DATASET_PATH)

In [None]:
class LangStyle:
  def __init__(self):
    self.word2index = {}
    self.index2word = {}
    self.word2count = {0: "SOS", 1: "EOS"}
    self.n_words = 2

  def add_setence_to_lang(self,sentence):
    for token in word_tokenize(sentence):
      if token not in self.word2index:
        self.word2index[token] = self.n_words
        self.word2count[token] = 1
        self.index2word[self.n_words] = token
        self.n_words +=1
      else:
        self.word2count[token] += 1

# Creating DataLoaders

In [None]:
class TSTSDATA(Dataset):
  def __init__(self, dataset_path, BertTokenizer, config, max_len, formalStyle, flag):

    self.dataset = pd.read_csv(dataset_path)
    if flag == 'train':
      num_samples_train = int(len(self.dataset) * 0.9)
      self.dataset = self.dataset.iloc[:num_samples_train]
    else:
      num_samples_train = int(len(self.dataset) * 0.9)
      self.dataset = self.dataset[num_samples_train:]

    self.BertTokenizer = BertTokenizer
    self.config = config
    self.max_len = max_len
    self.formalStyle  = formalStyle

  def __len__(self):
    return len(self.dataset)

  def get_encoder_input(self, informal_sentence):

    informal_encoding = self.BertTokenizer.encode_plus(
    informal_sentence,
    add_special_tokens=True,
    truncation=True,
    max_length=self.max_len,
    return_token_type_ids=True,
    padding='max_length',
    return_attention_mask=True,
    return_tensors='pt')


    informal_input = {
      'informal_sentence': informal_sentence,
      'input_ids': informal_encoding['input_ids'].flatten(),
      'attention_mask': informal_encoding['attention_mask'].flatten(),
      'token_type_ids': informal_encoding['token_type_ids'].flatten()
      }

    return informal_input

  def get_decoder_input(self,formal_sentence):
    vector = [self.formalStyle.word2index[word] for word in word_tokenize(formal_sentence)]
    vector.append(EOS_token)
    vector = torch.tensor(vector, dtype=torch.long)
    return vector

  def __getitem__(self,index):

    inFormalForm = self.dataset['inFormalForm'].iloc[index]
    target = self.dataset['formalForm'].iloc[index]

    input_encoder = self.get_encoder_input(inFormalForm)
    input_decoder = self.get_decoder_input(target)

    return (input_encoder, input_decoder, target)

# Setting hyperparameters

In [None]:
MAX_LEN = 20
BATCH_SIZE = 1
NUM_EPOCHS = 10
NUM_DECODER_LAYERS = 3

LEARNING_RATE_ENCODER =  2e-5
LEARNING_RATE_DECODER =  0.01
TEACHER_FORCE = 0.5

DATASET_PATH = '/content/drive/MyDrive/text_style_transfer/CleanedDataset_v2.csv'
HAZM_EMBEDDING_PATH = '/content/gdrive/MyDrive/text_style_transfer/targets_embedding.w2v'
MODEL_NAME_OR_PATH = 'HooshvareLab/bert-fa-base-uncased'

# Instatntiating required objects

In [None]:
formalStyle = LangStyle()
for index, row in dataset.iterrows():
  formalStyle.add_setence_to_lang(row['formalForm'])

In [None]:
config = BertConfig.from_pretrained(MODEL_NAME_OR_PATH)
BertTokenizer = BertTokenizer.from_pretrained(MODEL_NAME_OR_PATH)

In [None]:
train_data = TSTSDATA(DATASET_PATH,BertTokenizer,config,MAX_LEN,formalStyle,'train')
test_data = TSTSDATA(DATASET_PATH,BertTokenizer,config,MAX_LEN,formalStyle,'test')

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)

In [None]:
encoder = Encoder(MODEL_NAME_OR_PATH, config).to(device)
decoder = Decoder(config.hidden_size, formalStyle.n_words, NUM_DECODER_LAYERS).to(device)

In [None]:
encoder_optimizer = AdamW(encoder.parameters(), lr=LEARNING_RATE_ENCODER, correct_bias=False)
decoder_optimizer = torch.optim.SGD(decoder.parameters(),lr=LEARNING_RATE_DECODER)
criterion = nn.NLLLoss()

# Training

In [None]:
loss_per_epoch = []
for epoch in range(NUM_EPOCHS):
  sumOfLosses = 0
  print(f"starting epoch number : {epoch}")
  for iter,batch in enumerate(test_dataloader):
    loss = 0
    informal_sentence = batch[0] #[batch_size,num_tokens,1]
    formal_sentence  = batch[1].to(device)  #[batch_size,num_tokens,1]

    input_ids = informal_sentence['input_ids'].to(device)
    attention_mask = informal_sentence['attention_mask'].to(device)
    token_type_ids = informal_sentence['token_type_ids'].to(device)

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    sentence_representation = encoder(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)

    decoder_input = torch.tensor([[SOS_token]]).to(device)
    decoder_hidden = sentence_representation.view(1,1,-1).to(device)
    target_length = formal_sentence.shape[1]

    if TEACHER_FORCE > random.random():
      for index in range(target_length):
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
        loss += criterion(decoder_output, formal_sentence[:,index])
        decoder_input = formal_sentence[:,index]
    else:
      for index in range(target_length):
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
        topv, topi = decoder_output.topk(1)
        decoder_input = topi.squeeze().detach()
        loss += criterion(decoder_output, formal_sentence[:,index])

        if decoder_input.item() == EOS_token:
          break

    sumOfLosses += loss.item() / target_length
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

  if epoch % 5 == 0:
    torch.save(encoder.state_dict(), f"encoder{epoch}.pth")
    torch.save(decoder.state_dict(), f"decoder{epoch}.pth")
  loss_per_epoch.append(sumOfLosses/len(test_dataloader))
  print(f'end of epoch {epoch} and loss is {sumOfLosses/len(test_dataloader)}')

In [None]:
torch.save(encoder.state_dict(), f"encoder{epoch}.pth")
torch.save(decoder.state_dict(), f"decoder{epoch}.pth")

# Evaluation

In [None]:
import torchtext
from torchtext.data.metrics import bleu_score

In [None]:
with torch.no_grad():
  predicted_sentences = []
  bleuScore = 0
  for iter,batch in enumerate(test_dataloader):

    informal_sentence = batch[0] #[batch_size,num_tokens,1]
    formal_sentence  = batch[1].to(device)  #[batch_size,num_tokens,1]
    target_sentence = batch[2]

    input_ids = informal_sentence['input_ids'].to(device)
    attention_mask = informal_sentence['attention_mask'].to(device)
    token_type_ids = informal_sentence['token_type_ids'].to(device)

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    sentence_representation = encoder(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)

    decoder_input = torch.tensor([[SOS_token]]).to(device)
    decoder_hidden = sentence_representation.view(1,1,-1).to(device)
    target_length = formal_sentence.shape[1]


    decoded_words = []
    for index in range(MAX_LEN):
      decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
      topv, topi = decoder_output.data.topk(1)
      if topi.item() == EOS_token:
        decoded_words.append('<EOS>')
        break
      else:
        decoded_words.append(formalStyle.index2word[topi.item()])
        decoder_input = topi.squeeze().detach()

    print(target_sentence)
    print(decoded_words)
    # bleuScore += bleu_score(target_sentence,decoded_words)
    predicted_sentences.append(decoded_words)
    if iter == 10:
      break