In [1]:
%run ../../src/config.py

In [2]:
import pandas as pd
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [3]:
transcript_train      = pd.read_csv('../../data/prepared_data/transcript_train.csv')
transcript_test       = pd.read_csv('../../data/prepared_data/transcript_test.csv')
transcript_validation = pd.read_csv('../../data/prepared_data/transcript_validation.csv')
transcript_train

Unnamed: 0,campaign_no,arc_no,episode_no,episode_index,episode_label,section_no,line_no,speaker,line,nwords
0,2,5,18,223.0,2-5-18,2,1054,LIAM,"Fjord, do you want to carry these, or do you w...",16
1,3,1,19,274.0,3-1-19,2,1943,LAURA,"Oh, it's my glasses. Hold on. (laughter)",7
2,3,2,17,295.0,3-2-17,4,710,MATT,"I'll allow it for the time being, yeah.",8
3,3,2,25,303.0,3-2-25,4,468,MARISHA,You're more just visiting another landscape.,6
4,2,6,15,241.0,2-6-15,4,13,TRAVIS,As a point of clarification--,5
...,...,...,...,...,...,...,...,...,...,...
371415,3,3,19,325.0,3-3-19,2,1258,MATT,"All right, finishing FCG's go.",5
371416,3,1,11,266.0,3-1-11,2,415,MATT,22 points of lightning damage.,5
371417,2,6,28,254.0,2-6-28,4,395,LIAM,But I'm still pretty low.,5
371418,3,3,20,326.0,3-3-20,2,1058,MATT,"All right, who's keeping watch?",5


In [4]:
transcript_train['speakerno']      = [cast[x]['speakerno'] for x in transcript_train['speaker']]
transcript_test['speakerno']       = [cast[x]['speakerno'] for x in transcript_test['speaker']]
transcript_validation['speakerno'] = [cast[x]['speakerno'] for x in transcript_validation['speaker']]

In [5]:
tokenizer = get_tokenizer("basic_english")

def yield_tokens(data):
    for text in data['line']:
        yield tokenizer(text)


vocab = build_vocab_from_iterator(yield_tokens(transcript_train), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [6]:
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

print(text_pipeline('how do you want to do this williamchadyoung, habcldiekso?'))
print(text_pipeline('<unk>'))

[97, 26, 5, 86, 7, 26, 18, 0, 2, 0, 14]
[0]


In [7]:
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for _label, _text in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

def yield_data(data):
    for _, row in data.iterrows():
        yield(row['speakerno'], row['line'])

dataloader = DataLoader(
    yield_data(transcript_train),
    batch_size = 8,
    shuffle    = False,
    collate_fn = collate_batch
)

In [8]:
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [None]:
train_iter = yield_data(transcript_train)
num_class  = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize     = 64
model      = TextClassificationModel(vocab_size, emsize, num_class).to(device)
print(num_class)
print(vocab_size)

In [None]:
import time

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(
                    epoch, idx, len(dataloader), total_acc / total_count
                )
            )
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [None]:
from torchtext.data.functional import to_map_style_dataset

# Hyperparameters
EPOCHS     = 10  # epoch
LR         = 5  # learning rate
BATCH_SIZE = 64  # batch size for training

criterion  = torch.nn.CrossEntropyLoss()
optimizer  = torch.optim.SGD(model.parameters(), lr=LR)
scheduler  = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_iter = yield_data(transcript_train[:1000])
valid_iter = yield_data(transcript_validation[:1000])
test_iter  = yield_data(transcript_test[:1000])

train_dataset = to_map_style_dataset(train_iter)
valid_dataset = to_map_style_dataset(valid_iter)
test_dataset  = to_map_style_dataset(test_iter)

num_train = len(train_dataset)

train_dataloader = DataLoader(
    train_dataset,
    batch_size = BATCH_SIZE,
    shuffle    = True,
    collate_fn = collate_batch
)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size = BATCH_SIZE,
    shuffle    = True,
    collate_fn = collate_batch
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size = BATCH_SIZE,
    shuffle    = True,
    collate_fn = collate_batch
)

train(train_dataloader)

# for epoch in range(1, EPOCHS + 1):
#     epoch_start_time = time.time()
#     train(train_dataloader)
#     accu_val = evaluate(valid_dataloader)
#     if total_accu is not None and total_accu > accu_val:
#         scheduler.step()
#     else:
#         total_accu = accu_val
#     print("-" * 59)
#     print(
#         "| end of epoch {:3d} | time: {:5.2f}s | "
#         "valid accuracy {:8.3f} ".format(
#             epoch, time.time() - epoch_start_time, accu_val
#         )
#     )
#     print("-" * 59)