<a href="https://colab.research.google.com/github/wei-enwang/space-ham/blob/main/main_driver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import nltk
from nltk.corpus import words
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils import data
from preprocess import WholeData, BalancedData

import models
import utils

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/weinwang/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
nltk.download('words')

device = "cuda" if torch.cuda.is_available() else "cpu"
assert device == "cuda"   # use gpu whenever you can!

seed = 32
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

[nltk_data] Downloading package words to /home/weinwang/nltk_data...
[nltk_data]   Package words is already up-to-date!


In [13]:
plot_yes = True

# use one dataset for now
train_data_dir = "./data/enron1/"
test_data_dir = "./data/enron2/"
output_dir = "./output/"

# hyperparameters
batch_size = 64
hidden_size = 128
num_layer = 3
dropout = 0.5
learning_rate = 2e-5
epochs = 100
max_len = 300

In [8]:
vocab = set([str.lower() for str in words.words()])

# train_dataset = WholeData(train_data_dir, src_vocab=vocab, use_max_len=True, max_len=max_len)
# test_dataset = WholeData(test_data_dir, src_vocab=vocab, use_max_len=True, max_len=max_len)
# Balanced dataset
train_dataset = BalancedData(train_data_dir, src_vocab=vocab, use_max_len=True, max_len=max_len)
test_dataset = BalancedData(test_data_dir, src_vocab=vocab, use_max_len=True, max_len=max_len)

w2idx = train_dataset.src_v2id

embed = utils.load_pretrained_vectors(w2idx, "fastText/crawl-300d-2M.vec")
embed = torch.tensor(embed)

Number of ham emails: 1500, spam emails: 1500
Number of ham emails: 1496, spam emails: 1496
Loading pretrained vectors...
234378


0it [00:00, ?it/s]

There are 75835 / 234379 pretrained vectors found.


In [14]:
train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                                   num_workers=8, pin_memory=True)
test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, 
                                  num_workers=8, pin_memory=True, drop_last=True)


In [15]:
model = models.spam_lstm(hidden_size=hidden_size, pretrained_embedding=embed, dropout=dropout).to(device)
loss_fn = nn.BCEWithLogitsLoss().to(device)
opt = Adam(model.parameters(), lr=learning_rate)

Using pretrained vectors...


In [None]:
utils.train_test_scheme(train_dataloader, test_dataloader, model, loss_fn, opt, 
                        task_name="balancew2v+lstm300hid128", epochs=epochs, 
                        vis=plot_yes, print_every=1, img_dir=output_dir)



  1%|          | 1/100 [00:07<12:05,  7.32s/it]

Epoch 0
-------------------------------
Training loss: 0.693263, avg accuracy: 0.500047
Testing loss: 0.693185, avg accuracy: 0.502378


  2%|▏         | 2/100 [00:14<11:56,  7.31s/it]

Epoch 1
-------------------------------
Training loss: 0.693228, avg accuracy: 0.500000
Testing loss: 0.693273, avg accuracy: 0.498641


  3%|▎         | 3/100 [00:21<11:50,  7.32s/it]

Epoch 2
-------------------------------
Training loss: 0.693262, avg accuracy: 0.500000
Testing loss: 0.693178, avg accuracy: 0.501698


  4%|▍         | 4/100 [00:29<11:41,  7.31s/it]

Epoch 3
-------------------------------
Training loss: 0.693137, avg accuracy: 0.500427
Testing loss: 0.693257, avg accuracy: 0.497622


  5%|▌         | 5/100 [00:36<11:34,  7.31s/it]

Epoch 4
-------------------------------
Training loss: 0.693070, avg accuracy: 0.499905
Testing loss: 0.693212, avg accuracy: 0.500000


  6%|▌         | 6/100 [00:43<11:28,  7.32s/it]

Epoch 5
-------------------------------
Training loss: 0.693040, avg accuracy: 0.500095
Testing loss: 0.693214, avg accuracy: 0.499660


  7%|▋         | 7/100 [00:51<11:20,  7.32s/it]

Epoch 6
-------------------------------
Training loss: 0.692983, avg accuracy: 0.500190
Testing loss: 0.693217, avg accuracy: 0.500340


  8%|▊         | 8/100 [00:58<11:13,  7.32s/it]

Epoch 7
-------------------------------
Training loss: 0.693052, avg accuracy: 0.499953
Testing loss: 0.693211, avg accuracy: 0.502038


  9%|▉         | 9/100 [01:05<11:05,  7.32s/it]

Epoch 8
-------------------------------
Training loss: 0.692956, avg accuracy: 0.500095
Testing loss: 0.693230, avg accuracy: 0.501019


 10%|█         | 10/100 [01:13<10:58,  7.32s/it]

Epoch 9
-------------------------------
Training loss: 0.692903, avg accuracy: 0.500190
Testing loss: 0.693262, avg accuracy: 0.498981


 11%|█         | 11/100 [01:20<10:51,  7.33s/it]

Epoch 10
-------------------------------
Training loss: 0.692799, avg accuracy: 0.500380
Testing loss: 0.693272, avg accuracy: 0.500000


 12%|█▏        | 12/100 [01:27<10:45,  7.33s/it]

Epoch 11
-------------------------------
Training loss: 0.692729, avg accuracy: 0.500000
Testing loss: 0.693315, avg accuracy: 0.502378


 13%|█▎        | 13/100 [01:35<10:38,  7.33s/it]

Epoch 12
-------------------------------
Training loss: 0.692567, avg accuracy: 0.500047
Testing loss: 0.693378, avg accuracy: 0.500000


 14%|█▍        | 14/100 [01:42<10:30,  7.34s/it]

Epoch 13
-------------------------------
Training loss: 0.692347, avg accuracy: 0.499810
Testing loss: 0.693444, avg accuracy: 0.500679


 15%|█▌        | 15/100 [01:49<10:24,  7.35s/it]

Epoch 14
-------------------------------
Training loss: 0.691872, avg accuracy: 0.499953
Testing loss: 0.693643, avg accuracy: 0.499321


 16%|█▌        | 16/100 [01:57<10:15,  7.33s/it]

Epoch 15
-------------------------------
Training loss: 0.690905, avg accuracy: 0.500095
Testing loss: 0.694138, avg accuracy: 0.497622


 17%|█▋        | 17/100 [02:04<10:08,  7.33s/it]

Epoch 16
-------------------------------
Training loss: 0.687581, avg accuracy: 0.499763
Testing loss: 0.698898, avg accuracy: 0.500679


 18%|█▊        | 18/100 [02:11<10:00,  7.33s/it]

Epoch 17
-------------------------------
Training loss: 0.681747, avg accuracy: 0.499858
Testing loss: 0.706628, avg accuracy: 0.498641


 19%|█▉        | 19/100 [02:19<09:53,  7.33s/it]

Epoch 18
-------------------------------
Training loss: 0.675756, avg accuracy: 0.500190
Testing loss: 0.710211, avg accuracy: 0.498641


 20%|██        | 20/100 [02:26<09:45,  7.32s/it]

Epoch 19
-------------------------------
Training loss: 0.672405, avg accuracy: 0.500142
Testing loss: 0.712243, avg accuracy: 0.502378


 21%|██        | 21/100 [02:33<09:38,  7.32s/it]

Epoch 20
-------------------------------
Training loss: 0.667753, avg accuracy: 0.499858
Testing loss: 0.699756, avg accuracy: 0.501019


 22%|██▏       | 22/100 [02:41<09:30,  7.32s/it]

Epoch 21
-------------------------------
Training loss: 0.664529, avg accuracy: 0.499905
Testing loss: 0.690080, avg accuracy: 0.498641


 23%|██▎       | 23/100 [02:48<09:23,  7.32s/it]

Epoch 22
-------------------------------
Training loss: 0.661691, avg accuracy: 0.500047
Testing loss: 0.700707, avg accuracy: 0.501019


 24%|██▍       | 24/100 [02:55<09:16,  7.32s/it]

Epoch 23
-------------------------------
Training loss: 0.653132, avg accuracy: 0.500237
Testing loss: 0.703403, avg accuracy: 0.502038


 25%|██▌       | 25/100 [03:03<09:09,  7.33s/it]

Epoch 24
-------------------------------
Training loss: 0.605354, avg accuracy: 0.499953
Testing loss: 0.828477, avg accuracy: 0.500679


 26%|██▌       | 26/100 [03:10<09:02,  7.33s/it]

Epoch 25
-------------------------------
Training loss: 0.557551, avg accuracy: 0.499953
Testing loss: 0.705589, avg accuracy: 0.501019


 27%|██▋       | 27/100 [03:17<08:54,  7.32s/it]

Epoch 26
-------------------------------
Training loss: 0.541237, avg accuracy: 0.500570
Testing loss: 0.656677, avg accuracy: 0.502378


 28%|██▊       | 28/100 [03:25<08:47,  7.33s/it]

Epoch 27
-------------------------------
Training loss: 0.535207, avg accuracy: 0.610087
Testing loss: 0.651495, avg accuracy: 0.664742


 29%|██▉       | 29/100 [03:32<08:39,  7.32s/it]

Epoch 28
-------------------------------
Training loss: 0.535945, avg accuracy: 0.724022
Testing loss: 0.642060, avg accuracy: 0.655571


 30%|███       | 30/100 [03:39<08:32,  7.32s/it]

Epoch 29
-------------------------------
Training loss: 0.541660, avg accuracy: 0.715046
Testing loss: 0.645273, avg accuracy: 0.647079


 31%|███       | 31/100 [03:47<08:25,  7.33s/it]

Epoch 30
-------------------------------
Training loss: 0.541465, avg accuracy: 0.714523
Testing loss: 0.649659, avg accuracy: 0.645720


 32%|███▏      | 32/100 [03:54<08:18,  7.32s/it]

Epoch 31
-------------------------------
Training loss: 0.544621, avg accuracy: 0.709489
Testing loss: 0.631591, avg accuracy: 0.636549


 33%|███▎      | 33/100 [04:01<08:11,  7.33s/it]

Epoch 32
-------------------------------
Training loss: 0.560256, avg accuracy: 0.676814
Testing loss: 0.635638, avg accuracy: 0.624321


 34%|███▍      | 34/100 [04:09<08:03,  7.32s/it]

Epoch 33
-------------------------------
Training loss: 0.553530, avg accuracy: 0.683511
Testing loss: 0.634641, avg accuracy: 0.627717


 35%|███▌      | 35/100 [04:16<07:56,  7.33s/it]

Epoch 34
-------------------------------
Training loss: 0.569643, avg accuracy: 0.630462
Testing loss: 0.648921, avg accuracy: 0.620245


 36%|███▌      | 36/100 [04:23<07:48,  7.32s/it]

Epoch 35
-------------------------------
Training loss: 0.557404, avg accuracy: 0.585391
Testing loss: 0.644692, avg accuracy: 0.587976


 37%|███▋      | 37/100 [04:31<07:41,  7.33s/it]

Epoch 36
-------------------------------
Training loss: 0.573599, avg accuracy: 0.558273
Testing loss: 0.642996, avg accuracy: 0.514946


 38%|███▊      | 38/100 [04:38<07:34,  7.33s/it]

Epoch 37
-------------------------------
Training loss: 0.566806, avg accuracy: 0.524079
Testing loss: 0.643471, avg accuracy: 0.510870


 39%|███▉      | 39/100 [04:45<07:27,  7.33s/it]

Epoch 38
-------------------------------
Training loss: 0.560844, avg accuracy: 0.525076
Testing loss: 0.643768, avg accuracy: 0.513587


 40%|████      | 40/100 [04:53<07:20,  7.34s/it]

Epoch 39
-------------------------------
Training loss: 0.560213, avg accuracy: 0.532532
Testing loss: 0.644879, avg accuracy: 0.517323


 41%|████      | 41/100 [05:00<07:13,  7.34s/it]

Epoch 40
-------------------------------
Training loss: 0.559806, avg accuracy: 0.524934
Testing loss: 0.644536, avg accuracy: 0.514606


 42%|████▏     | 42/100 [05:07<07:05,  7.34s/it]

Epoch 41
-------------------------------
Training loss: 0.558010, avg accuracy: 0.539039
Testing loss: 0.647199, avg accuracy: 0.515625


 43%|████▎     | 43/100 [05:15<06:58,  7.35s/it]

Epoch 42
-------------------------------
Training loss: 0.558016, avg accuracy: 0.560933
Testing loss: 0.645846, avg accuracy: 0.518003


 44%|████▍     | 44/100 [05:22<06:50,  7.34s/it]

Epoch 43
-------------------------------
Training loss: 0.558245, avg accuracy: 0.582637
Testing loss: 0.645210, avg accuracy: 0.518682


 45%|████▌     | 45/100 [05:29<06:43,  7.33s/it]

Epoch 44
-------------------------------
Training loss: 0.557822, avg accuracy: 0.558131
Testing loss: 0.646880, avg accuracy: 0.518682


 46%|████▌     | 46/100 [05:37<06:35,  7.33s/it]

Epoch 45
-------------------------------
Training loss: 0.557246, avg accuracy: 0.543551
Testing loss: 0.646164, avg accuracy: 0.515965


 47%|████▋     | 47/100 [05:44<06:28,  7.33s/it]

Epoch 46
-------------------------------
Training loss: 0.557445, avg accuracy: 0.553904
Testing loss: 0.646496, avg accuracy: 0.520380


 48%|████▊     | 48/100 [05:51<06:21,  7.33s/it]

Epoch 47
-------------------------------
Training loss: 0.557229, avg accuracy: 0.563022
Testing loss: 0.644099, avg accuracy: 0.521399


 49%|████▉     | 49/100 [05:59<06:13,  7.33s/it]

Epoch 48
-------------------------------
Training loss: 0.556425, avg accuracy: 0.588288
Testing loss: 0.646014, avg accuracy: 0.520041


 50%|█████     | 50/100 [06:06<06:06,  7.33s/it]

Epoch 49
-------------------------------
Training loss: 0.557088, avg accuracy: 0.551909
Testing loss: 0.647215, avg accuracy: 0.521739


 51%|█████     | 51/100 [06:13<05:59,  7.33s/it]

Epoch 50
-------------------------------
Training loss: 0.556727, avg accuracy: 0.565919
Testing loss: 0.643585, avg accuracy: 0.520720


 52%|█████▏    | 52/100 [06:21<05:52,  7.34s/it]

Epoch 51
-------------------------------
Training loss: 0.557269, avg accuracy: 0.554664
Testing loss: 0.644417, avg accuracy: 0.521739


In [12]:
torch.save(model.state_dict(), output_dir+"balancew2v_lstm300hid128.pt")