In [1]:
import torch
import torch.nn as nn
import pandas as pd

In [2]:
ls

Train Transformer 1.ipynb  dataloader.py
[1m[34m__pycache__[m[m/               models.py


# First, prepare the data

In [3]:
from dataloader import *

In [4]:
# load every dataset manually and create torch objects
comment_df = pd.read_csv("../data/attack_annotated_comments.tsv", sep ='\t')
body_df = pd.read_csv("../data/fake_news_bodies.csv")
stance_df = pd.read_csv("../data/fake_news_stances.csv")
vocab = Vocabulary([comment_df["comment"], body_df["articleBody"], stance_df["Headline"]])
annotation_df = pd.read_csv("../data/attack_annotations.tsv",  sep='\t')

wiki_dataset = WikiDataset(comment_df, annotation_df, vocab)
fake_news_dataset = FakeNewsDataset(body_df, stance_df, vocab)

In [5]:
comment_df.head(2)

Unnamed: 0,rev_id,comment,year,logged_in,ns,sample,split
0,37675,`-NEWLINE_TOKENThis is not ``creative``. Thos...,2002,False,article,random,train
1,44816,`NEWLINE_TOKENNEWLINE_TOKEN:: the term ``stand...,2002,False,article,random,train


In [6]:
body_df.head(2)

Unnamed: 0,Body ID,articleBody,sentence_as_idx
0,0,A small meteorite crashed into a wooded area i...,"[2, 126, 2627, 7088, 36713, 1541, 126, 64723, ..."
1,4,Last week we hinted at what was to come as Ebo...,"[2, 978, 9959, 301, 27551, 181, 184, 243, 21, ..."


In [7]:
stance_df.head(2)

Unnamed: 0,Headline,Body ID,Stance
0,Police find mass graves with at least '15 bodi...,712,unrelated
1,Hundreds of Palestinians flee floods in Gaza a...,158,agree


In [8]:
# check label correspondence
stance_df['Stance'].unique()

array(['unrelated', 'agree', 'disagree', 'discuss'], dtype=object)

In [9]:
annotation_df.head(2)

Unnamed: 0,rev_id,worker_id,quoting_attack,recipient_attack,third_party_attack,other_attack,attack
0,37675,1362,0.0,0.0,0.0,0.0,0.0
1,37675,2408,0.0,0.0,0.0,0.0,0.0


In [10]:
# check label bias
fake_news_dataset.y.mean(axis=0)

array([0.73130953, 0.07360122, 0.01680941, 0.17827984])

In [11]:
# check label bias
wiki_dataset.y.mean(axis=0)

array([0.88270731, 0.11729269])

# Once the data is loaded, prepare the model

In [12]:
from models import *

In [15]:
# simple transformer model, cant be used for classification
vocab_size = len(vocab)
embedding_dim = 64
nhead = 1
hidden_dim = 32
num_layers = 1
feedforward_dim = 64
def model_from_dataset(dataset):
    labels = len(dataset[0][1])
    model = TransformerClassifier(vocab_size, labels, embedding_dim, nhead, feedforward_dim, num_layers)
    return model

In [16]:
model = model_from_dataset([(1, np.array([0, 1, 0]))]) # fake dataset with 3 labels
nparr = np.random.randint(0, vocab_size, size=(5, 25))
x = torch.from_numpy(nparr)
print(x)
print(model(x))

tensor([[201288, 133873, 248726,   5244, 306116, 161056, 338715, 362149, 214736,
         287528, 307005, 323909, 173648,  89833, 138905,  15575,   9140, 240530,
         293650,  14501, 258201, 105418,  51832, 138803, 134351],
        [257260, 330811, 419858, 239446, 332499,  72424, 274231, 102853, 407554,
         332244, 382067, 157056, 138782,  64966, 379028, 180675, 373916, 126081,
          66220, 257145, 233126, 296096, 322472, 126914, 304744],
        [ 85264, 304619, 313022, 344618, 339394, 205910, 383950, 401967, 365365,
          23676, 102137, 125371,   3341, 132438, 224062, 198787, 158698, 182141,
         132533, 290088, 344641, 201694, 272195, 141012, 361418],
        [278137, 409492, 240337, 466018, 268849, 305400, 445415, 139938, 138067,
         139557, 473885,  46106, 455936, 130380, 163545, 441458, 435325, 148277,
          35704,  57611, 173653, 417989, 278269,  79516, 341788],
        [ 12010, 197074,  55527, 335206, 175479, 111003, 169623, 455608, 334065,
       

## Once the models are working, we can implement our train method

In [42]:
criterion = nn.CrossEntropyLoss()
lr = 0.1 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

import time
def train(model, dataset, batch_size=32):
    print("training encoder")
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    n = len(dataset)
    batch_order = np.arange(n)
    np.random.shuffle(batch_order)
    for i in range(0, n, batch_size):
        x_batch, y_batch = dataset[batch_order[i:min(i+batch_size, n)]]
        print(x_batch)
        break
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

In [43]:
train(model, wiki_dataset)

training encoder
[array([     2,     40,      5,    184,     40,   3695,     24,     16,
       390231,    434,   1119,  11151,     67,   3047,   2387,    126,
        20357,   1608,    665,   7637,     67,     85,    449,   1774,
           87,   3069,     21,   6828,     10,  33445,     13, 390232,
          349,      5,    184,  12967,    703,   1075,     21,   6828,
          126,   6287,    124,     10, 103374,     87,   1140,     24,
         2245,     77,    461,     85,   2545,   2415,    342,     10,
       389706,  39693, 390233,    243,    296,   2959,     21,    341,
           60,     77,  12081,   1368,     21,    186,    446,     21,
          212,     10, 284580,   4617,    297,    461,     87,    220,
          158,   6220,     40,    243,  15340,   4219,     16,  13153,
           60,    360,    344,     95,    449,  30588,     87,   2145,
           24,     51,     10,  25811,     77,   1866,    359,    449,
          281,     60,    703,     77,   4538,    237,     

In [39]:
wiki_dataset[0:2][1]

array([[1, 0],
       [1, 0]])