<a href="https://colab.research.google.com/github/taravatp/Text_Style_Transfer/blob/main/Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers==3.1.0
!pip install -qU hazm

In [None]:
import numpy as np
import pandas as pd
import re

import torch
from torch.utils.data import Dataset
from transformers import BertConfig, BertTokenizer

In [None]:
import hazm
from hazm import word_tokenize
import gensim
from gensim.models.word2vec import Word2Vec

In [None]:
class FormalityDataset(Dataset):
  def __init__(self, dataset_path,embedding_path, BertTokenizer, config, max_len, flag):

    self.dataset = pd.read_csv(dataset_path)
    if flag=='train':
      self.dataset = self.dataset.iloc[1000:]
    elif flag == 'test':
      self.dataset = self.dataset.iloc[0:1000]

    self.BertTokenizer = BertTokenizer
    self.config = config
    self.max_len = max_len

    # self.HazmEmbedding = gensim.models.Word2Vec.load(embedding_path)

  def __len__(self):
    return len(self.dataset)

  def get_encoder_input(self, informal_sentence):

    informal_encoding = self.BertTokenizer.encode_plus(
    informal_sentence,
    add_special_tokens=True,
    truncation=True,
    max_length=self.max_len,
    return_token_type_ids=True,
    padding='max_length',
    return_attention_mask=True,
    return_tensors='pt')


    informal_input = {
      'informal_sentence': informal_sentence,
      'input_ids': informal_encoding['input_ids'].flatten(),
      'attention_mask': informal_encoding['attention_mask'].flatten(),
      'token_type_ids': informal_encoding['token_type_ids'].flatten()
      }

    return informal_input

  def get_decoder_input(self,formal_sentence):
    padded_sequence = ['' for i in range(self.max_len)]
    tokens = word_tokenize(formal_sentence)
    if len(tokens) > self.max_len:
      padded_sequence = tokens[:self.max_len]
    else:
      padded_sequence[:len(formal_sentence)] = formal_sentence
    return padded_sequence


  def __getitem__(self,index):

    inFormalForm = self.dataset['inFormalForm'].iloc[index]
    target = self.dataset['formalForm'].iloc[index]

    input_encoder = self.get_encoder_input(inFormalForm)
    # target = self.get_decoder_input(target)

    return input_encoder, target

In [None]:
from torch.utils.data import DataLoader
TRAIN_BATCH_SIZE = 32

In [None]:
if __name__ == "__main__":
  DATASET_PATH = '/content/drive/MyDrive/text_style_transfer/CleanedDataset.csv'
  HAZM_EMBEDDING_PATH = '/content/drive/MyDrive/text_style_transfer/targets_embedding.w2v'
  MODEL_NAME_OR_PATH = 'HooshvareLab/bert-fa-base-uncased'
  BertTokenizer = BertTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
  config = BertConfig.from_pretrained(MODEL_NAME_OR_PATH)
  MAX_LEN = 32
  train_dataset = FormalityDataset(DATASET_PATH,HAZM_EMBEDDING_PATH,BertTokenizer,config,MAX_LEN,'train')
  train_data_loader = DataLoader(train_dataset,batch_size=TRAIN_BATCH_SIZE)