In [157]:
import torch
import torch.nn as nn
import torch.optim as optim

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

print(f'GPU available: {torch.cuda.is_available()}')

Sat May  1 23:30:29 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P0    34W / 250W |  14553MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [158]:
!pip install emoji



In [159]:
import StockSentimentDataset
from torch.utils.data import DataLoader
from sklearn.utils.class_weight import compute_class_weight

train_ds, val_ds, test_ds, embeddings, featurizer = StockSentimentDataset.load_all_data()

class_weights = compute_class_weight(class_weight='balanced', classes=[0,1,2], y=train_ds.labels)
class_weights_tensor = torch.FloatTensor(class_weights)

In [160]:
import torch.nn.functional as F
class CRAN(nn.Module):
    def __init__(self, hidden_dim, embedding_dim, vocab_size, pad_token):
        super(CRAN, self).__init__()

        self.pad_token = pad_token
        # self.n_filters = n_filters
        self.hidden_dim = hidden_dim

        self.embeddings = nn.Embedding(vocab_size, embedding_dim)

        #CNN
        # Paper only used 3,4,5 grams
        self.cnn_3gram = nn.Conv1d(embedding_dim, out_channels=1, kernel_size=3, stride=1)
        self.cnn_4gram = nn.Conv1d(embedding_dim, out_channels=1, kernel_size=4, stride=1)
        self.cnn_5gram = nn.Conv1d(embedding_dim, out_channels=1, kernel_size=5, stride=1)

        self.conv_fc = nn.Linear(3, hidden_dim)
        
        #BiGRU
        self.bigru = torch.nn.GRU(embedding_dim, hidden_size=hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
        self.gru_fc = nn.Linear(hidden_dim*2, hidden_dim)

        #Attention and output
        self.u_activation = torch.nn.Tanh()

        self.attention_softmax = nn.LogSoftmax(dim=1)
        self.output_layer = nn.Linear(hidden_dim, 3) # Whoops its sentiment analysis

    def forward(self, X, train=False):
        X = X.permute(1,0) # I did all this backwards from the data
        embedded = self.embeddings(X)

        embedded_permute = embedded.permute(0,2,1)

        x = self.cnn_3gram(embedded_permute)
        c3_pool = F.max_pool1d(x, x.shape[2])
        x = self.cnn_4gram(embedded_permute)
        c4_pool = F.max_pool1d(x, x.shape[2])
        x = self.cnn_5gram(embedded_permute)
        c5_pool = F.max_pool1d(x, x.shape[2])
    
        pooled_rep = torch.cat((c3_pool, c4_pool, c5_pool), dim = 1).squeeze()
        context = self.conv_fc(pooled_rep)

        gru_output, gru_hidden = self.bigru(embedded)

        input = self.u_activation(self.gru_fc(gru_output))

        attention_scores = torch.bmm(input, context.unsqueeze(2))
        attention_weights = self.attention_softmax(attention_scores)

        attention_output = torch.bmm(attention_weights.permute(0,2,1), input).squeeze()

        logits = self.output_layer(attention_output)

        return logits



In [161]:
from sklearn.metrics import f1_score
import random

def score_model(model, test_ds):
  model.eval()
  with torch.no_grad():
    test_dl = DataLoader(dataset=test_ds, batch_size=len(test_ds.samples),
      shuffle=False, collate_fn=StockSentimentDataset.collate_fn)
    predictions = []
    for x, y in test_dl:
      x = x.cuda()
      y = y.cuda()
      pred_scores = model(x)
      predictions.extend(torch.argmax(pred_scores, dim=1).tolist())
  return f1_score(test_ds.labels, predictions, labels=[0,1,2], average='macro')


In [162]:
#Training

from random import sample
import tqdm
import os
import subprocess
import random
from pretrained_embeddings import add_embeddings_to_embedding_matrix


def train(model, best_val_model, train_ds, val_ds, class_weights):
  class_weights = class_weights.cuda()
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

  train_dl = DataLoader(dataset=train_ds, batch_size=len(train_ds.samples),
    shuffle=True, collate_fn=StockSentimentDataset.collate_fn)
  print(len(train_ds.samples))

  val_dl = DataLoader(dataset=val_ds, batch_size=len(val_ds.samples),
    shuffle=False, collate_fn=StockSentimentDataset.collate_fn)

  epoch_stats = []
  best_f1 = 0.0
  n_epochs_without_improvement = 0
  best_val_model.load_state_dict(model.state_dict())
  for epoch in range(250):
    totalLoss = 0.0
    n_samples = 0
    model.train()
    for x, y in tqdm.notebook.tqdm(train_dl, leave=False):
      x = x.cuda()
      y = y.cuda()
      model.zero_grad()
      predicted_scores = model(x)
      loss = loss_fn(predicted_scores, y)
      loss.backward()
      totalLoss += loss.item()
      n_samples += x.size()[1]
      optimizer.step()


    average_train_loss = totalLoss / float(n_samples)

    model.eval()
    total_val_loss = 0.0
    n_val_samples = 0
    with torch.no_grad():
      for x, y in val_dl:
        x = x.cuda()
        y = y.cuda()
        predicted_scores = model(x)
        loss = loss_fn(predicted_scores, y)
        total_val_loss += loss.item()
        n_val_samples += x.size()[1]
    
    average_val_loss = total_val_loss / float(n_val_samples)
    epoch_stats.append((epoch, average_train_loss, average_val_loss))
    val_f1 = score_model(model, val_ds)
    if epoch % 10 == 0:
      print("Epoch {0} finished. Train loss: {1}, Val loss: {2}".format(epoch, average_train_loss, average_val_loss))

    if val_f1 > best_f1:
      best_f1 = val_f1
      n_epochs_without_improvement = 0
      best_val_model.load_state_dict(model.state_dict())
    # else:
    #   n_epochs_without_improvement += 1
    #   if n_epochs_without_improvement == 10:
    #     break

  return best_f1, epoch_stats

model = CRAN(120, 100, featurizer.vocab_size, featurizer.PADDING_ID)
best_val_model = CRAN(120, 100, featurizer.vocab_size, featurizer.PADDING_ID)
add_embeddings_to_embedding_matrix(embeddings, model.embeddings, featurizer)
model = model.cuda()
best_val_model = best_val_model.cuda()
train(model, best_val_model, train_ds, val_ds, class_weights_tensor)

909


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 0 finished. Train loss: 0.04298386610511637, Val loss: 0.3608958185935507


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 10 finished. Train loss: 0.02311165602949455, Val loss: 0.09705030674837073


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 20 finished. Train loss: 0.01172960447137243, Val loss: 0.047513377909757654


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 30 finished. Train loss: 0.005361286326996957, Val loss: 0.03302955140872878


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 40 finished. Train loss: 0.003949680212963008, Val loss: 0.015786062697975004


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 50 finished. Train loss: 0.0015624656666754627, Val loss: 0.012839965674341942


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 60 finished. Train loss: 0.0013831295195979253, Val loss: 0.01634192710020104


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 70 finished. Train loss: 0.000770655271112591, Val loss: 0.014881236212594169


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 80 finished. Train loss: 0.0011690748800145517, Val loss: 0.009040787511942337


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 90 finished. Train loss: 0.0016063143580135602, Val loss: 0.008351170286840322


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 100 finished. Train loss: 0.000669670577096467, Val loss: 0.009608550339329


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 110 finished. Train loss: 0.0006318734292805653, Val loss: 0.007786989212036133


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 120 finished. Train loss: 0.0002593327276777513, Val loss: 0.008025339063333005


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 130 finished. Train loss: 0.00019245674245559475, Val loss: 0.008070552227448444


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 140 finished. Train loss: 0.00015091800978212598, Val loss: 0.008145386467174607


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 150 finished. Train loss: 0.00012340395823039107, Val loss: 0.008353522845676966


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 160 finished. Train loss: 0.0001028986571907866, Val loss: 0.00848904799441902


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 170 finished. Train loss: 8.714722000201805e-05, Val loss: 0.008699433535945659


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 180 finished. Train loss: 7.429490290065803e-05, Val loss: 0.00888353403733701


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 190 finished. Train loss: 6.393853439749664e-05, Val loss: 0.00909208886477412


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 200 finished. Train loss: 5.52633537514375e-05, Val loss: 0.009298499749631298


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 210 finished. Train loss: 4.7949408114415456e-05, Val loss: 0.00950616294023942


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 220 finished. Train loss: 4.1760360807737766e-05, Val loss: 0.009727791255834152


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 230 finished. Train loss: 3.650618924034966e-05, Val loss: 0.00994520041407371


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Epoch 240 finished. Train loss: 3.202434330031161e-05, Val loss: 0.010161700905585776


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

(0.5700969856216479,
 [(0, 0.04298386610511637, 0.3608958185935507),
  (1, 0.12409364639466876, 0.42153868383290816),
  (2, 0.1494850183870926, 0.37652548965142696),
  (3, 0.13317599154935025, 0.24100455459283324),
  (4, 0.08036726982024374, 0.19328956214749082),
  (5, 0.06171004292201681, 0.05105538270911392),
  (6, 0.014780951018380646, 0.17092842958411392),
  (7, 0.05013796782205076, 0.21634214751574457),
  (8, 0.06576210332520069, 0.15495928939507933),
  (9, 0.048535291380090155, 0.08779657130338708),
  (10, 0.02311165602949455, 0.09705030674837073),
  (11, 0.02583602133101494, 0.07382567074834084),
  (12, 0.02037902555056531, 0.09959878726881377),
  (13, 0.030015840519904042, 0.07349070237607372),
  (14, 0.02046679716036789, 0.06854069476224939),
  (15, 0.01826026547204281, 0.05264662236583476),
  (16, 0.01309759729635073, 0.048786742346627374),
  (17, 0.015698228338764053, 0.05628483636038644),
  (18, 0.020124349394778345, 0.055084043619584064),
  (19, 0.01731650139501255, 0.0480

In [163]:
print(score_model(model, train_ds))
print(score_model(model, test_ds))
print(score_model(model, val_ds))

print(score_model(best_val_model, train_ds))
print(score_model(best_val_model, test_ds))
print(score_model(best_val_model, val_ds))


0.9959855239190873
0.5520553111250787
0.549552749770279
0.9626565023816663
0.5334913159552448
0.5700969856216479
