In [None]:
import torch
import io

from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.utils import unicode_csv_reader

In [None]:
from classification import datasets, vocab, models

In [None]:
vocab = create_vocab_from_tsv("../datasets/systematic_review/phase1.train.shuf.tsv", [2], ngrams=2)

In [None]:
tokenizer = get_tokenizer("basic_english")
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) + 1


In [None]:
import io
from torchtext.utils import (
    unicode_csv_reader,
)
from torchtext.vocab import Vocab
from torchtext.data.utils import get_tokenizer
from typing import Callable, List
from torch.utils import data
import torch

_default_tokenizer = get_tokenizer("basic_english")
DEFAULT_LABEL_PIPELINE = lambda x: x
DEFAULT_TEXT_PIPELINE = lambda x: _default_tokenizer(x)


def create_torch_dataloader(
    dataset: data.Dataset,
    vocab: Vocab,
    label_pipeline: Callable = DEFAULT_LABEL_PIPELINE,
    text_pipeline: Callable = DEFAULT_TEXT_PIPELINE,
    **kwargs
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _collate_batch(batch):
        label_list, text_list, offsets = [], [], [0]
        for (_label, _text) in batch:
            label_list.append(label_pipeline(_label))
            processed_text = torch.tensor(vocab(text_pipeline(_text)), dtype=torch.int64)
            text_list.append(processed_text)
            offsets.append(processed_text.size(0))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
        text_list = torch.cat(text_list)
        return label_list.to(device), text_list.to(device), offsets.to(device)

    return data.DataLoader(dataset, collate_fn=_collate_batch, **kwargs)


class TSVRawTextIterableDataset(data.IterableDataset):
    def __init__(self, filepath: str, data_columns: List[int]):
        self._number_of_items = _get_tsv_file_length(filepath)
        self._iterator = _create_data_from_tsv(
            filepath, data_column_indices=data_columns
        )
        self._current_position = 0

    def __iter__(self):
        return self

    def __next__(self):
        item = next(self._iterator)
        self._current_position += 1
        return item

    def __len__(self):
        return self._number_of_items


class TSVRawTextMapDataset(data.Dataset):
    def __init__(self, filepath: str, data_columns: List[int]):
        self._records = [
            record
            for record in _create_data_from_tsv(
                filepath, data_column_indices=data_columns
            )
        ]

    def __getitem__(self, index):
        return self._records[index]

    def __len__(self):
        return len(self._records)


def _create_data_from_tsv(data_path, data_column_indices):
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f, delimiter="\t")
        for row in reader:
            data = [row[i] for i in data_column_indices]
            yield int(row[0]), " ".join(data)


def _get_tsv_file_length(data_path):
    with io.open(data_path, encoding="utf8") as f:
        row_count = sum(1 for row in f)

    return row_count


In [None]:
data_iter = TSVRawTextIterableDataset("../datasets/systematic_review/phase1.train.shuf.tsv", [2])
data_map = TSVRawTextMapDataset("../datasets/systematic_review/phase1.train.shuf.tsv", [2])

dl = create_torch_dataloader(data_map, vocab, label_pipeline=label_pipeline, shuffle=False)

In [None]:
from torch import nn

class EmbeddingBagLinearModel(nn.Module):

    def __init__(self, vocab_size: int, embed_dim: int, num_class: int):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        #print(embedded)
        return self.fc(embedded)

In [None]:
num_class = 2
vocab_size = len(vocab)
emsize = 64
model = EmbeddingBagLinearModel(vocab_size, emsize, num_class)

In [None]:
import time
weights=torch.tensor([1.0, 30.0])
criterion = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(model.parameters(), lr=10)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(0, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0   
            start_time = time.time()
            
def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0
    
    pred = []
    true = []
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            pred.append(predicted_label.argmax(1))
            true.append(label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return torch.cat(pred).numpy(), torch.cat(true).numpy()

In [None]:
data_iter = TSVRawTextIterableDataset("../datasets/systematic_review/phase1.train.shuf.tsv", [2])
data_map = TSVRawTextMapDataset("../datasets/systematic_review/phase1.dev.shuf.tsv", [2])
label_pipeline = lambda x: x if x > 0 else 0
text_pipeline = lambda x: list(ngrams_iterator(tokenizer(x), 2))
train_dl = create_torch_dataloader(data_map, vocab, label_pipeline=label_pipeline, text_pipeline=text_pipeline)