In [126]:
import os
import re
import glob
import shutil

import torch
from torch.autograd import Variable

In [127]:
def construct_vocab(file_, max_size=200000, mincount=5):
    vocab2id = {'<s>': 0, '</s>': 1, '<pad>': 2, '<unk>': 3}
    id2vocab = {0: '<s>', 1: '</s>', 2: '<pad>', 3: '<unk>'}
    word_pad = {'<s>': 0, '</s>': 1, '<pad>': 2, '<unk>': 3}
    
    cnt = 4
    with open(file_, 'r') as fp:
        for line in fp:
            arr = re.split('\s', line[:-1])
            if arr[0] == ' ':
                continue
            if arr[0] in word_pad:
                continue
            if int(arr[1]) >= mincount:
                vocab2id[arr[0]] = cnt
                id2vocab[cnt] = arr[0]
                cnt += 1
            if len(vocab2id) == max_size:
                break
    
    return vocab2id, id2vocab

vocab2id, id2vocab = construct_vocab('../sum_data/vocab', max_size=50000)
print len(vocab2id)

50000


In [128]:
def create_batch_file(path_, file_, batch_size, clean=False):
    file_name = os.path.join(path_, file_)
    folder = os.path.join(path_, 'batch_folder'+str(batch_size))
    fkey = 'batch_'
    
    if os.path.exists(folder):
        batch_files = glob.glob(os.path.join(folder, fkey+'*'))
        if len(batch_files) > 0 and clean==False:
            return len(batch_files)
    
    try:
        shutil.rmtree(folder)
        os.mkdir(folder)
    except:
        os.mkdir(folder)
    
    fp = open(file_name, 'r')
    cnt = 0
    for line in fp:
        try:
            arr.append(line)
        except:
            arr = []
        if len(arr) == batch_size:
            fout = open(os.path.join(folder, fkey+str(cnt)), 'w')
            for itm in arr:
                fout.write(itm)
            fout.close()
            arr = []
            cnt += 1
    
    fout = open(os.path.join(folder, fkey+str(cnt)), 'w')
    for itm in arr:
        fout.write(itm)
    fout.close()
    arr = []
    cnt += 1
    fp.close()
    
    return cnt

create_batch_file('../sum_data', 'train.txt', batch_size=16)

17952

In [129]:
def process_minibatch(batch_id, path_, batch_size, vocab2id, max_lens=[400, 100]):
    
    folder = os.path.join(path_, 'batch_folder'+str(batch_size))
    fkey = 'batch_'
    file_ = folder + '/' + fkey + str(batch_id)
    fp = open(file_, 'r')
    src_arr = []
    trg_arr = []
    src_lens = []
    trg_lens = []
    for line in fp:
        arr = re.split('<sec>', line[:-1])
        dabs = re.split('\s', arr[0])
        dabs = filter(None, dabs)
        trg_lens.append(len(dabs))
        
        dabs2id = [
            vocab2id[wd] if wd in vocab2id
            else vocab2id['<unk>']
            for wd in dabs
        ]
        trg_arr.append(dabs2id)
                
        dart = re.split('\s', arr[1])
        dart = filter(None, dart)
        src_lens.append(len(dart))
        dart2id = [
            vocab2id[wd] if wd in vocab2id
            else vocab2id['<unk>']
            for wd in dart
        ]
        src_arr.append(dart2id)
    fp.close()
    
    src_max_lens = max(src_lens) 
    if max_lens[0] < max(src_lens):
        src_max_lens = max_lens[0]
    trg_max_lens = max(trg_lens)
    if max_lens[1] < max(trg_lens):
        trg_max_lens = max(trg_lens)
            
    src_arr = [itm[:src_max_lens] for itm in src_arr]
    trg_arr = [itm[:trg_max_lens] for itm in trg_arr]

    src_arr = [
        itm[:-1] + [vocab2id['<pad>']]*(1+src_max_lens-len(itm))
        for itm in src_arr
    ]
    trg_input_arr = [
        itm[:-1] + [vocab2id['<pad>']]*(1+trg_max_lens-len(itm))
        for itm in trg_arr
    ]
    trg_output_arr = [
        itm[1:] + [vocab2id['<pad>']]*(1+trg_max_lens-len(itm))
        for itm in trg_arr
    ]
    
    src_var = Variable(torch.LongTensor(src_arr))
    trg_input_var = Variable(torch.LongTensor(trg_input_arr))
    trg_output_var = Variable(torch.LongTensor(trg_output_arr))
    
    return src_var, trg_input_var, trg_output_var

process_minibatch(0, '../sum_data', 64, vocab2id)

(Variable containing:
    207      6    263  ...     377     33      2
  10407      6   4107  ...    3776      9      2
   4098      6    822  ...       8   1733      2
         ...            ⋱           ...         
     55    127     53  ...      56   2441      2
    612     55    127  ...      30     48      2
   1372      6    753  ...       2      2      2
 [torch.LongTensor of size 64x400], Variable containing:
      0   1334   5750  ...       2      2      2
      0     69     24  ...       2      2      2
      0    326   9775  ...       2      2      2
         ...            ⋱           ...         
      0     69     24  ...       2      2      2
      0    448     24  ...       2      2      2
      0    142    738  ...       2      2      2
 [torch.LongTensor of size 64x74], Variable containing:
   1334   5750    412  ...       2      2      2
     69     24     36  ...       2      2      2
    326   9775     22  ...       2      2      2
         ...            ⋱       