In [1]:
from numpy.core.fromnumeric import mean
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from models import RnnModel, TextCNN
import torch
import torch.nn as nn
import spacy
import torch
from torchtext import data
import argparse
import os
import logging
import random
from tqdm import tqdm
import numpy as np
import json
import torch.optim as optim
import torch.cuda
import nltk
from utils import AverageMeter, TextDataset, count_parameters, layer_wise_parameters, human_format

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LABEL = data.Field(sequential=False, use_vocab=False, is_target=True)
TEXT = data.Field(sequential=True, tokenize=nltk.word_tokenize, lower=True)



In [14]:
@torch.no_grad()
def do_validation(model: nn.Module, val_iter):
    model.eval()
    total = 0
    correct = 0
    for batch in tqdm(val_iter, desc="Validating"):
        labels = batch[0]  # [batch_size]
        texts = batch[1].t()  # [text_len, batch_size]

        output = model(texts)
        predictions = torch.argmax(output, dim=1)
        correct_num = torch.sum(predictions == labels).item()

        total += len(batch[0])
        correct += correct_num
    logger.info("Validation: total %d items; %d are correct." %
                (total, correct))
    return correct / total

In [3]:
train, val = data.TabularDataset.splits(
    path='merged_data', train='train.csv', validation='dev.csv',
    format='csv', skip_header=True,
    fields=[('label_id', LABEL), ('text', TEXT)]
)
test = data.TabularDataset(os.path.join('merged_data', 'test.csv'), format='csv', skip_header=True,
                           fields=[('label_id', None), ('text', TEXT)])



In [4]:
TEXT.build_vocab(train, vectors='glove.840B.300d',
                 max_size=10000,
                 min_freq=10)

In [5]:
train[0].label_id

'3'

In [6]:
class TextDataset(Dataset):
    def __init__(self, datas, vocab, device, is_test=False):
        self.is_test = is_test
        if not is_test:
            self.labels = torch.tensor(
                list(map(int, datas.label_id))).to(device)
        self.features = torch.tensor(list(map(
            lambda sentence: list(map(
                lambda token: vocab.stoi[token], sentence)),
            datas.text
        ))).to(device)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        if self.is_test:
            return self.features[index]
        else:
            return self.labels[index], self.features[index]

In [7]:
def padding(text, length):
    while len(text) < length:
        text.append("<pad>")
    return text

In [9]:
for train_item in train:
    train_item.text = padding(train_item.text[:200], 200)

In [10]:
dataset = TextDataset(train, TEXT.vocab, DEVICE)

In [13]:
dataset[0]

(tensor(3, device='cuda:0'),
 tensor([ 426,  557, 1547,    0,  120,   72,    2,  772,   15,   34,   16,   34,
           17,    0,    3,  426,  369,   24,    0, 3372,    7,    0,    3,   50,
         3834,  813,  321,    4,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,

In [15]:
import pandas as pd

In [16]:
ori_train = pd.read_csv('data/train.csv')
ori_dev = pd.read_csv('data/dev.csv')

In [17]:
new_train = pd.read_csv('merged_data/train.csv')
new_dev = pd.read_csv('merged_data/dev.csv')

In [25]:
print(all(ori_train["Class Index"] == new_train["label"]))
print(all(ori_train["Title"].apply(lambda s : s.replace('\\', ' ')) == new_train["title"]))
print(all(ori_train["Description"].apply(lambda s : s.replace('\\', ' ')) == new_train["description"]))

True
True
True


In [26]:
print(all(ori_dev["Class Index"] == new_dev["label"]))
print(all(ori_dev["Title"].apply(lambda s : s.replace('\\', ' ')) == new_dev["title"]))
print(all(ori_dev["Description"].apply(lambda s : s.replace('\\', ' ')) == new_dev["description"]))

True
True
True
