<a href="https://colab.research.google.com/github/trilokimodi/lt2222-v23-a3/blob/main/Assignment3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
import io
from torch import nn
from torch.utils.data.dataset import random_split

In [3]:
# As instructed the data is in data folder and I assume that CWD is one directory before data dir.
dataset = os.path.join(os.getcwd(), 'data')  # To run in MLTGPU
dataset = "/content/drive/My Drive/MLSNLP/data/enron_sample"  # To run in Colab
#  dataset = "/scratch/lt2222-v23/enron_sample" # To run in mltgpu

In [4]:
os.listdir(dataset)

['corman-s',
 'lenhart-m',
 'donohoe-t',
 'may-l',
 'bailey-s',
 'keiser-k',
 'panus-s',
 'dean-c',
 'mccarty-d',
 'lay-k',
 'salisbury-h',
 'schwieger-j',
 'saibi-e',
 'quigley-d']

In [5]:
import os
import sys
import argparse
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch import optim
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
    

In [28]:
dataset_paths = [os.path.join(os.path.join(dataset, iClass), iText) for iClass in os.listdir(dataset) for iText in os.listdir(os.path.join(dataset, iClass))]
dataset_paths = [iPath for iPath in dataset_paths if '_edit' not in iPath]
labelEncoder = LabelEncoder()
labelEncoder.fit(os.listdir(dataset))

In [7]:
tokenizer = get_tokenizer('basic_english')

In [8]:
import re
for iPath in dataset_paths:
    newFile = list()
    with io.open(iPath, encoding = 'utf-8') as fh:
        for line in fh:
            if line.startswith(('Message-ID', 'Mime-Version:', 'Content-Type:', 'Content-Transfer-Encoding:')):
                pass
            elif line.startswith(('X-From:', 'X-To:', 'X-cc:', 'X-bcc:', 'X-Folder:', 'X-Origin:', 'X-FileName:')):
                pass
            elif re.search('From:', line) is not None:
                pass
            elif re.search('To:', line) is not None:
                pass
            elif re.search('Date:', line) is not None:
                pass
            elif re.search('Sent:', line) is not None:
                pass
            elif re.search(' -----Original Message-----', line) is not None:
                pass
            else:
                newFile.append(line)
    fh.close()

    with open(iPath + '_edit', 'w') as fh:
        for line in newFile:
            fh.write(line)
    fh.close()

In [9]:
with io.open(dataset_paths[110], encoding = 'utf-8') as fh:
    print(fh.readlines())
fh.close()

['Subject: RE: Web site\n', '\n', 'Hi.  Thanks for the info.\n', '\n', 'Subject:\tWeb site\n', '\n', "Hi Shelley - Welcome home!  I hear you are feeling somewhat better.  Can't wait to see you again.  Rob said to forward this to you.\n", '\n', '\n', '\n', 'http://energycommerce.house.gov/\n', '\n']


In [10]:
with io.open(dataset_paths[110] + '_edit', encoding = 'utf-8') as fh:
    print(fh.readlines())
fh.close()

['Subject: RE: Web site\n', '\n', 'Hi.  Thanks for the info.\n', '\n', 'Subject:\tWeb site\n', '\n', "Hi Shelley - Welcome home!  I hear you are feeling somewhat better.  Can't wait to see you again.  Rob said to forward this to you.\n", '\n', '\n', '\n', 'http://energycommerce.house.gov/\n', '\n']


In [11]:
dataset_paths = [iPath + '_edit' for iPath in dataset_paths]

In [12]:
def yield_tokens(paths):
    for iPath in paths:
        with io.open(iPath, encoding = 'utf-8') as fh:
            for line in fh:
                yield tokenizer(line)
        fh.close()

In [13]:
vocabObject = build_vocab_from_iterator(yield_tokens(dataset_paths), specials=["<unk>"])
vocabObject.set_default_index(vocabObject["<unk>"])

In [14]:
len(vocabObject)

20194

In [15]:
class TextDataset(Dataset):
    def __init__(self, dataset_paths=dataset_paths, labelEncoder=labelEncoder, tokenizer=tokenizer, vocab=vocabObject):
        self.datasetPaths = dataset_paths
        self.labelEncoder = labelEncoder
        self.tokenizer=tokenizer
        self.vocab=vocab

    def __len__(self):
        return len(self.datasetPaths)

    def __getitem__(self, idx):
        docFilePath = self.datasetPaths[idx]
        text_pipeline = lambda x: self.vocab(self.tokenizer(x))
        with io.open(docFilePath, encoding = 'utf-8') as fh:
            content = fh.readlines()
        content = ''.join(content)
        textTensor = torch.tensor(text_pipeline(content), dtype=torch.int64)
        label = docFilePath.split('/')[-2]
        label = self.labelEncoder.transform([label])
        return textTensor, label[0]

In [16]:
complete_data = TextDataset()

In [17]:
class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size=len(vocabObject), embed_dim=50, num_class=len(labelEncoder.classes_)):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [18]:
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_text, _label) in batch:
         text_list.append(_text)
         label_list.append(_label)
         offsets.append(_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    return text_list.to(device), label_list.to(device), offsets.to(device)

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TextClassificationModel().to(device)

In [20]:
import time

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (text, label, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (text, label, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

In [21]:
import pandas as pd
df = pd.DataFrame(data = dataset_paths, columns=['X'])
df['y'] = df['X'].apply(lambda x: x.split('/')[-2])

In [22]:
train_prop = int(len(complete_data) * 0.80)
train_data, test_data = random_split(complete_data, [train_prop, len(complete_data) - train_prop])

In [23]:
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_dataset = to_map_style_dataset(train_data)
test_dataset = to_map_style_dataset(test_data)
num_train = int(len(train_data) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_data) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))
    print('-' * 59)

-----------------------------------------------------------
| end of epoch   1 | time:  0.62s | valid accuracy    0.799 
-----------------------------------------------------------
-----------------------------------------------------------
| end of epoch   2 | time:  0.48s | valid accuracy    0.848 
-----------------------------------------------------------
-----------------------------------------------------------
| end of epoch   3 | time:  0.45s | valid accuracy    0.914 
-----------------------------------------------------------
-----------------------------------------------------------
| end of epoch   4 | time:  0.47s | valid accuracy    0.889 
-----------------------------------------------------------
-----------------------------------------------------------
| end of epoch   5 | time:  0.49s | valid accuracy    0.922 
-----------------------------------------------------------
-----------------------------------------------------------
| end of epoch   6 | time:  0.52s |

In [24]:
print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.
test accuracy    0.924
