In [1]:
%reload_ext autoreload
%autoreload 2

from rnn_classifier.train import train_handler
import pandas as pd
import os
import pickle
import torchtext.data as data
from torchtext.data import TabularDataset
from torchtext.data import Field
import gensim
from gensim.models.word2vec import Word2Vec
from gensim.test.utils import common_texts, get_tmpfile
import torch
from torch import nn

## Load Data

In [2]:
intermediates_dir = '../../../intermediates'

TRAIN_PATH = 'cleaned_text_train.csv'
VAL_PATH = 'cleaned_text_val.csv'
TEST_PATH = 'cleaned_text_test.csv'

VOCABULARY_PATH = os.path.join(intermediates_dir, 'text_vocabulary.pkl')
vocabulary = pickle.load(open(VOCABULARY_PATH, 'rb'))
num_tokens = len(vocabulary)

In [3]:
tokenize = lambda x: x.split()
TEXT = Field(sequential = True, tokenize = tokenize, lower = True)
LABEL = Field(sequential = False, use_vocab = False)

train_datafields = [('file_name', None), # not needed
                    ('text', TEXT), 
                    ('image_loc', None),
                    ('damaged_infrastructure', LABEL), 
                    ('damaged_nature', LABEL),
                    ('fires', LABEL), 
                    ('flood', LABEL),
                    ('human_damage', LABEL),
                    ('non_damage', LABEL)]

train, val = TabularDataset.splits(path = intermediates_dir, # the root directory where the data lies
                                   train = TRAIN_PATH, 
                                   validation = VAL_PATH,
                                   format = 'csv',
                                   skip_header = True, 
                                   fields = train_datafields)

test_datafields = [('file_name', None), # not needed
                   ('text', TEXT), 
                   ('image_loc', None),
                   ('damaged_infrastructure', LABEL), 
                   ('damaged_nature', LABEL),
                   ('fires', LABEL), 
                   ('flood', LABEL),
                   ('human_damage', LABEL),
                   ('non_damage', LABEL)]

test = TabularDataset(path = os.path.join(intermediates_dir, TEST_PATH),
                      format = 'csv',
                      skip_header = True,
                      fields = test_datafields)

datasets = {
    'train': train,
    'val': val,
    'test': test
}

In [4]:
TEXT.build_vocab(train)

In [7]:
hyperparams = {
    'rnn_type': 'LSTM', # 'LSTM', 'GRU'
    'embedding_size': 100,
    'num_hidden_units': 500,
    'num_layers': 2,
    'init_lr': 1e-3,
    'grad_clipping': 5,
    'num_epochs': 10,
    'batch_size': 32,
    'dropout_rate': 0,
    'is_bidirectional': True
}

In [None]:
model, best_acc, train_loss_history, train_acc_history, val_acc_history = train_handler(hyperparams, datasets, TEXT, LABEL, len(vocabulary))

In [None]:
torch.save(model.state_dict(), os.path.join('output', 'trained_models', 'lstm_model'))