In [None]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)
%cd /content/drive/My\ Drive/HW4

Mounted at /content/drive
/content/drive/My Drive/HW4


In [None]:
# Setup for GLoVE embeddings (only need to run once)
# !wget http://nlp.stanford.edu/data/glove.6B.zip
# !unzip glove*.zip

# GLoVE Stuff

In [None]:
import numpy as np
embeddings_dict={}
with open("glove.6B.50d.txt", 'r') as f:
  for line in f:
      values = line.split()
      word = values[0]
      vector = np.asarray(values[1:], "float32")
      embeddings_dict[word] = vector

In [None]:
def get_embedding_features(raw_data, is_for_generation=False):
  """
  raw_data: List of tuples (words, label) corresponding to tweets and their corresponding hashtags
  is_for_generation: If using features for generation,
  return the sequence of word vectors, otherwise average over that dimension
  """
  featurized_data = []
  if not is_for_generation:
    vec = np.zeros((1,50))
  else:
    vec = []
  for row in raw_data:
    count = 0
    for word in row[0]:
      if not is_for_generation:
        if word in embeddings_dict:
          vec += embeddings_dict[word]
          count = count+1
      else:
        if word in embeddings_dict:
          vec.append(embeddings_dict[word])
    
    if not is_for_generation and count > 0:
      featurized_data.append([vec / count, row[1]])
    elif count > 0:
      featurized_data.append(np.stack(vec, axis=0))
  return featurized_data

# Hashtag Classification

In [None]:
import os
import json
import re

hashtag_to_label = {
    'superbowl': 0,
    'sb49': 1,
    'patriots': 2,
    'nfl': 3,
    'gopatriots': 4,
    'gohawks': 5
}

datadir = "data"
raw_data = []
for fname in os.listdir(datadir):
  if not fname.endswith('.txt'):
    continue
  
  pound_idx = fname.index('#')
  hashtag = fname[pound_idx+1:-4]
  label_for_file = hashtag_to_label[hashtag]

  with open(os.path.join(datadir, fname)) as json_file:
    for line in json_file.readlines():
      data_pt = json.loads(line)
      
      tweet = data_pt['tweet']['text']
      words = tweet.split()
      words = [x for x in words if x[0] != '#' and not x.startswith('http')] # remove hashtags and urls from tweet
      words = [re.sub(r'[^\w\s]', '', x) for x in words] # Remove punctuation
      
      raw_data.append([words, label_for_file])

In [None]:
classification_feats = get_embedding_features(raw_data, is_for_generation=False)
classification_labels = np.array([x[1] for x in classification_feats])
classification_feats = np.concatenate([x[0] for x in classification_feats], axis=0)
print(classification_feats.shape, classification_labels.shape)
# import numpy as np
# classification_feats = np.load("features.npy")
# classification_labels = np.load("labels.npy")

(2495446, 50) (2495446,)


In [None]:
np.save('features.npy', classification_feats)
np.save('labels.npy', classification_labels)

In [None]:
import torch.nn as nn

class MLP(nn.Module):
  def __init__(self, num_hidden_layers, hidden_width, input_width, output_width):
    super().__init__()
    self.input_layer = nn.Linear(input_width, hidden_width)
    self.relu = nn.ReLU()

    hidden_layer_list = []
    for i in range(num_hidden_layers):
      hidden_layer_list.append(nn.Linear(hidden_width, hidden_width))
      hidden_layer_list.append(nn.ReLU())
    self.hidden_layers = nn.Sequential(*hidden_layer_list)

    self.output_layer = nn.Linear(hidden_width, output_width)
  
  def forward(self, x):
    x = self.input_layer(x)
    x = self.relu(x)
    x = self.hidden_layers(x)
    return self.output_layer(x)

In [None]:
from sklearn.model_selection import train_test_split
import torch

model = MLP(num_hidden_layers=6, hidden_width=128, input_width=50, output_width=6) # 6 output classes corresponding to each hashtag
x_train, x_test, y_train, y_test = train_test_split(classification_feats, classification_labels, test_size=0.1)

x_train = torch.Tensor(x_train)
y_train = torch.Tensor(y_train)
x_test = torch.Tensor(x_test)
y_test = torch.Tensor(y_test)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
batch_size = 128
num_epochs = 1

model.cuda() # moves the model to the GPU

print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)

torch.Size([2245901, 50]) torch.Size([249545, 50]) torch.Size([2245901]) torch.Size([249545])


In [None]:
from tqdm import tqdm
# Helper function to print validation loss/accuracy
def evaluate(model, x_test, y_test, epoch):
  model.eval()
  with torch.no_grad():
    num_correct = 0.0
    for i in range(0, x_test.shape[0], batch_size):
      if i+batch_size > x_test.shape[0]:
        i_end = x_test.shape[0]
      else:
        i_end = i + batch_size

      model_in = x_test[i:i_end].cuda()
      labels = y_test[i:i_end].cuda()

      preds = model(model_in)
      num_correct += torch.sum(torch.argmax(preds, dim=1) == labels)
  print(f"Epoch: {epoch}, Validation Accuracy: {num_correct / x_test.shape[0]}")
  model.train()

evaluate(model, x_test, y_test, -1)
# Main training loop
for epoch in range(num_epochs):
  avg_loss = 0.0
  count = 0
  for i in tqdm(range(0, x_train.shape[0], batch_size)):
    if i+batch_size > x_train.shape[0]:
      i_end = x_train.shape[0]
    else:
      i_end = i + batch_size
  
    model_in = x_train[i:i_end].cuda() # Gets one batch of training data and GT labels onto the GPU
    labels = y_train[i:i_end].type('torch.cuda.LongTensor')

    # Forward pass through the model
    model_out = model(model_in)

    # Backwards pass to calculate gradients and update parameters
    loss = loss_fn(model_out, labels)
    avg_loss += loss.item()
    count += 1
    # print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  print(f"Average Training Loss for Epoch {epoch}: {avg_loss / count}")
  evaluate(model, x_test, y_test, epoch)


Epoch: -1, Validation Accuracy: 0.15814782679080963


100%|██████████| 17547/17547 [01:07<00:00, 261.13it/s]


Average Training Loss for Epoch 0: 5.479392735832636
Epoch: 0, Validation Accuracy: 0.4214750826358795


# Tweet Generation

In [None]:
import os
import json
import re

datadir = "data"
fname = "tweets_#gopatriots.txt" # Change this to load data for a different hashtag
raw_data = []
full_tweets = []
with open(os.path.join(datadir, fname)) as json_file:
  for line in json_file.readlines():
    data_pt = json.loads(line)
      
    tweet = data_pt['tweet']['text']
    words = tweet.split()
    words = [x.lower() for x in words if x[0] != '#' and not x.startswith('http')] # remove hashtags and urls from tweet
    words = [re.sub(r'[^\w\s]', '', x) for x in words] # Remove punctuation
    words = [x for x in words if len(x) > 0]
    if len(words) < 2:
      continue

    full_tweets.append(words)
      
    for i in range(len(words)-1):
      raw_data.append((words[:i+1], words[i+1]))

In [None]:
print(len(raw_data))
print(len(full_tweets))

128957
18731


In [None]:
!pip install pytorch-nlp

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-nlp
  Downloading pytorch_nlp-0.5.0-py3-none-any.whl (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.1/90.1 KB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pytorch-nlp
Successfully installed pytorch-nlp-0.5.0


In [None]:
from torchnlp.encoders.text import StaticTokenizerEncoder

# Creates the vocabulary over all the tweets
encoder = StaticTokenizerEncoder(full_tweets, tokenize=lambda x:x, min_occurrences=10)
print(len(encoder.vocab))

# Tokenizes each tweet in the dataset
processed_data = []
for tweet in full_tweets:
  tokenized = encoder.encode(tweet)

  input = tokenized[:-1]
  label = tokenized[-1]

  processed_data.append((input, label))

# Caluclates max sequence length for use in padding during training
sequence_length = max([x[0].shape[0] for x in processed_data])
print(sequence_length)

1554
27


In [None]:
import torch.nn as nn
import torch
class Generator(nn.Module):
  def __init__(self, vocab_size, hidden_dim=128, num_lstm_layers=3):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
    self.lstm = nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=num_lstm_layers, batch_first=True)
    self.output_layer = nn.Linear(hidden_dim, vocab_size)

  
  def forward(self, x, in_state=None):
    x = self.embedding(x)
    x, state = self.lstm(x)
    return self.output_layer(x), state

def evaluate(model, epoch, test_data, batch_size):
  model.eval()
  num_correct = 0
  for i in tqdm(range(0, len(test_data), batch_size)):
    if i+batch_size > len(test_data):
      i_end = len(test_data)
    else:
      i_end = i+batch_size

    batch = test[i:i_end]

    padded = [pad_tensor(x[0], length=sequence_length) for x in batch] # Necessary to pad each sequence so they are all the same length
    model_in = torch.stack(padded, dim=0).cuda()
    labels = torch.stack([x[1] for x in batch], dim=0).type('torch.cuda.LongTensor')

    with torch.no_grad():
      preds, _ = model(model_in)

    num_correct += torch.sum(torch.argmax(preds[:,-1,:], dim=-1) == labels)
  print(f"Validation Accuracy for epoch {epoch}: {num_correct / len(test_data)}")
  model.train()

In [None]:
from torchnlp.encoders.text import pad_tensor
from tqdm import tqdm
from sklearn.model_selection import train_test_split

model = Generator(len(encoder.vocab)).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
num_epochs = 50
batch_size = 64

train, test = train_test_split(processed_data, test_size=0.2)

state = (torch.zeros((3, batch_size, 128)).cuda(), torch.zeros((3, batch_size, 128)).cuda())
for epoch in range(num_epochs):
  average_loss = 0.0
  for i in tqdm(range(0, len(train), batch_size)):
    if i+batch_size > len(train):
      continue
    else:
      i_end = i+batch_size
    
    batch = train[i:i_end]

    padded = [pad_tensor(x[0], length=sequence_length) for x in batch] # Necessary to pad each sequence so they are all the same length
    model_in = torch.stack(padded, dim=0).cuda()
    labels = torch.stack([x[1] for x in batch], dim=0).type('torch.cuda.LongTensor')

    preds, state = model(model_in)

    # state = (state[0].detach(), state[1].detach())

    loss = loss_fn(preds[:,-1,:], labels)
    average_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  print(f"Average Loss for epoch {epoch}: {average_loss / len(train)}")
  evaluate(model, epoch, test, batch_size)

100%|██████████| 235/235 [00:02<00:00, 101.12it/s]


Average Loss for epoch 0: 0.07795519831660465


100%|██████████| 59/59 [00:00<00:00, 546.35it/s]


Validation Accuracy for epoch 0: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 163.23it/s]


Average Loss for epoch 1: 0.07406135510456352


100%|██████████| 59/59 [00:00<00:00, 521.86it/s]


Validation Accuracy for epoch 1: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 172.62it/s]


Average Loss for epoch 2: 0.07396487585376815


100%|██████████| 59/59 [00:00<00:00, 515.48it/s]


Validation Accuracy for epoch 2: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 159.34it/s]


Average Loss for epoch 3: 0.07390789225403153


100%|██████████| 59/59 [00:00<00:00, 410.27it/s]


Validation Accuracy for epoch 3: 0.29917266964912415


100%|██████████| 235/235 [00:02<00:00, 100.89it/s]


Average Loss for epoch 4: 0.07385780210808135


100%|██████████| 59/59 [00:00<00:00, 221.45it/s]


Validation Accuracy for epoch 4: 0.29917266964912415


100%|██████████| 235/235 [00:03<00:00, 66.38it/s]


Average Loss for epoch 5: 0.07381475192941706


100%|██████████| 59/59 [00:00<00:00, 172.31it/s]


Validation Accuracy for epoch 5: 0.29917266964912415


100%|██████████| 235/235 [00:03<00:00, 70.63it/s]


Average Loss for epoch 6: 0.07378184585347254


100%|██████████| 59/59 [00:00<00:00, 388.69it/s]


Validation Accuracy for epoch 6: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 176.09it/s]


Average Loss for epoch 7: 0.07375315293168487


100%|██████████| 59/59 [00:00<00:00, 531.78it/s]


Validation Accuracy for epoch 7: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 176.54it/s]


Average Loss for epoch 8: 0.07372814085925256


100%|██████████| 59/59 [00:00<00:00, 537.02it/s]


Validation Accuracy for epoch 8: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 175.87it/s]


Average Loss for epoch 9: 0.07370438584655094


100%|██████████| 59/59 [00:00<00:00, 502.10it/s]


Validation Accuracy for epoch 9: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 178.72it/s]


Average Loss for epoch 10: 0.07373811867042333


100%|██████████| 59/59 [00:00<00:00, 530.18it/s]


Validation Accuracy for epoch 10: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 176.45it/s]


Average Loss for epoch 11: 0.07336523291966651


100%|██████████| 59/59 [00:00<00:00, 535.33it/s]


Validation Accuracy for epoch 11: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 132.85it/s]


Average Loss for epoch 12: 0.07306857442894023


100%|██████████| 59/59 [00:00<00:00, 428.34it/s]


Validation Accuracy for epoch 12: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 139.06it/s]


Average Loss for epoch 13: 0.0729550260690464


100%|██████████| 59/59 [00:00<00:00, 521.26it/s]


Validation Accuracy for epoch 13: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 173.32it/s]


Average Loss for epoch 14: 0.07282935739263646


100%|██████████| 59/59 [00:00<00:00, 513.86it/s]


Validation Accuracy for epoch 14: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 175.37it/s]


Average Loss for epoch 15: 0.07274662735496558


100%|██████████| 59/59 [00:00<00:00, 542.77it/s]


Validation Accuracy for epoch 15: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 177.02it/s]


Average Loss for epoch 16: 0.07237911081161336


100%|██████████| 59/59 [00:00<00:00, 534.05it/s]


Validation Accuracy for epoch 16: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 174.04it/s]


Average Loss for epoch 17: 0.07176812324877799


100%|██████████| 59/59 [00:00<00:00, 514.20it/s]


Validation Accuracy for epoch 17: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 175.10it/s]


Average Loss for epoch 18: 0.07149604483205751


100%|██████████| 59/59 [00:00<00:00, 514.15it/s]


Validation Accuracy for epoch 18: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 171.03it/s]


Average Loss for epoch 19: 0.07076626451207683


100%|██████████| 59/59 [00:00<00:00, 519.13it/s]


Validation Accuracy for epoch 19: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 156.45it/s]


Average Loss for epoch 20: 0.06997346631422822


100%|██████████| 59/59 [00:00<00:00, 414.01it/s]


Validation Accuracy for epoch 20: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 131.76it/s]


Average Loss for epoch 21: 0.06933380758068616


100%|██████████| 59/59 [00:00<00:00, 384.24it/s]


Validation Accuracy for epoch 21: 0.29917266964912415


100%|██████████| 235/235 [00:01<00:00, 148.89it/s]


Average Loss for epoch 22: 0.06873630712010678


100%|██████████| 59/59 [00:00<00:00, 516.54it/s]


Validation Accuracy for epoch 22: 0.3077128529548645


100%|██████████| 235/235 [00:01<00:00, 169.42it/s]


Average Loss for epoch 23: 0.06792278286294381


100%|██████████| 59/59 [00:00<00:00, 515.25it/s]


Validation Accuracy for epoch 23: 0.3191886842250824


100%|██████████| 235/235 [00:01<00:00, 172.19it/s]


Average Loss for epoch 24: 0.06685457931949822


100%|██████████| 59/59 [00:00<00:00, 523.77it/s]


Validation Accuracy for epoch 24: 0.32052308320999146


100%|██████████| 235/235 [00:01<00:00, 171.96it/s]


Average Loss for epoch 25: 0.06579543956447526


100%|██████████| 59/59 [00:00<00:00, 518.93it/s]


Validation Accuracy for epoch 25: 0.32292500138282776


100%|██████████| 235/235 [00:01<00:00, 169.45it/s]


Average Loss for epoch 26: 0.0649813267233788


100%|██████████| 59/59 [00:00<00:00, 524.42it/s]


Validation Accuracy for epoch 26: 0.32506003975868225


100%|██████████| 235/235 [00:01<00:00, 171.70it/s]


Average Loss for epoch 27: 0.0642297486379321


100%|██████████| 59/59 [00:00<00:00, 507.94it/s]


Validation Accuracy for epoch 27: 0.3234587609767914


100%|██████████| 235/235 [00:01<00:00, 171.44it/s]


Average Loss for epoch 28: 0.0636410390032337


100%|██████████| 59/59 [00:00<00:00, 515.91it/s]


Validation Accuracy for epoch 28: 0.3269282281398773


100%|██████████| 235/235 [00:01<00:00, 128.59it/s]


Average Loss for epoch 29: 0.06309957064795609


100%|██████████| 59/59 [00:00<00:00, 356.64it/s]


Validation Accuracy for epoch 29: 0.3255937993526459


100%|██████████| 235/235 [00:01<00:00, 134.67it/s]


Average Loss for epoch 30: 0.06215911829974393


100%|██████████| 59/59 [00:00<00:00, 509.11it/s]


Validation Accuracy for epoch 30: 0.3269282281398773


100%|██████████| 235/235 [00:01<00:00, 171.18it/s]


Average Loss for epoch 31: 0.061344105215296665


100%|██████████| 59/59 [00:00<00:00, 525.63it/s]


Validation Accuracy for epoch 31: 0.3290632665157318


100%|██████████| 235/235 [00:01<00:00, 170.42it/s]


Average Loss for epoch 32: 0.06094600359704873


100%|██████████| 59/59 [00:00<00:00, 513.73it/s]


Validation Accuracy for epoch 32: 0.3263944685459137


100%|██████████| 235/235 [00:01<00:00, 167.36it/s]


Average Loss for epoch 33: 0.06036318307946343


100%|██████████| 59/59 [00:00<00:00, 498.81it/s]


Validation Accuracy for epoch 33: 0.33039766550064087


100%|██████████| 235/235 [00:01<00:00, 169.25it/s]


Average Loss for epoch 34: 0.05984117339717725


100%|██████████| 59/59 [00:00<00:00, 503.13it/s]


Validation Accuracy for epoch 34: 0.3293301463127136


100%|██████████| 235/235 [00:01<00:00, 167.18it/s]


Average Loss for epoch 35: 0.05889410542665289


100%|██████████| 59/59 [00:00<00:00, 516.95it/s]


Validation Accuracy for epoch 35: 0.3293301463127136


100%|██████████| 235/235 [00:01<00:00, 173.95it/s]


Average Loss for epoch 36: 0.057950867635326725


100%|██████████| 59/59 [00:00<00:00, 503.17it/s]


Validation Accuracy for epoch 36: 0.32959702610969543


100%|██████████| 235/235 [00:01<00:00, 143.79it/s]


Average Loss for epoch 37: 0.05761367017358494


100%|██████████| 59/59 [00:00<00:00, 372.13it/s]


Validation Accuracy for epoch 37: 0.33013078570365906


100%|██████████| 235/235 [00:01<00:00, 131.22it/s]


Average Loss for epoch 38: 0.05721606902113842


100%|██████████| 59/59 [00:00<00:00, 396.35it/s]


Validation Accuracy for epoch 38: 0.3311983048915863


100%|██████████| 235/235 [00:01<00:00, 156.93it/s]


Average Loss for epoch 39: 0.056868157123284614


100%|██████████| 59/59 [00:00<00:00, 492.90it/s]


Validation Accuracy for epoch 39: 0.33413398265838623


100%|██████████| 235/235 [00:01<00:00, 172.44it/s]


Average Loss for epoch 40: 0.05590186542963129


100%|██████████| 59/59 [00:00<00:00, 485.58it/s]


Validation Accuracy for epoch 40: 0.3311983048915863


100%|██████████| 235/235 [00:01<00:00, 171.16it/s]


Average Loss for epoch 41: 0.05535262395084318


100%|██████████| 59/59 [00:00<00:00, 484.83it/s]


Validation Accuracy for epoch 41: 0.3311983048915863


100%|██████████| 235/235 [00:01<00:00, 174.65it/s]


Average Loss for epoch 42: 0.054339035860179775


100%|██████████| 59/59 [00:00<00:00, 518.67it/s]


Validation Accuracy for epoch 42: 0.32078996300697327


100%|██████████| 235/235 [00:01<00:00, 173.52it/s]


Average Loss for epoch 43: 0.05344844133346016


100%|██████████| 59/59 [00:00<00:00, 511.60it/s]


Validation Accuracy for epoch 43: 0.3290632665157318


100%|██████████| 235/235 [00:01<00:00, 171.85it/s]


Average Loss for epoch 44: 0.05332253438167501


100%|██████████| 59/59 [00:00<00:00, 521.10it/s]


Validation Accuracy for epoch 44: 0.3263944685459137


100%|██████████| 235/235 [00:01<00:00, 171.27it/s]


Average Loss for epoch 45: 0.05347071122626729


100%|██████████| 59/59 [00:00<00:00, 388.68it/s]


Validation Accuracy for epoch 45: 0.3282626271247864


100%|██████████| 235/235 [00:01<00:00, 130.33it/s]


Average Loss for epoch 46: 0.0526625410287188


100%|██████████| 59/59 [00:00<00:00, 413.12it/s]


Validation Accuracy for epoch 46: 0.3255937993526459


100%|██████████| 235/235 [00:01<00:00, 124.11it/s]


Average Loss for epoch 47: 0.05098552986764679


100%|██████████| 59/59 [00:00<00:00, 501.44it/s]


Validation Accuracy for epoch 47: 0.3293301463127136


100%|██████████| 235/235 [00:01<00:00, 140.11it/s]


Average Loss for epoch 48: 0.05020086902446818


100%|██████████| 59/59 [00:00<00:00, 517.63it/s]


Validation Accuracy for epoch 48: 0.31812116503715515


100%|██████████| 235/235 [00:01<00:00, 163.65it/s]


Average Loss for epoch 49: 0.04925700512978716


100%|██████████| 59/59 [00:00<00:00, 467.39it/s]

Validation Accuracy for epoch 49: 0.3101147711277008





**GRU**

In [None]:
class Generator_GRU1(nn.Module):
  def __init__(self, vocab_size, hidden_dim=128, num_lstm_layers=3):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
    self.lstm = nn.GRU(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=num_lstm_layers, batch_first=True)
    self.output_layer = nn.Linear(hidden_dim, vocab_size)

  
  def forward(self, x, in_state=None):
    x = self.embedding(x)
    x, state = self.lstm(x)
    return self.output_layer(x), state

class Generator_GRU2(nn.Module):
  def __init__(self, vocab_size, hidden_dim=128, num_gru_layers=3):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
    self.gru1 = nn.GRU(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=num_gru_layers, batch_first=True)
    self.gru2 = nn.GRU(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=num_gru_layers, batch_first=True)
    self.output_layer1 = nn.Linear(hidden_dim, hidden_dim)
    self.output_layer2 = nn.Linear(hidden_dim, vocab_size)

  
  def forward(self, x, in_state=None):
    x = self.embedding(x)
    x, state1 = self.gru1(x)
    x, state2 = self.gru2(x)
    x = self.output_layer1(x)
    return self.output_layer2(x), state2

In [None]:
model = Generator_GRU1(len(encoder.vocab)).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
num_epochs = 50
batch_size = 64

train, test = train_test_split(processed_data, test_size=0.2)

state = (torch.zeros((3, batch_size, 128)).cuda(), torch.zeros((3, batch_size, 128)).cuda())
for epoch in range(num_epochs):
  average_loss = 0.0
  for i in tqdm(range(0, len(train), batch_size)):
    if i+batch_size > len(train):
      continue
    else:
      i_end = i+batch_size
    
    batch = train[i:i_end]

    padded = [pad_tensor(x[0], length=sequence_length) for x in batch] # Necessary to pad each sequence so they are all the same length
    model_in = torch.stack(padded, dim=0).cuda()
    labels = torch.stack([x[1] for x in batch], dim=0).type('torch.cuda.LongTensor')

    preds, state = model(model_in)

    # state = (state[0].detach(), state[1].detach())

    loss = loss_fn(preds[:,-1,:], labels)
    average_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  print(f"Average Loss for epoch {epoch}: {average_loss / len(train)}")
  evaluate(model, epoch, test, batch_size)

100%|██████████| 235/235 [00:01<00:00, 167.84it/s]


Average Loss for epoch 0: 0.07833673234171748


100%|██████████| 59/59 [00:00<00:00, 591.91it/s]


Validation Accuracy for epoch 0: 0.31838804483413696


100%|██████████| 235/235 [00:01<00:00, 163.63it/s]


Average Loss for epoch 1: 0.07323294728437466


100%|██████████| 59/59 [00:00<00:00, 495.81it/s]


Validation Accuracy for epoch 1: 0.34347477555274963


100%|██████████| 235/235 [00:01<00:00, 160.73it/s]


Average Loss for epoch 2: 0.06899597271839311


100%|██████████| 59/59 [00:00<00:00, 525.29it/s]


Validation Accuracy for epoch 2: 0.3480117619037628


100%|██████████| 235/235 [00:01<00:00, 184.13it/s]


Average Loss for epoch 3: 0.0657163165465944


100%|██████████| 59/59 [00:00<00:00, 494.47it/s]


Validation Accuracy for epoch 3: 0.3640245497226715


100%|██████████| 235/235 [00:03<00:00, 69.38it/s]


Average Loss for epoch 4: 0.06286191496375533


100%|██████████| 59/59 [00:00<00:00, 206.92it/s]


Validation Accuracy for epoch 4: 0.3688284158706665


100%|██████████| 235/235 [00:02<00:00, 96.98it/s] 


Average Loss for epoch 5: 0.060053566704443034


100%|██████████| 59/59 [00:00<00:00, 322.75it/s]


Validation Accuracy for epoch 5: 0.37443289160728455


100%|██████████| 235/235 [00:02<00:00, 92.41it/s]


Average Loss for epoch 6: 0.05726077037257303


100%|██████████| 59/59 [00:00<00:00, 349.12it/s]


Validation Accuracy for epoch 6: 0.36722710728645325


100%|██████████| 235/235 [00:02<00:00, 103.91it/s]


Average Loss for epoch 7: 0.054463094829303726


100%|██████████| 59/59 [00:00<00:00, 536.64it/s]


Validation Accuracy for epoch 7: 0.3658927083015442


100%|██████████| 235/235 [00:01<00:00, 188.00it/s]


Average Loss for epoch 8: 0.05165145965864871


100%|██████████| 59/59 [00:00<00:00, 615.20it/s]


Validation Accuracy for epoch 8: 0.35361623764038086


100%|██████████| 235/235 [00:01<00:00, 182.51it/s]


Average Loss for epoch 9: 0.049005454832770516


100%|██████████| 59/59 [00:00<00:00, 385.77it/s]


Validation Accuracy for epoch 9: 0.34133973717689514


100%|██████████| 235/235 [00:01<00:00, 135.61it/s]


Average Loss for epoch 10: 0.046311097719805544


100%|██████████| 59/59 [00:00<00:00, 399.30it/s]


Validation Accuracy for epoch 10: 0.33226582407951355


100%|██████████| 235/235 [00:01<00:00, 143.67it/s]


Average Loss for epoch 11: 0.04376141346684765


100%|██████████| 59/59 [00:00<00:00, 608.91it/s]


Validation Accuracy for epoch 11: 0.33466774225234985


100%|██████████| 235/235 [00:01<00:00, 180.60it/s]


Average Loss for epoch 12: 0.04126648543610397


100%|██████████| 59/59 [00:00<00:00, 596.64it/s]


Validation Accuracy for epoch 12: 0.3282626271247864


100%|██████████| 235/235 [00:01<00:00, 185.87it/s]


Average Loss for epoch 13: 0.03858400224939235


100%|██████████| 59/59 [00:00<00:00, 565.72it/s]


Validation Accuracy for epoch 13: 0.3263944685459137


100%|██████████| 235/235 [00:01<00:00, 183.72it/s]


Average Loss for epoch 14: 0.03610295218798927


100%|██████████| 59/59 [00:00<00:00, 575.26it/s]


Validation Accuracy for epoch 14: 0.3186549246311188


100%|██████████| 235/235 [00:01<00:00, 180.64it/s]


Average Loss for epoch 15: 0.03384550867300904


100%|██████████| 59/59 [00:00<00:00, 598.92it/s]


Validation Accuracy for epoch 15: 0.32292500138282776


100%|██████████| 235/235 [00:01<00:00, 181.62it/s]


Average Loss for epoch 16: 0.0314866024115731


100%|██████████| 59/59 [00:00<00:00, 579.41it/s]


Validation Accuracy for epoch 16: 0.30878037214279175


100%|██████████| 235/235 [00:01<00:00, 183.66it/s]


Average Loss for epoch 17: 0.02926587322562758


100%|██████████| 59/59 [00:00<00:00, 549.70it/s]


Validation Accuracy for epoch 17: 0.29436883330345154


100%|██████████| 235/235 [00:01<00:00, 178.14it/s]


Average Loss for epoch 18: 0.027108661253631785


100%|██████████| 59/59 [00:00<00:00, 365.90it/s]


Validation Accuracy for epoch 18: 0.29490259289741516


100%|██████████| 235/235 [00:01<00:00, 135.62it/s]


Average Loss for epoch 19: 0.025358290172361526


100%|██████████| 59/59 [00:00<00:00, 393.91it/s]


Validation Accuracy for epoch 19: 0.2975713908672333


100%|██████████| 235/235 [00:01<00:00, 140.07it/s]


Average Loss for epoch 20: 0.023464658507929087


100%|██████████| 59/59 [00:00<00:00, 526.69it/s]


Validation Accuracy for epoch 20: 0.30851349234580994


100%|██████████| 235/235 [00:01<00:00, 183.21it/s]


Average Loss for epoch 21: 0.02178502782137077


100%|██████████| 59/59 [00:00<00:00, 508.05it/s]


Validation Accuracy for epoch 21: 0.30557781457901


100%|██████████| 235/235 [00:01<00:00, 179.77it/s]


Average Loss for epoch 22: 0.020249564807267348


100%|██████████| 59/59 [00:00<00:00, 589.72it/s]


Validation Accuracy for epoch 22: 0.29490259289741516


100%|██████████| 235/235 [00:01<00:00, 184.38it/s]


Average Loss for epoch 23: 0.018662288279789243


100%|██████████| 59/59 [00:00<00:00, 575.57it/s]


Validation Accuracy for epoch 23: 0.28956499695777893


100%|██████████| 235/235 [00:01<00:00, 182.23it/s]


Average Loss for epoch 24: 0.017274732039659596


100%|██████████| 59/59 [00:00<00:00, 554.91it/s]


Validation Accuracy for epoch 24: 0.3029089868068695


100%|██████████| 235/235 [00:01<00:00, 185.05it/s]


Average Loss for epoch 25: 0.016168941633782793


100%|██████████| 59/59 [00:00<00:00, 514.02it/s]


Validation Accuracy for epoch 25: 0.29196691513061523


100%|██████████| 235/235 [00:01<00:00, 182.65it/s]


Average Loss for epoch 26: 0.01500099562446637


100%|██████████| 59/59 [00:00<00:00, 549.69it/s]


Validation Accuracy for epoch 26: 0.2933013141155243


100%|██████████| 235/235 [00:01<00:00, 166.32it/s]


Average Loss for epoch 27: 0.013940902163303922


100%|██████████| 59/59 [00:00<00:00, 378.21it/s]


Validation Accuracy for epoch 27: 0.2975713908672333


100%|██████████| 235/235 [00:01<00:00, 141.88it/s]


Average Loss for epoch 28: 0.0130179548368566


100%|██████████| 59/59 [00:00<00:00, 397.50it/s]


Validation Accuracy for epoch 28: 0.3010408282279968


100%|██████████| 235/235 [00:01<00:00, 145.37it/s]


Average Loss for epoch 29: 0.012254560536327555


100%|██████████| 59/59 [00:00<00:00, 542.06it/s]


Validation Accuracy for epoch 29: 0.29730451107025146


100%|██████████| 235/235 [00:01<00:00, 184.04it/s]


Average Loss for epoch 30: 0.011813385172429343


100%|██████████| 59/59 [00:00<00:00, 568.68it/s]


Validation Accuracy for epoch 30: 0.2892981171607971


100%|██████████| 235/235 [00:01<00:00, 181.68it/s]


Average Loss for epoch 31: 0.011068947189594423


100%|██████████| 59/59 [00:00<00:00, 574.79it/s]


Validation Accuracy for epoch 31: 0.2815585732460022


100%|██████████| 235/235 [00:01<00:00, 182.09it/s]


Average Loss for epoch 32: 0.01067033521254242


100%|██████████| 59/59 [00:00<00:00, 526.06it/s]


Validation Accuracy for epoch 32: 0.2882305979728699


100%|██████████| 235/235 [00:01<00:00, 183.60it/s]


Average Loss for epoch 33: 0.009958161982905082


100%|██████████| 59/59 [00:00<00:00, 544.35it/s]


Validation Accuracy for epoch 33: 0.3029089868068695


100%|██████████| 235/235 [00:01<00:00, 183.94it/s]


Average Loss for epoch 34: 0.009167344806128113


100%|██████████| 59/59 [00:00<00:00, 525.30it/s]


Validation Accuracy for epoch 34: 0.29970642924308777


100%|██████████| 235/235 [00:01<00:00, 183.54it/s]


Average Loss for epoch 35: 0.008907114837300759


100%|██████████| 59/59 [00:00<00:00, 593.52it/s]


Validation Accuracy for epoch 35: 0.29436883330345154


100%|██████████| 235/235 [00:01<00:00, 175.19it/s]


Average Loss for epoch 36: 0.008429157072792572


100%|██████████| 59/59 [00:00<00:00, 373.78it/s]


Validation Accuracy for epoch 36: 0.2933013141155243


100%|██████████| 235/235 [00:01<00:00, 142.13it/s]


Average Loss for epoch 37: 0.008580387057313866


100%|██████████| 59/59 [00:00<00:00, 371.68it/s]


Validation Accuracy for epoch 37: 0.28796371817588806


100%|██████████| 235/235 [00:01<00:00, 152.08it/s]


Average Loss for epoch 38: 0.008210787655028888


100%|██████████| 59/59 [00:00<00:00, 541.39it/s]


Validation Accuracy for epoch 38: 0.2906325161457062


100%|██████████| 235/235 [00:01<00:00, 187.04it/s]


Average Loss for epoch 39: 0.007962875028408853


100%|██████████| 59/59 [00:00<00:00, 533.12it/s]


Validation Accuracy for epoch 39: 0.28235921263694763


100%|██████████| 235/235 [00:01<00:00, 183.39it/s]


Average Loss for epoch 40: 0.007668303690743332


100%|██████████| 59/59 [00:00<00:00, 554.89it/s]


Validation Accuracy for epoch 40: 0.29223379492759705


100%|██████████| 235/235 [00:01<00:00, 184.20it/s]


Average Loss for epoch 41: 0.0072616638125452965


100%|██████████| 59/59 [00:00<00:00, 613.22it/s]


Validation Accuracy for epoch 41: 0.2884974777698517


100%|██████████| 235/235 [00:01<00:00, 185.67it/s]


Average Loss for epoch 42: 0.007039090488215404


100%|██████████| 59/59 [00:00<00:00, 564.91it/s]


Validation Accuracy for epoch 42: 0.2959701120853424


100%|██████████| 235/235 [00:01<00:00, 185.25it/s]


Average Loss for epoch 43: 0.007362954094767316


100%|██████████| 59/59 [00:00<00:00, 600.96it/s]


Validation Accuracy for epoch 43: 0.2860955595970154


100%|██████████| 235/235 [00:01<00:00, 182.93it/s]


Average Loss for epoch 44: 0.007501057033400706


100%|██████████| 59/59 [00:00<00:00, 551.38it/s]


Validation Accuracy for epoch 44: 0.28742995858192444


100%|██████████| 235/235 [00:01<00:00, 171.98it/s]


Average Loss for epoch 45: 0.007134349262903733


100%|██████████| 59/59 [00:00<00:00, 400.33it/s]


Validation Accuracy for epoch 45: 0.28529492020606995


100%|██████████| 235/235 [00:01<00:00, 137.17it/s]


Average Loss for epoch 46: 0.006808930693728302


100%|██████████| 59/59 [00:00<00:00, 359.37it/s]


Validation Accuracy for epoch 46: 0.27595409750938416


100%|██████████| 235/235 [00:01<00:00, 144.39it/s]


Average Loss for epoch 47: 0.006608881043380616


100%|██████████| 59/59 [00:00<00:00, 605.66it/s]


Validation Accuracy for epoch 47: 0.27542033791542053


100%|██████████| 235/235 [00:01<00:00, 187.76it/s]


Average Loss for epoch 48: 0.00609434298470863


100%|██████████| 59/59 [00:00<00:00, 568.59it/s]


Validation Accuracy for epoch 48: 0.27782225608825684


100%|██████████| 235/235 [00:01<00:00, 185.62it/s]


Average Loss for epoch 49: 0.006064862510098917


100%|██████████| 59/59 [00:00<00:00, 541.91it/s]

Validation Accuracy for epoch 49: 0.28049105405807495





In [None]:
model = Generator_GRU2(len(encoder.vocab)).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
num_epochs = 50
batch_size = 64

train, test = train_test_split(processed_data, test_size=0.2)

state = (torch.zeros((3, batch_size, 128)).cuda(), torch.zeros((3, batch_size, 128)).cuda())
for epoch in range(num_epochs):
  average_loss = 0.0
  for i in tqdm(range(0, len(train), batch_size)):
    if i+batch_size > len(train):
      continue
    else:
      i_end = i+batch_size
    
    batch = train[i:i_end]

    padded = [pad_tensor(x[0], length=sequence_length) for x in batch] # Necessary to pad each sequence so they are all the same length
    model_in = torch.stack(padded, dim=0).cuda()
    labels = torch.stack([x[1] for x in batch], dim=0).type('torch.cuda.LongTensor')

    preds, state = model(model_in)

    # state = (state[0].detach(), state[1].detach())

    loss = loss_fn(preds[:,-1,:], labels)
    average_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  
  print(f"Average Loss for epoch {epoch}: {average_loss / len(train)}")
  evaluate(model, epoch, test, batch_size)

100%|██████████| 235/235 [00:02<00:00, 101.60it/s]


Average Loss for epoch 0: 0.0780502396093227


100%|██████████| 59/59 [00:00<00:00, 225.32it/s]


Validation Accuracy for epoch 0: 0.3146517276763916


100%|██████████| 235/235 [00:01<00:00, 120.30it/s]


Average Loss for epoch 1: 0.07389732488832305


100%|██████████| 59/59 [00:00<00:00, 450.93it/s]


Validation Accuracy for epoch 1: 0.3146517276763916


100%|██████████| 235/235 [00:01<00:00, 120.86it/s]


Average Loss for epoch 2: 0.07232489632974635


100%|██████████| 59/59 [00:00<00:00, 447.64it/s]


Validation Accuracy for epoch 2: 0.32479315996170044


100%|██████████| 235/235 [00:02<00:00, 90.85it/s]


Average Loss for epoch 3: 0.07003100097847698


100%|██████████| 59/59 [00:00<00:00, 322.94it/s]


Validation Accuracy for epoch 3: 0.3352015018463135


100%|██████████| 235/235 [00:02<00:00, 110.15it/s]


Average Loss for epoch 4: 0.06799056855548975


100%|██████████| 59/59 [00:00<00:00, 434.25it/s]


Validation Accuracy for epoch 4: 0.33680278062820435


100%|██████████| 235/235 [00:01<00:00, 121.33it/s]


Average Loss for epoch 5: 0.06620786416613349


100%|██████████| 59/59 [00:00<00:00, 428.42it/s]


Validation Accuracy for epoch 5: 0.34107285737991333


100%|██████████| 235/235 [00:01<00:00, 121.89it/s]


Average Loss for epoch 6: 0.0648749097744157


100%|██████████| 59/59 [00:00<00:00, 442.65it/s]


Validation Accuracy for epoch 6: 0.34400853514671326


100%|██████████| 235/235 [00:01<00:00, 120.07it/s]


Average Loss for epoch 7: 0.0635351215017332


100%|██████████| 59/59 [00:00<00:00, 448.47it/s]


Validation Accuracy for epoch 7: 0.3448091745376587


100%|██████████| 235/235 [00:01<00:00, 122.25it/s]


Average Loss for epoch 8: 0.06217514252955431


100%|██████████| 59/59 [00:00<00:00, 445.44it/s]


Validation Accuracy for epoch 8: 0.34827864170074463


100%|██████████| 235/235 [00:02<00:00, 92.79it/s]


Average Loss for epoch 9: 0.06082825425532942


100%|██████████| 59/59 [00:00<00:00, 342.47it/s]


Validation Accuracy for epoch 9: 0.3517480790615082


100%|██████████| 235/235 [00:02<00:00, 109.90it/s]


Average Loss for epoch 10: 0.05925103657651699


100%|██████████| 59/59 [00:00<00:00, 446.86it/s]


Validation Accuracy for epoch 10: 0.34881240129470825


100%|██████████| 235/235 [00:01<00:00, 120.85it/s]


Average Loss for epoch 11: 0.05766916768282286


100%|██████████| 59/59 [00:00<00:00, 429.61it/s]


Validation Accuracy for epoch 11: 0.3493461608886719


100%|██████████| 235/235 [00:01<00:00, 121.42it/s]


Average Loss for epoch 12: 0.055954464826110335


100%|██████████| 59/59 [00:00<00:00, 439.56it/s]


Validation Accuracy for epoch 12: 0.34374165534973145


100%|██████████| 235/235 [00:01<00:00, 121.59it/s]


Average Loss for epoch 13: 0.05451050065000364


100%|██████████| 59/59 [00:00<00:00, 440.02it/s]


Validation Accuracy for epoch 13: 0.347744882106781


100%|██████████| 235/235 [00:01<00:00, 119.92it/s]


Average Loss for epoch 14: 0.05320966877023358


100%|██████████| 59/59 [00:00<00:00, 441.78it/s]


Validation Accuracy for epoch 14: 0.3408059775829315


100%|██████████| 235/235 [00:02<00:00, 92.47it/s]


Average Loss for epoch 15: 0.05145557083933352


100%|██████████| 59/59 [00:00<00:00, 351.17it/s]


Validation Accuracy for epoch 15: 0.3517480790615082


100%|██████████| 235/235 [00:02<00:00, 108.39it/s]


Average Loss for epoch 16: 0.04940094312307864


100%|██████████| 59/59 [00:00<00:00, 435.05it/s]


Validation Accuracy for epoch 16: 0.338670939207077


100%|██████████| 235/235 [00:01<00:00, 121.81it/s]


Average Loss for epoch 17: 0.047511895773440074


100%|██████████| 59/59 [00:00<00:00, 431.52it/s]


Validation Accuracy for epoch 17: 0.3408059775829315


100%|██████████| 235/235 [00:01<00:00, 119.56it/s]


Average Loss for epoch 18: 0.04587861303651772


100%|██████████| 59/59 [00:00<00:00, 443.32it/s]


Validation Accuracy for epoch 18: 0.33893781900405884


100%|██████████| 235/235 [00:01<00:00, 123.16it/s]


Average Loss for epoch 19: 0.04458764236016637


100%|██████████| 59/59 [00:00<00:00, 437.23it/s]


Validation Accuracy for epoch 19: 0.3306645452976227


100%|██████████| 235/235 [00:01<00:00, 122.52it/s]


Average Loss for epoch 20: 0.042692864865653345


100%|██████████| 59/59 [00:00<00:00, 430.45it/s]


Validation Accuracy for epoch 20: 0.3170536458492279


100%|██████████| 235/235 [00:02<00:00, 94.38it/s]


Average Loss for epoch 21: 0.040378004718259634


100%|██████████| 59/59 [00:00<00:00, 342.20it/s]


Validation Accuracy for epoch 21: 0.3122498095035553


100%|██████████| 235/235 [00:02<00:00, 106.97it/s]


Average Loss for epoch 22: 0.038575598344532995


100%|██████████| 59/59 [00:00<00:00, 435.97it/s]


Validation Accuracy for epoch 22: 0.30370962619781494


100%|██████████| 235/235 [00:01<00:00, 120.96it/s]


Average Loss for epoch 23: 0.036575698663191686


100%|██████████| 59/59 [00:00<00:00, 435.65it/s]


Validation Accuracy for epoch 23: 0.29890578985214233


100%|██████████| 235/235 [00:01<00:00, 121.82it/s]


Average Loss for epoch 24: 0.03450968734287541


100%|██████████| 59/59 [00:00<00:00, 435.20it/s]


Validation Accuracy for epoch 24: 0.2562049627304077


100%|██████████| 235/235 [00:01<00:00, 119.73it/s]


Average Loss for epoch 25: 0.033100575900816116


100%|██████████| 59/59 [00:00<00:00, 430.30it/s]


Validation Accuracy for epoch 25: 0.26394450664520264


100%|██████████| 235/235 [00:01<00:00, 121.81it/s]


Average Loss for epoch 26: 0.03178087779849338


100%|██████████| 59/59 [00:00<00:00, 430.66it/s]


Validation Accuracy for epoch 26: 0.27381905913352966


100%|██████████| 235/235 [00:02<00:00, 96.49it/s]


Average Loss for epoch 27: 0.02997547427855643


100%|██████████| 59/59 [00:00<00:00, 328.96it/s]


Validation Accuracy for epoch 27: 0.2767547369003296


100%|██████████| 235/235 [00:02<00:00, 105.50it/s]


Average Loss for epoch 28: 0.028126976544756464


100%|██████████| 59/59 [00:00<00:00, 430.96it/s]


Validation Accuracy for epoch 28: 0.27301841974258423


100%|██████████| 235/235 [00:01<00:00, 122.03it/s]


Average Loss for epoch 29: 0.026397992233764152


100%|██████████| 59/59 [00:00<00:00, 436.02it/s]


Validation Accuracy for epoch 29: 0.2535361647605896


100%|██████████| 235/235 [00:01<00:00, 121.89it/s]


Average Loss for epoch 30: 0.024559544434187185


100%|██████████| 59/59 [00:00<00:00, 434.86it/s]


Validation Accuracy for epoch 30: 0.24766480922698975


100%|██████████| 235/235 [00:01<00:00, 120.24it/s]


Average Loss for epoch 31: 0.023536816155821386


100%|██████████| 59/59 [00:00<00:00, 437.80it/s]


Validation Accuracy for epoch 31: 0.2532692849636078


100%|██████████| 235/235 [00:01<00:00, 119.97it/s]


Average Loss for epoch 32: 0.02226507911946897


100%|██████████| 59/59 [00:00<00:00, 406.81it/s]


Validation Accuracy for epoch 32: 0.24312783777713776


100%|██████████| 235/235 [00:02<00:00, 97.60it/s]


Average Loss for epoch 33: 0.021011617390694494


100%|██████████| 59/59 [00:00<00:00, 335.56it/s]


Validation Accuracy for epoch 33: 0.2636776268482208


100%|██████████| 235/235 [00:02<00:00, 93.00it/s]


Average Loss for epoch 34: 0.01977544869815864


100%|██████████| 59/59 [00:00<00:00, 437.79it/s]


Validation Accuracy for epoch 34: 0.2580731213092804


100%|██████████| 235/235 [00:01<00:00, 119.79it/s]


Average Loss for epoch 35: 0.01869389657836131


100%|██████████| 59/59 [00:00<00:00, 436.72it/s]


Validation Accuracy for epoch 35: 0.25220176577568054


100%|██████████| 235/235 [00:01<00:00, 121.81it/s]


Average Loss for epoch 36: 0.017220280953129535


100%|██████████| 59/59 [00:00<00:00, 428.49it/s]


Validation Accuracy for epoch 36: 0.2540699243545532


100%|██████████| 235/235 [00:01<00:00, 122.52it/s]


Average Loss for epoch 37: 0.015810528359372415


100%|██████████| 59/59 [00:00<00:00, 409.37it/s]


Validation Accuracy for epoch 37: 0.24979984760284424


100%|██████████| 235/235 [00:01<00:00, 120.46it/s]


Average Loss for epoch 38: 0.015054294210000912


100%|██████████| 59/59 [00:00<00:00, 432.69it/s]


Validation Accuracy for epoch 38: 0.24259407818317413


100%|██████████| 235/235 [00:02<00:00, 93.30it/s]


Average Loss for epoch 39: 0.014391505835181052


100%|██████████| 59/59 [00:00<00:00, 369.26it/s]


Validation Accuracy for epoch 39: 0.24846544861793518


100%|██████████| 235/235 [00:02<00:00, 109.74it/s]


Average Loss for epoch 40: 0.013971020976251546


100%|██████████| 59/59 [00:00<00:00, 407.67it/s]


Validation Accuracy for epoch 40: 0.2415265589952469


100%|██████████| 235/235 [00:01<00:00, 120.86it/s]


Average Loss for epoch 41: 0.01328715040993385


100%|██████████| 59/59 [00:00<00:00, 441.04it/s]


Validation Accuracy for epoch 41: 0.24979984760284424


100%|██████████| 235/235 [00:01<00:00, 118.30it/s]


Average Loss for epoch 42: 0.012693085462586866


100%|██████████| 59/59 [00:00<00:00, 440.69it/s]


Validation Accuracy for epoch 42: 0.24259407818317413


100%|██████████| 235/235 [00:01<00:00, 121.90it/s]


Average Loss for epoch 43: 0.012126163931069245


100%|██████████| 59/59 [00:00<00:00, 437.03it/s]


Validation Accuracy for epoch 43: 0.2348545491695404


100%|██████████| 235/235 [00:01<00:00, 120.55it/s]


Average Loss for epoch 44: 0.01150765525375021


100%|██████████| 59/59 [00:00<00:00, 424.25it/s]


Validation Accuracy for epoch 44: 0.24312783777713776


100%|██████████| 235/235 [00:02<00:00, 90.62it/s]


Average Loss for epoch 45: 0.01055878689199352


100%|██████████| 59/59 [00:00<00:00, 342.74it/s]


Validation Accuracy for epoch 45: 0.2417934387922287


100%|██████████| 235/235 [00:02<00:00, 108.50it/s]


Average Loss for epoch 46: 0.010427652878584355


100%|██████████| 59/59 [00:00<00:00, 424.06it/s]


Validation Accuracy for epoch 46: 0.24232719838619232


100%|██████████| 235/235 [00:01<00:00, 120.98it/s]


Average Loss for epoch 47: 0.010329683340620498


100%|██████████| 59/59 [00:00<00:00, 437.37it/s]


Validation Accuracy for epoch 47: 0.2393915206193924


100%|██████████| 235/235 [00:01<00:00, 119.13it/s]


Average Loss for epoch 48: 0.009957379916582397


100%|██████████| 59/59 [00:00<00:00, 432.25it/s]


Validation Accuracy for epoch 48: 0.22951695322990417


100%|██████████| 235/235 [00:01<00:00, 119.24it/s]


Average Loss for epoch 49: 0.009942473262922368


100%|██████████| 59/59 [00:00<00:00, 434.47it/s]

Validation Accuracy for epoch 49: 0.24686415493488312





In [None]:
class MarkovModel:
   # Model from https://github.com/thomashikaru
   
    """Represents a Markov Model for a given text"""

    def __init__(self, n, text):
        """Constructor takes n-gram length and training text
        and builds dictionary mapping n-grams to
        character-probability mappings."""
        self.n = n
        self.d = {}
        for i in range(len(text)-n-1):
            ngram = text[i:i+n]
            nextchar = text[i+n:i+n+1]
            if ngram in self.d:
                if nextchar in self.d[ngram]:
                    self.d[ngram][nextchar] += 1
                else:
                    self.d[ngram][nextchar] = 1
            else:
                self.d[ngram] = {nextchar: 1}

    def test_init(self):
        for x in (list(self.d.items())[:10]):
            print(x)

    def get_next_char(self, ngram):
        """Generates a single next character based to come after the provided n-gram,
        based on the probability distribution learned from the text."""
        if ngram in self.d:
            dist = self.d[ngram]
            distlist = list(dist.items())
            keys = [k for k, _ in distlist]
            vals = [v for _, v in distlist]
            valsum = sum(vals)
            vals = list(map(lambda x: x/valsum, vals))
            return np.random.choice(keys, 1, p=vals)[0]
        else:
            # this should never happen if start string n-gram exists in train text
            return np.random.choice([x for x in "abcdefghijklmnopqrstuvwxyz"])

    def get_n_chars(self, length, ngram):
        """Returns a generated sequence of specified length,
        using the given n-gram as a starting seed."""
        s = []
        for i in range(length):
            nextchar = self.get_next_char(ngram)
            ngram = ngram[1:]+nextchar
            s.append(nextchar)
        return ''.join(s)

In [None]:
#raw_data
#full_tweets

In [None]:
ngram_length = 4
tweet_length = 280
model = MarkovModel(ngram_length, text)
initial_ngram = "Hill"[:ngram_length]
print(initial_ngram + model.get_n_chars(tweet_length, initial_ngram))

Hillary of @lisall be dignific people sure and fair unsel/Justing! The specialMelB &amp; over express security - realDonal law in term(s) in Marylanded from the five-Disgrace would site House present. on @thehill poll Trump....youtubes making give Voice....Senator the race. The Afric


In [None]:
tweet_text = ''
for tweets in raw_data:
  for tweet in tweets:
    #print(type(tweet[0][0]))
    temp = " ".join(tweet)
    tweet_text = tweet_text + temp

In [None]:
ngram_length = 4
tweet_length = 280
model = MarkovModel(ngram_length, tweet_text)
initial_ngram = "Football"[:ngram_length]
print(initial_ngram + model.get_n_chars(tweet_length, initial_ngram))

Footro see the supongod som a k ich one were we are bet when montra closersn oway timelisto para el y nobodyt obrey o u s e r s a e c klord day should be up a ti brady te d o wso walking louie now days scome goes thanks for jonás que eut einteres about to keeping fun brady gronkwoski
