-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
106 lines (77 loc) · 3.15 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from torch.utils.data import DataLoader,Dataset
import pandas as pd
import spacy
from torch.nn.utils.rnn import pad_sequence
import os
import pickle
import unicodedata
class Vocabulary:
def __init__(self, freqThresold):
self.itos = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
self.stoi = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
self.freqThresold = freqThresold
def __len__(self):
return len(self.stoi)
@staticmethod
def tokenizer(text):
return [tok.lower() for tok in str(text).split(" ")]
def build_vocabulary(self,sentenceList):
freq = {}
idx = 4
for sent in sentenceList:
for word in self.tokenizer(sent):
#print(word)
if word not in freq:
freq[word] = 1
else:
freq[word] += 1
if freq[word] == self.freqThresold:
self.itos[idx] = word
self.stoi[word] = idx
idx += 1
def encode(self,text):
tokenizedText = self.tokenizer(text)
return [
self.stoi[token] if token in self.stoi else self.stoi['<UNK>'] for token in tokenizedText
]
def storeVocab(self,name):
print("Saving Vocab Dict...")
with open('Language-Translation-Using-PyTorch/output/' + name+ '_itos.pkl', 'wb') as f:
pickle.dump(self.itos, f, pickle.HIGHEST_PROTOCOL)
with open('Language-Translation-Using-PyTorch/output/' + name + '_stoi.pkl', 'wb') as f:
pickle.dump(self.stoi, f, pickle.HIGHEST_PROTOCOL)
class TranslateDataset(Dataset):
def __init__(self,text_file, freqThresold = 2):
self.df = pd.read_csv(text_file)
self.english = self.df['english_sentence']
self.hindi = self.df['hindi_sentence']
self.eng_vocab = Vocabulary(freqThresold)
self.hin_vocab = Vocabulary(freqThresold)
self.eng_vocab.build_vocabulary(self.english.tolist())
self.hin_vocab.build_vocabulary(self.hindi.tolist())
def __len__(self):
return self.df.shape[0]
def __getitem__(self, index):
english = self.english[index]
hindi = self.hindi[index]
encoded_english = [self.eng_vocab.stoi["<SOS>"]]
encoded_english += self.eng_vocab.encode(english)
encoded_english.append(self.eng_vocab.stoi["<EOS>"])
encoded_hindi = [self.hin_vocab.stoi["<SOS>"]]
encoded_hindi += self.hin_vocab.encode(hindi)
encoded_hindi.append(self.hin_vocab.stoi["<EOS>"])
return torch.tensor(encoded_english), torch.tensor(encoded_hindi)
"""
class MyCollate:
def __init__(self,padIdx):
self.padIdx = padIdx
def __call__(self,batch):
imgs = [item[0].unsqueeze(0) for item in batch]
imgs = torch.cat(imgs,dim = 0)
targets = [item[1] for item in batch]
targets = pad_sequence(targets, batch_first=False, padding_value= self.padIdx)
return imgs,targets
"""
# trainDataset = TranslateDataset('Language-Translation-Using-PyTorch/input/train_df.csv')
# print(trainDataset.eng_vocab.stoi.keys())