Defining ML Model 

In [6]:
%pip install -r ..\requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [63]:

#initialise neural network
import torch
import torch.nn as nn

class model(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            #first layer
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.5),
            #hidden layer
            nn.Linear(64, 8),
            nn.ReLU(),
            nn.BatchNorm1d(8),
            nn.Dropout(0.5),
            #output scalar score
            nn.Linear(8, 1)
        )

    def forward(self, x):
        return self.mlp(x)
    
embed_to_score = model(input_dim=512)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(embed_to_score.parameters(), lr=1e-4, weight_decay=1e-4)

Data preprocessing

In [50]:
#entering data
import csv
import ast
import numpy as np
embeddings = []
scores = []

with open(r'C:\Users\rwwj8\OneDrive\Documents\scores_embeds.csv', 'r', encoding = 'utf-8') as file:
    reader = csv.DictReader(file)
    data = [row for row in reader]
    for row in data:
        rowlist = ast.literal_eval(row['embeds'])
        #print(len(rowlist))
        embeddings.append(rowlist)
        score = float(row['score'])
        scores.append(score)
embeddingarr = np.array(embeddings)
scorearr = np.array(scores)
print(scorearr.shape)

(469,)


In [51]:
#gpu code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embed_to_score.to(device)

model(
  (mlp): Sequential(
    (0): Linear(in_features=512, out_features=64, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=64, out_features=8, bias=True)
    (5): ReLU()
    (6): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.5, inplace=False)
    (8): Linear(in_features=8, out_features=1, bias=True)
  )
)

In [32]:
#data preprocessing
from sklearn.model_selection import train_test_split
emb_train, score_train, emb_val, score_val = train_test_split(embeddingarr, scorearr, test_size=0.2, random_state=42)
#convert to tensors
emb_train = torch.FloatTensor(emb_train).to(device)
emb_val = torch.FloatTensor(emb_val).to(device)  # Add dim for MSE loss
score_train = torch.FloatTensor(score_train).to(device)
score_val = torch.FloatTensor(score_val).to(device)
print(emb_train.shape, emb_val.shape, score_train.shape, score_val.shape)
#data loader
from torch.utils.data import DataLoader, TensorDataset
train_dataset = TensorDataset(emb_train, emb_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = TensorDataset(score_train, score_val)
val_loader = DataLoader(val_dataset, batch_size=32)

torch.Size([375, 512]) torch.Size([375]) torch.Size([94, 512]) torch.Size([94])


Model training

In [52]:
#training loop
import torch.optim as optim

def train(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    for X_batch, y_batch in tqdm(dataloader):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def validate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            total_loss += loss.item()
    return total_loss / len(dataloader)

In [65]:
from tqdm import tqdm
tqdm._instances.clear()

best_val_loss = float('inf')
patience = 50  # Stop after 5 epochs without improvement
epochs_no_improve = 0

for epoch in range(1000):  # Max epochs
    train_loss = train(embed_to_score, train_loader, criterion, optimizer)
    val_loss = validate(embed_to_score, val_loader, criterion)
    if epoch % 10 == 0:  # Print every 10 epochs
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        torch.save(embed_to_score.state_dict(), 'best_model.pth')  # Save best model
    else:
        epochs_no_improve += 1
        if epochs_no_improve == patience:
            print("Early stopping!")
            break

100%|██████████| 12/12 [00:00<00:00, 331.46it/s]


Epoch 0: Train Loss = 1.3267, Val Loss = 0.7117


100%|██████████| 12/12 [00:00<00:00, 405.74it/s]
100%|██████████| 12/12 [00:00<00:00, 257.69it/s]
100%|██████████| 12/12 [00:00<00:00, 368.12it/s]
100%|██████████| 12/12 [00:00<00:00, 404.99it/s]
100%|██████████| 12/12 [00:00<00:00, 404.70it/s]
100%|██████████| 12/12 [00:00<00:00, 408.29it/s]
100%|██████████| 12/12 [00:00<00:00, 418.78it/s]
100%|██████████| 12/12 [00:00<00:00, 412.34it/s]
100%|██████████| 12/12 [00:00<00:00, 410.11it/s]
100%|██████████| 12/12 [00:00<00:00, 390.34it/s]


Epoch 10: Train Loss = 1.3379, Val Loss = 0.7100


100%|██████████| 12/12 [00:00<00:00, 369.98it/s]
100%|██████████| 12/12 [00:00<00:00, 399.02it/s]
100%|██████████| 12/12 [00:00<00:00, 400.07it/s]
100%|██████████| 12/12 [00:00<00:00, 412.17it/s]
100%|██████████| 12/12 [00:00<00:00, 425.55it/s]
100%|██████████| 12/12 [00:00<00:00, 414.60it/s]
100%|██████████| 12/12 [00:00<00:00, 408.80it/s]
100%|██████████| 12/12 [00:00<00:00, 421.74it/s]
100%|██████████| 12/12 [00:00<00:00, 425.80it/s]
100%|██████████| 12/12 [00:00<00:00, 402.57it/s]


Epoch 20: Train Loss = 1.2994, Val Loss = 0.7106


100%|██████████| 12/12 [00:00<00:00, 401.01it/s]
100%|██████████| 12/12 [00:00<00:00, 384.52it/s]
100%|██████████| 12/12 [00:00<00:00, 375.63it/s]
100%|██████████| 12/12 [00:00<00:00, 372.91it/s]
100%|██████████| 12/12 [00:00<00:00, 372.78it/s]
100%|██████████| 12/12 [00:00<00:00, 392.40it/s]
100%|██████████| 12/12 [00:00<00:00, 392.16it/s]
100%|██████████| 12/12 [00:00<00:00, 369.35it/s]
100%|██████████| 12/12 [00:00<00:00, 369.81it/s]
100%|██████████| 12/12 [00:00<00:00, 341.13it/s]


Epoch 30: Train Loss = 1.2699, Val Loss = 0.7107


100%|██████████| 12/12 [00:00<00:00, 338.20it/s]
100%|██████████| 12/12 [00:00<00:00, 361.92it/s]
100%|██████████| 12/12 [00:00<00:00, 389.88it/s]
100%|██████████| 12/12 [00:00<00:00, 391.71it/s]
100%|██████████| 12/12 [00:00<00:00, 390.16it/s]
100%|██████████| 12/12 [00:00<00:00, 398.94it/s]
100%|██████████| 12/12 [00:00<00:00, 390.66it/s]
100%|██████████| 12/12 [00:00<00:00, 394.35it/s]
100%|██████████| 12/12 [00:00<00:00, 390.69it/s]
100%|██████████| 12/12 [00:00<00:00, 395.50it/s]


Epoch 40: Train Loss = 1.2945, Val Loss = 0.7098


100%|██████████| 12/12 [00:00<00:00, 371.63it/s]
100%|██████████| 12/12 [00:00<00:00, 388.19it/s]
100%|██████████| 12/12 [00:00<00:00, 397.95it/s]
100%|██████████| 12/12 [00:00<00:00, 401.19it/s]
100%|██████████| 12/12 [00:00<00:00, 386.72it/s]
100%|██████████| 12/12 [00:00<00:00, 406.68it/s]
100%|██████████| 12/12 [00:00<00:00, 405.22it/s]
100%|██████████| 12/12 [00:00<00:00, 398.49it/s]
100%|██████████| 12/12 [00:00<00:00, 391.79it/s]
100%|██████████| 12/12 [00:00<00:00, 389.02it/s]


Epoch 50: Train Loss = 1.2290, Val Loss = 0.7106


100%|██████████| 12/12 [00:00<00:00, 385.02it/s]
100%|██████████| 12/12 [00:00<00:00, 390.57it/s]
100%|██████████| 12/12 [00:00<00:00, 389.80it/s]
100%|██████████| 12/12 [00:00<00:00, 382.93it/s]
100%|██████████| 12/12 [00:00<00:00, 393.05it/s]
100%|██████████| 12/12 [00:00<00:00, 405.04it/s]
100%|██████████| 12/12 [00:00<00:00, 368.28it/s]
100%|██████████| 12/12 [00:00<00:00, 372.64it/s]
100%|██████████| 12/12 [00:00<00:00, 383.89it/s]
100%|██████████| 12/12 [00:00<00:00, 380.46it/s]


Epoch 60: Train Loss = 1.1923, Val Loss = 0.7082


100%|██████████| 12/12 [00:00<00:00, 360.18it/s]
100%|██████████| 12/12 [00:00<00:00, 396.27it/s]
100%|██████████| 12/12 [00:00<00:00, 398.36it/s]
100%|██████████| 12/12 [00:00<00:00, 385.33it/s]
100%|██████████| 12/12 [00:00<00:00, 257.88it/s]
100%|██████████| 12/12 [00:00<00:00, 358.64it/s]
100%|██████████| 12/12 [00:00<00:00, 376.06it/s]
100%|██████████| 12/12 [00:00<00:00, 394.90it/s]
100%|██████████| 12/12 [00:00<00:00, 381.16it/s]
100%|██████████| 12/12 [00:00<00:00, 379.70it/s]


Epoch 70: Train Loss = 1.2461, Val Loss = 0.7069


100%|██████████| 12/12 [00:00<00:00, 375.86it/s]
100%|██████████| 12/12 [00:00<00:00, 374.56it/s]
100%|██████████| 12/12 [00:00<00:00, 385.93it/s]
100%|██████████| 12/12 [00:00<00:00, 354.87it/s]
100%|██████████| 12/12 [00:00<00:00, 386.78it/s]
100%|██████████| 12/12 [00:00<00:00, 396.99it/s]
100%|██████████| 12/12 [00:00<00:00, 385.66it/s]
100%|██████████| 12/12 [00:00<00:00, 385.04it/s]
100%|██████████| 12/12 [00:00<00:00, 383.57it/s]
100%|██████████| 12/12 [00:00<00:00, 383.45it/s]


Epoch 80: Train Loss = 1.1642, Val Loss = 0.7066


100%|██████████| 12/12 [00:00<00:00, 374.13it/s]
100%|██████████| 12/12 [00:00<00:00, 400.27it/s]
100%|██████████| 12/12 [00:00<00:00, 371.19it/s]
100%|██████████| 12/12 [00:00<00:00, 372.06it/s]
100%|██████████| 12/12 [00:00<00:00, 381.16it/s]
100%|██████████| 12/12 [00:00<00:00, 393.69it/s]
100%|██████████| 12/12 [00:00<00:00, 382.51it/s]
100%|██████████| 12/12 [00:00<00:00, 395.22it/s]
100%|██████████| 12/12 [00:00<00:00, 339.77it/s]
100%|██████████| 12/12 [00:00<00:00, 363.01it/s]


Epoch 90: Train Loss = 1.1926, Val Loss = 0.7067


100%|██████████| 12/12 [00:00<00:00, 353.22it/s]
100%|██████████| 12/12 [00:00<00:00, 362.23it/s]
100%|██████████| 12/12 [00:00<00:00, 399.65it/s]
100%|██████████| 12/12 [00:00<00:00, 403.63it/s]
100%|██████████| 12/12 [00:00<00:00, 400.84it/s]
100%|██████████| 12/12 [00:00<00:00, 394.04it/s]
100%|██████████| 12/12 [00:00<00:00, 364.19it/s]
100%|██████████| 12/12 [00:00<00:00, 378.73it/s]
100%|██████████| 12/12 [00:00<00:00, 403.82it/s]
100%|██████████| 12/12 [00:00<00:00, 399.04it/s]


Epoch 100: Train Loss = 1.2113, Val Loss = 0.7062


100%|██████████| 12/12 [00:00<00:00, 407.53it/s]
100%|██████████| 12/12 [00:00<00:00, 416.02it/s]
100%|██████████| 12/12 [00:00<00:00, 416.04it/s]
100%|██████████| 12/12 [00:00<00:00, 416.03it/s]
100%|██████████| 12/12 [00:00<00:00, 421.88it/s]
100%|██████████| 12/12 [00:00<00:00, 418.53it/s]
100%|██████████| 12/12 [00:00<00:00, 420.76it/s]
100%|██████████| 12/12 [00:00<00:00, 421.88it/s]
100%|██████████| 12/12 [00:00<00:00, 428.30it/s]
100%|██████████| 12/12 [00:00<00:00, 430.73it/s]


Epoch 110: Train Loss = 1.2216, Val Loss = 0.7062


100%|██████████| 12/12 [00:00<00:00, 408.11it/s]
100%|██████████| 12/12 [00:00<00:00, 424.71it/s]
100%|██████████| 12/12 [00:00<00:00, 424.73it/s]
100%|██████████| 12/12 [00:00<00:00, 411.56it/s]
100%|██████████| 12/12 [00:00<00:00, 416.08it/s]
100%|██████████| 12/12 [00:00<00:00, 394.47it/s]
100%|██████████| 12/12 [00:00<00:00, 403.50it/s]
100%|██████████| 12/12 [00:00<00:00, 381.60it/s]
100%|██████████| 12/12 [00:00<00:00, 358.14it/s]
100%|██████████| 12/12 [00:00<00:00, 373.52it/s]


Epoch 120: Train Loss = 1.1300, Val Loss = 0.7061


100%|██████████| 12/12 [00:00<00:00, 364.79it/s]
100%|██████████| 12/12 [00:00<00:00, 394.37it/s]
100%|██████████| 12/12 [00:00<00:00, 383.78it/s]
100%|██████████| 12/12 [00:00<00:00, 380.56it/s]
100%|██████████| 12/12 [00:00<00:00, 373.94it/s]
100%|██████████| 12/12 [00:00<00:00, 392.07it/s]
100%|██████████| 12/12 [00:00<00:00, 396.91it/s]
100%|██████████| 12/12 [00:00<00:00, 302.34it/s]
100%|██████████| 12/12 [00:00<00:00, 333.34it/s]
100%|██████████| 12/12 [00:00<00:00, 301.25it/s]


Epoch 130: Train Loss = 1.2230, Val Loss = 0.7067


100%|██████████| 12/12 [00:00<00:00, 348.41it/s]
100%|██████████| 12/12 [00:00<00:00, 244.89it/s]
100%|██████████| 12/12 [00:00<00:00, 353.59it/s]
100%|██████████| 12/12 [00:00<00:00, 386.51it/s]
100%|██████████| 12/12 [00:00<00:00, 391.42it/s]
100%|██████████| 12/12 [00:00<00:00, 388.85it/s]
100%|██████████| 12/12 [00:00<00:00, 394.33it/s]
100%|██████████| 12/12 [00:00<00:00, 391.68it/s]
100%|██████████| 12/12 [00:00<00:00, 383.86it/s]
100%|██████████| 12/12 [00:00<00:00, 381.53it/s]


Epoch 140: Train Loss = 1.1756, Val Loss = 0.7058


100%|██████████| 12/12 [00:00<00:00, 370.73it/s]
100%|██████████| 12/12 [00:00<00:00, 378.21it/s]
100%|██████████| 12/12 [00:00<00:00, 396.17it/s]
100%|██████████| 12/12 [00:00<00:00, 380.15it/s]
100%|██████████| 12/12 [00:00<00:00, 391.38it/s]
100%|██████████| 12/12 [00:00<00:00, 344.36it/s]
100%|██████████| 12/12 [00:00<00:00, 396.16it/s]
100%|██████████| 12/12 [00:00<00:00, 381.16it/s]
100%|██████████| 12/12 [00:00<00:00, 346.41it/s]
100%|██████████| 12/12 [00:00<00:00, 374.77it/s]


Epoch 150: Train Loss = 1.1479, Val Loss = 0.7051


100%|██████████| 12/12 [00:00<00:00, 348.67it/s]
100%|██████████| 12/12 [00:00<00:00, 407.36it/s]
100%|██████████| 12/12 [00:00<00:00, 397.98it/s]
100%|██████████| 12/12 [00:00<00:00, 401.63it/s]
100%|██████████| 12/12 [00:00<00:00, 398.08it/s]
100%|██████████| 12/12 [00:00<00:00, 397.67it/s]
100%|██████████| 12/12 [00:00<00:00, 373.16it/s]
100%|██████████| 12/12 [00:00<00:00, 387.81it/s]
100%|██████████| 12/12 [00:00<00:00, 399.67it/s]
100%|██████████| 12/12 [00:00<00:00, 392.80it/s]


Epoch 160: Train Loss = 1.1808, Val Loss = 0.7036


100%|██████████| 12/12 [00:00<00:00, 400.83it/s]
100%|██████████| 12/12 [00:00<00:00, 402.91it/s]
100%|██████████| 12/12 [00:00<00:00, 392.20it/s]
100%|██████████| 12/12 [00:00<00:00, 400.01it/s]
100%|██████████| 12/12 [00:00<00:00, 402.29it/s]
100%|██████████| 12/12 [00:00<00:00, 397.11it/s]
100%|██████████| 12/12 [00:00<00:00, 399.63it/s]
100%|██████████| 12/12 [00:00<00:00, 405.57it/s]
100%|██████████| 12/12 [00:00<00:00, 418.29it/s]
100%|██████████| 12/12 [00:00<00:00, 379.19it/s]


Epoch 170: Train Loss = 1.1911, Val Loss = 0.7042


100%|██████████| 12/12 [00:00<00:00, 391.98it/s]
100%|██████████| 12/12 [00:00<00:00, 390.07it/s]
100%|██████████| 12/12 [00:00<00:00, 400.09it/s]
100%|██████████| 12/12 [00:00<00:00, 410.71it/s]
100%|██████████| 12/12 [00:00<00:00, 403.29it/s]
100%|██████████| 12/12 [00:00<00:00, 369.10it/s]
100%|██████████| 12/12 [00:00<00:00, 408.97it/s]
100%|██████████| 12/12 [00:00<00:00, 414.68it/s]
100%|██████████| 12/12 [00:00<00:00, 353.95it/s]
100%|██████████| 12/12 [00:00<00:00, 382.93it/s]


Epoch 180: Train Loss = 1.1540, Val Loss = 0.7042


100%|██████████| 12/12 [00:00<00:00, 386.97it/s]
100%|██████████| 12/12 [00:00<00:00, 386.93it/s]
100%|██████████| 12/12 [00:00<00:00, 396.60it/s]
100%|██████████| 12/12 [00:00<00:00, 394.67it/s]
100%|██████████| 12/12 [00:00<00:00, 375.23it/s]
100%|██████████| 12/12 [00:00<00:00, 380.10it/s]
100%|██████████| 12/12 [00:00<00:00, 386.27it/s]
100%|██████████| 12/12 [00:00<00:00, 388.25it/s]
100%|██████████| 12/12 [00:00<00:00, 381.60it/s]
100%|██████████| 12/12 [00:00<00:00, 380.77it/s]


Epoch 190: Train Loss = 1.1570, Val Loss = 0.7038


100%|██████████| 12/12 [00:00<00:00, 381.47it/s]
100%|██████████| 12/12 [00:00<00:00, 400.63it/s]
100%|██████████| 12/12 [00:00<00:00, 376.45it/s]
100%|██████████| 12/12 [00:00<00:00, 377.14it/s]
100%|██████████| 12/12 [00:00<00:00, 276.55it/s]
100%|██████████| 12/12 [00:00<00:00, 388.85it/s]
100%|██████████| 12/12 [00:00<00:00, 390.00it/s]
100%|██████████| 12/12 [00:00<00:00, 387.37it/s]
100%|██████████| 12/12 [00:00<00:00, 389.01it/s]
100%|██████████| 12/12 [00:00<00:00, 381.36it/s]


Epoch 200: Train Loss = 1.0833, Val Loss = 0.7032


100%|██████████| 12/12 [00:00<00:00, 375.43it/s]
100%|██████████| 12/12 [00:00<00:00, 390.32it/s]
100%|██████████| 12/12 [00:00<00:00, 372.79it/s]
100%|██████████| 12/12 [00:00<00:00, 360.97it/s]
100%|██████████| 12/12 [00:00<00:00, 380.59it/s]
100%|██████████| 12/12 [00:00<00:00, 380.28it/s]
100%|██████████| 12/12 [00:00<00:00, 386.60it/s]
100%|██████████| 12/12 [00:00<00:00, 384.97it/s]
100%|██████████| 12/12 [00:00<00:00, 345.09it/s]
100%|██████████| 12/12 [00:00<00:00, 345.61it/s]


Epoch 210: Train Loss = 1.1520, Val Loss = 0.7038


100%|██████████| 12/12 [00:00<00:00, 353.00it/s]
100%|██████████| 12/12 [00:00<00:00, 383.97it/s]
100%|██████████| 12/12 [00:00<00:00, 384.53it/s]
100%|██████████| 12/12 [00:00<00:00, 351.97it/s]
100%|██████████| 12/12 [00:00<00:00, 327.05it/s]
100%|██████████| 12/12 [00:00<00:00, 345.90it/s]
100%|██████████| 12/12 [00:00<00:00, 359.51it/s]
100%|██████████| 12/12 [00:00<00:00, 386.05it/s]
100%|██████████| 12/12 [00:00<00:00, 384.84it/s]
100%|██████████| 12/12 [00:00<00:00, 387.22it/s]


Epoch 220: Train Loss = 1.1223, Val Loss = 0.7040


100%|██████████| 12/12 [00:00<00:00, 381.43it/s]
100%|██████████| 12/12 [00:00<00:00, 375.83it/s]
100%|██████████| 12/12 [00:00<00:00, 397.24it/s]
100%|██████████| 12/12 [00:00<00:00, 402.63it/s]
100%|██████████| 12/12 [00:00<00:00, 329.53it/s]
100%|██████████| 12/12 [00:00<00:00, 327.65it/s]
100%|██████████| 12/12 [00:00<00:00, 344.91it/s]
100%|██████████| 12/12 [00:00<00:00, 360.08it/s]
100%|██████████| 12/12 [00:00<00:00, 381.25it/s]
100%|██████████| 12/12 [00:00<00:00, 385.71it/s]


Epoch 230: Train Loss = 1.1066, Val Loss = 0.7031


100%|██████████| 12/12 [00:00<00:00, 368.50it/s]
100%|██████████| 12/12 [00:00<00:00, 349.03it/s]
100%|██████████| 12/12 [00:00<00:00, 391.91it/s]
100%|██████████| 12/12 [00:00<00:00, 392.85it/s]
100%|██████████| 12/12 [00:00<00:00, 374.47it/s]
100%|██████████| 12/12 [00:00<00:00, 366.00it/s]
100%|██████████| 12/12 [00:00<00:00, 358.48it/s]
100%|██████████| 12/12 [00:00<00:00, 379.98it/s]
100%|██████████| 12/12 [00:00<00:00, 376.81it/s]
100%|██████████| 12/12 [00:00<00:00, 376.09it/s]


Epoch 240: Train Loss = 1.1332, Val Loss = 0.7028


100%|██████████| 12/12 [00:00<00:00, 349.62it/s]
100%|██████████| 12/12 [00:00<00:00, 374.19it/s]
100%|██████████| 12/12 [00:00<00:00, 337.66it/s]
100%|██████████| 12/12 [00:00<00:00, 368.07it/s]
100%|██████████| 12/12 [00:00<00:00, 379.66it/s]
100%|██████████| 12/12 [00:00<00:00, 362.41it/s]
100%|██████████| 12/12 [00:00<00:00, 378.84it/s]
100%|██████████| 12/12 [00:00<00:00, 372.87it/s]
100%|██████████| 12/12 [00:00<00:00, 389.93it/s]
100%|██████████| 12/12 [00:00<00:00, 376.05it/s]


Epoch 250: Train Loss = 1.0763, Val Loss = 0.7022


100%|██████████| 12/12 [00:00<00:00, 362.41it/s]
100%|██████████| 12/12 [00:00<00:00, 377.02it/s]
100%|██████████| 12/12 [00:00<00:00, 392.10it/s]
100%|██████████| 12/12 [00:00<00:00, 381.90it/s]
100%|██████████| 12/12 [00:00<00:00, 375.54it/s]
100%|██████████| 12/12 [00:00<00:00, 387.72it/s]
100%|██████████| 12/12 [00:00<00:00, 366.64it/s]
100%|██████████| 12/12 [00:00<00:00, 375.82it/s]
100%|██████████| 12/12 [00:00<00:00, 368.83it/s]
100%|██████████| 12/12 [00:00<00:00, 356.87it/s]


Epoch 260: Train Loss = 1.1060, Val Loss = 0.7023


100%|██████████| 12/12 [00:00<00:00, 336.57it/s]
100%|██████████| 12/12 [00:00<00:00, 269.48it/s]
100%|██████████| 12/12 [00:00<00:00, 382.73it/s]
100%|██████████| 12/12 [00:00<00:00, 383.59it/s]
100%|██████████| 12/12 [00:00<00:00, 366.63it/s]
100%|██████████| 12/12 [00:00<00:00, 341.70it/s]
100%|██████████| 12/12 [00:00<00:00, 327.22it/s]
100%|██████████| 12/12 [00:00<00:00, 336.25it/s]
100%|██████████| 12/12 [00:00<00:00, 359.00it/s]
100%|██████████| 12/12 [00:00<00:00, 375.25it/s]


Epoch 270: Train Loss = 1.1417, Val Loss = 0.7019


100%|██████████| 12/12 [00:00<00:00, 368.55it/s]
100%|██████████| 12/12 [00:00<00:00, 314.76it/s]
100%|██████████| 12/12 [00:00<00:00, 310.55it/s]
100%|██████████| 12/12 [00:00<00:00, 300.35it/s]
100%|██████████| 12/12 [00:00<00:00, 387.28it/s]
100%|██████████| 12/12 [00:00<00:00, 357.74it/s]
100%|██████████| 12/12 [00:00<00:00, 342.73it/s]
100%|██████████| 12/12 [00:00<00:00, 385.31it/s]
100%|██████████| 12/12 [00:00<00:00, 375.53it/s]
100%|██████████| 12/12 [00:00<00:00, 372.60it/s]


Epoch 280: Train Loss = 1.0729, Val Loss = 0.7020


100%|██████████| 12/12 [00:00<00:00, 363.61it/s]
100%|██████████| 12/12 [00:00<00:00, 364.66it/s]
100%|██████████| 12/12 [00:00<00:00, 375.34it/s]
100%|██████████| 12/12 [00:00<00:00, 373.16it/s]
100%|██████████| 12/12 [00:00<00:00, 382.21it/s]
100%|██████████| 12/12 [00:00<00:00, 335.97it/s]
100%|██████████| 12/12 [00:00<00:00, 368.69it/s]
100%|██████████| 12/12 [00:00<00:00, 359.24it/s]
100%|██████████| 12/12 [00:00<00:00, 375.11it/s]
100%|██████████| 12/12 [00:00<00:00, 375.00it/s]


Epoch 290: Train Loss = 1.0278, Val Loss = 0.7015


100%|██████████| 12/12 [00:00<00:00, 358.29it/s]
100%|██████████| 12/12 [00:00<00:00, 360.30it/s]
100%|██████████| 12/12 [00:00<00:00, 388.27it/s]
100%|██████████| 12/12 [00:00<00:00, 371.18it/s]
100%|██████████| 12/12 [00:00<00:00, 390.32it/s]
100%|██████████| 12/12 [00:00<00:00, 388.46it/s]
100%|██████████| 12/12 [00:00<00:00, 393.12it/s]
100%|██████████| 12/12 [00:00<00:00, 385.66it/s]
100%|██████████| 12/12 [00:00<00:00, 364.73it/s]
100%|██████████| 12/12 [00:00<00:00, 390.03it/s]


Epoch 300: Train Loss = 1.0891, Val Loss = 0.7017


100%|██████████| 12/12 [00:00<00:00, 402.07it/s]
100%|██████████| 12/12 [00:00<00:00, 386.82it/s]
100%|██████████| 12/12 [00:00<00:00, 395.49it/s]
100%|██████████| 12/12 [00:00<00:00, 405.27it/s]
100%|██████████| 12/12 [00:00<00:00, 385.45it/s]
100%|██████████| 12/12 [00:00<00:00, 370.45it/s]
100%|██████████| 12/12 [00:00<00:00, 398.77it/s]
100%|██████████| 12/12 [00:00<00:00, 379.46it/s]
100%|██████████| 12/12 [00:00<00:00, 392.19it/s]
100%|██████████| 12/12 [00:00<00:00, 376.90it/s]


Epoch 310: Train Loss = 1.0884, Val Loss = 0.7011


100%|██████████| 12/12 [00:00<00:00, 350.61it/s]
100%|██████████| 12/12 [00:00<00:00, 390.44it/s]
100%|██████████| 12/12 [00:00<00:00, 398.80it/s]
100%|██████████| 12/12 [00:00<00:00, 374.23it/s]
100%|██████████| 12/12 [00:00<00:00, 354.30it/s]
100%|██████████| 12/12 [00:00<00:00, 376.83it/s]
100%|██████████| 12/12 [00:00<00:00, 399.43it/s]
100%|██████████| 12/12 [00:00<00:00, 396.65it/s]
100%|██████████| 12/12 [00:00<00:00, 383.04it/s]
100%|██████████| 12/12 [00:00<00:00, 359.07it/s]


Epoch 320: Train Loss = 1.0412, Val Loss = 0.7009


100%|██████████| 12/12 [00:00<00:00, 248.97it/s]
100%|██████████| 12/12 [00:00<00:00, 341.94it/s]
100%|██████████| 12/12 [00:00<00:00, 272.55it/s]
100%|██████████| 12/12 [00:00<00:00, 279.88it/s]
100%|██████████| 12/12 [00:00<00:00, 313.50it/s]
100%|██████████| 12/12 [00:00<00:00, 374.39it/s]
100%|██████████| 12/12 [00:00<00:00, 375.57it/s]
100%|██████████| 12/12 [00:00<00:00, 371.71it/s]
100%|██████████| 12/12 [00:00<00:00, 378.86it/s]
100%|██████████| 12/12 [00:00<00:00, 358.26it/s]


Epoch 330: Train Loss = 1.0671, Val Loss = 0.7013


100%|██████████| 12/12 [00:00<00:00, 339.66it/s]
100%|██████████| 12/12 [00:00<00:00, 344.85it/s]
100%|██████████| 12/12 [00:00<00:00, 335.92it/s]
100%|██████████| 12/12 [00:00<00:00, 307.53it/s]
100%|██████████| 12/12 [00:00<00:00, 338.44it/s]
100%|██████████| 12/12 [00:00<00:00, 325.16it/s]
100%|██████████| 12/12 [00:00<00:00, 365.20it/s]
100%|██████████| 12/12 [00:00<00:00, 377.53it/s]
100%|██████████| 12/12 [00:00<00:00, 366.43it/s]
100%|██████████| 12/12 [00:00<00:00, 371.46it/s]


Epoch 340: Train Loss = 1.0344, Val Loss = 0.7020


100%|██████████| 12/12 [00:00<00:00, 341.20it/s]
100%|██████████| 12/12 [00:00<00:00, 381.39it/s]
100%|██████████| 12/12 [00:00<00:00, 384.41it/s]
100%|██████████| 12/12 [00:00<00:00, 350.08it/s]
100%|██████████| 12/12 [00:00<00:00, 323.43it/s]
100%|██████████| 12/12 [00:00<00:00, 342.33it/s]
100%|██████████| 12/12 [00:00<00:00, 331.93it/s]
100%|██████████| 12/12 [00:00<00:00, 362.34it/s]
100%|██████████| 12/12 [00:00<00:00, 324.75it/s]
100%|██████████| 12/12 [00:00<00:00, 378.63it/s]


Epoch 350: Train Loss = 1.0154, Val Loss = 0.7018


100%|██████████| 12/12 [00:00<00:00, 364.99it/s]
100%|██████████| 12/12 [00:00<00:00, 385.59it/s]
100%|██████████| 12/12 [00:00<00:00, 375.65it/s]
100%|██████████| 12/12 [00:00<00:00, 375.50it/s]
100%|██████████| 12/12 [00:00<00:00, 374.64it/s]
100%|██████████| 12/12 [00:00<00:00, 375.28it/s]
100%|██████████| 12/12 [00:00<00:00, 316.10it/s]
100%|██████████| 12/12 [00:00<00:00, 294.16it/s]
100%|██████████| 12/12 [00:00<00:00, 319.39it/s]
100%|██████████| 12/12 [00:00<00:00, 288.20it/s]


Epoch 360: Train Loss = 1.0825, Val Loss = 0.7013


100%|██████████| 12/12 [00:00<00:00, 349.55it/s]
100%|██████████| 12/12 [00:00<00:00, 392.20it/s]
100%|██████████| 12/12 [00:00<00:00, 370.44it/s]
100%|██████████| 12/12 [00:00<00:00, 367.46it/s]
100%|██████████| 12/12 [00:00<00:00, 374.42it/s]
100%|██████████| 12/12 [00:00<00:00, 375.69it/s]
100%|██████████| 12/12 [00:00<00:00, 331.80it/s]
100%|██████████| 12/12 [00:00<00:00, 367.14it/s]
100%|██████████| 12/12 [00:00<00:00, 377.64it/s]
100%|██████████| 12/12 [00:00<00:00, 369.66it/s]


Epoch 370: Train Loss = 1.0204, Val Loss = 0.7016


100%|██████████| 12/12 [00:00<00:00, 371.89it/s]
100%|██████████| 12/12 [00:00<00:00, 369.21it/s]

Early stopping!





Model took multiple training runs to stop - about 2400 epochs total. Thankfully each epoch is fast and can be done on cpu. 
Patience is there to prevent overfitting - when validation loss increases but training loss goes down, thats overfitting
I realised that if patience is too low, model will stop in a local minima for validation loss. Patience = 50 worked well here. 

Testing best model weights on random internet picture. At least the grade falls below 1 and 5. 

In [72]:
#Testing the trained model
#This tensor is converted from picture of a man https://www.123rf.com/photo_117285262_cheerful-male-person-wearing-colorful-clothes-while-posing-on-camera.html
sample_tensor = torch.FloatTensor([ 1.6438e-02,  2.1896e-02, -8.0362e-02,  2.8341e-02, -2.8072e-02,
         -2.9751e-02,  3.3327e-03,  7.1567e-03, -3.3463e-03, -1.4960e-02,
          1.6220e-02, -5.9948e-02, -3.5069e-03, -3.2648e-02, -4.8549e-02,
          5.3182e-02, -1.2044e-01,  1.3931e-02, -2.7137e-03, -4.5740e-02,
         -5.4338e-03, -7.7437e-03,  1.9806e-03,  3.1177e-02,  2.3815e-02,
         -4.8724e-02,  3.4536e-02,  3.0051e-02,  7.7540e-02,  1.6905e-03,
          8.1902e-02,  2.0170e-02,  3.2379e-02, -5.5702e-04,  6.4838e-03,
          1.6441e-02, -7.1309e-03,  3.3163e-02,  7.0299e-03, -5.3777e-02,
          2.1970e-02, -4.3867e-02,  2.1012e-02,  1.3051e-02, -2.8605e-02,
         -2.1673e-02, -7.7806e-02, -5.0445e-03, -1.7270e-03,  3.8109e-02,
         -1.3078e-02, -9.2852e-04,  1.9087e-02, -1.4945e-02, -6.1000e-02,
          1.8101e-02, -1.5107e-02, -2.0382e-03, -1.2078e-01, -5.7932e-03,
         -8.0106e-02, -1.3348e-02, -1.6996e-02, -4.8043e-02, -4.8187e-02,
         -8.3656e-03,  1.7748e-02, -7.7310e-02,  1.2552e-02, -2.7422e-03,
         -5.8433e-02, -1.9581e-02, -1.2393e-03, -3.2840e-02,  1.9550e-02,
         -6.0912e-02,  2.5954e-02,  2.7044e-02,  1.5751e-02, -2.9800e-02,
         -5.7729e-03,  3.0403e-02, -2.5153e-02,  3.2811e-02,  2.6373e-02,
          3.6481e-02,  1.8692e-02,  1.2436e-02,  1.3068e-02, -5.4420e-02,
          3.5197e-02,  3.0339e-03, -2.9471e-02, -1.6294e-02,  1.9376e-03,
         -9.6888e-03, -9.2772e-04,  2.3004e-02,  1.8474e-02, -1.9492e-02,
         -1.8044e-02, -4.1529e-02, -1.5912e-03, -8.9523e-03,  5.9890e-02,
         -4.0150e-02, -2.6220e-02,  3.3787e-03, -1.3667e-02, -4.1747e-02,
         -6.0041e-02, -8.8984e-03, -4.2063e-03,  5.9290e-03, -2.0958e-02,
         -5.1528e-02, -1.1521e-02,  1.0781e-02,  1.7129e-02,  1.6626e-02,
          2.0090e-02, -1.3816e-02, -1.2629e-02,  3.5362e-02, -5.0381e-03,
         -4.6351e-02, -7.0735e-02, -3.3706e-04,  3.9993e-02, -8.3239e-03,
          6.3205e-04,  3.1786e-02,  5.0934e-02, -4.0317e-02, -9.7049e-02,
          4.1309e-03, -1.8061e-02,  2.9901e-03,  7.8605e-03,  1.7492e-02,
          2.6236e-03,  1.6291e-02, -3.4488e-02,  2.6159e-02, -9.3525e-03,
          4.2920e-02, -1.0373e-02,  1.3415e-02,  2.3318e-02, -4.4368e-02,
         -6.5866e-02,  4.0162e-03,  1.8135e-01,  3.4263e-02, -4.5111e-03,
          2.4679e-02, -2.6330e-02,  3.1580e-03, -4.3842e-02, -7.0146e-03,
          3.0014e-02,  2.1881e-02, -2.8969e-02,  7.2498e-03,  5.8582e-02,
         -4.8642e-02,  6.9074e-04, -1.2843e-01, -5.9926e-02, -1.6402e-02,
         -2.7898e-02, -1.1446e-02, -5.4659e-03,  1.5569e-02, -1.4688e-02,
          8.8983e-02,  2.7876e-02, -1.3252e-02,  8.6685e-02,  2.8667e-02,
          4.4407e-02,  4.5487e-03, -1.2681e-02, -1.2898e-02, -4.0447e-02,
          8.4092e-03,  8.4929e-02,  2.8877e-02, -2.1838e-02, -2.8689e-02,
          8.2374e-04, -4.7018e-02, -4.2617e-02,  4.0515e-02,  2.4294e-02,
          3.6246e-02, -3.2249e-02, -5.6838e-02,  2.5454e-05,  6.8819e-02,
          1.0482e-01, -6.8002e-02,  5.7314e-02,  1.6914e-02, -2.3342e-02,
          1.6111e-02, -1.4589e-02, -4.0899e-02,  1.0747e-02, -2.0445e-02,
          3.9802e-03, -1.4422e-02, -1.3104e-02, -7.0705e-03,  1.0223e-03,
         -1.1471e-02, -6.6757e-03, -3.8847e-02, -3.6280e-02,  5.1143e-02,
          1.7621e-02, -1.5823e-02,  4.1379e-03,  1.6249e-02,  7.7959e-02,
          3.1737e-02,  7.7840e-03, -3.6849e-02,  3.9102e-03, -2.1117e-02,
          1.7891e-02, -1.8777e-02, -2.6012e-02,  4.8590e-02,  1.3336e-01,
         -1.3727e-02, -6.1243e-02, -8.3668e-03,  2.3006e-03,  2.0935e-02,
          2.9661e-02, -7.2977e-02,  3.5479e-02, -2.4728e-02,  1.6162e-03,
         -4.0983e-02, -2.1773e-02, -1.2003e-02, -1.6234e-02,  4.2936e-02,
         -4.7685e-02, -1.6079e-02,  3.3041e-02, -5.8450e-02, -5.6025e-02,
          7.9667e-03, -3.1043e-02,  2.5277e-02,  1.2441e-02, -1.0250e-02,
         -2.4603e-02, -2.7787e-02,  4.4998e-02,  2.5554e-03, -3.1553e-02,
          2.2477e-02, -3.2626e-02,  5.9970e-02,  9.5149e-03, -1.0181e-02,
         -8.6941e-03, -1.7810e-02,  2.7997e-02, -4.9730e-04, -1.5720e-02,
         -1.7308e-03,  1.3384e-02,  3.0308e-02, -2.9071e-02, -6.0074e-03,
         -2.7530e-02, -2.1719e-02,  2.7147e-03,  2.9770e-02,  1.8560e-02,
         -2.8023e-02,  3.4007e-03,  3.8103e-02,  7.9685e-03,  2.2906e-02,
          9.8222e-03,  2.1956e-03,  1.8347e-02, -1.1035e-02, -2.1805e-02,
          8.0348e-03,  6.5365e-03,  5.7519e-03, -9.0100e-03,  3.4582e-03,
         -8.5428e-02, -8.2064e-03,  3.7818e-02, -9.2870e-03,  4.3539e-02,
         -1.0486e-02,  3.6035e-02, -1.8147e-01, -3.0396e-02,  1.1968e-02,
         -6.5025e-02, -8.5649e-02, -3.7049e-02, -2.0510e-02,  4.6033e-02,
          2.8916e-02, -1.9188e-02, -4.0011e-02, -1.7165e-02, -4.7628e-02,
         -5.9055e-03, -4.4336e-02, -4.5208e-02,  3.4894e-02, -3.3236e-03,
          3.5319e-03, -8.0306e-02, -3.3924e-02,  1.5684e-02, -1.4497e-02,
         -1.3824e-02,  1.3490e-02, -1.6188e-02,  2.1202e-02,  4.4270e-03,
          6.1059e-04, -3.2296e-02, -2.4578e-02, -5.8685e-02, -3.1978e-02,
          7.9443e-03, -2.7962e-03,  1.5279e-02, -1.8560e-02,  1.3277e-02,
         -1.8105e-02,  7.7305e-04, -2.1878e-02, -3.5787e-02,  4.1071e-02,
          6.1255e-02,  1.2895e-02, -2.7177e-02, -2.2769e-02, -6.8263e-02,
          8.7712e-03, -2.8485e-02,  7.2261e-03, -3.1435e-02, -4.9915e-02,
          4.9357e-03, -8.0902e-03, -3.0833e-02,  2.5438e-02,  2.3160e-02,
          3.1968e-03, -9.8436e-03,  3.6580e-02, -2.1161e-04, -3.0524e-02,
          3.5791e-02,  3.6604e-02,  1.0698e-02,  2.0288e-02, -1.0971e-02,
         -9.5724e-02,  1.1956e-02, -4.6368e-03,  1.8884e-02, -3.6222e-02,
          2.5449e-02, -3.3758e-02, -2.1567e-02, -4.5111e-03,  3.0986e-02,
          5.4160e-02, -3.9285e-03,  1.7523e-02,  1.0818e-02,  5.6329e-03,
         -2.1182e-03, -7.9104e-02,  4.3279e-02, -5.3102e-03, -4.8174e-02,
         -4.6041e-03, -1.1762e-01, -8.3034e-03,  1.3346e-02, -9.3953e-03,
          5.1457e-02, -1.1119e-02,  7.3139e-03, -2.0534e-02, -1.1732e-02,
          2.6643e-02, -9.7342e-03,  4.2314e-02, -6.7216e-02,  2.8781e-02,
          2.2335e-02, -1.6726e-02,  2.2697e-02,  2.6102e-02,  3.4108e-02,
         -1.3881e-02, -2.6161e-02, -4.4447e-02, -4.9447e-02, -2.6481e-03,
          8.4490e-03, -1.9126e-02, -7.8287e-02, -2.5057e-02,  3.7006e-02,
         -2.3662e-03, -8.1124e-03, -2.2771e-02, -2.3840e-02,  1.6940e-02,
          4.4685e-02, -6.3397e-03,  3.7412e-03,  1.9873e-02,  7.6575e-02,
          1.5466e-02, -2.6530e-02,  7.1482e-03, -2.8864e-03,  1.0379e-02,
         -9.7083e-03,  1.5241e-02,  5.2368e-02, -7.9797e-02,  1.1748e-02,
         -1.8466e-02,  1.5176e-03,  5.3380e-02, -3.0241e-02, -1.9172e-03,
          4.2681e-03,  5.1388e-03,  6.1217e-03,  1.9138e-02,  4.0037e-03,
         -4.4822e-02, -1.5271e-02, -3.2483e-02, -7.6644e-02,  8.8780e-03,
         -1.3251e-02, -6.0519e-03, -2.6461e-02,  7.2956e-03,  4.7201e-02,
          9.8676e-03,  1.4828e-02,  3.8183e-02, -3.2692e-02,  1.6484e-02,
         -3.1506e-02, -8.2864e-02,  1.8854e-02, -2.4680e-02, -5.8000e-02,
          7.4619e-02,  4.6146e-02, -2.2261e-02, -3.5827e-02, -1.4187e-02,
          2.3814e-02,  7.2704e-02, -1.6533e-03, -1.1226e-02, -6.5184e-03,
         -3.2976e-02,  1.2230e-02, -1.3260e-02, -1.3116e-02,  8.0362e-03,
         -1.7711e-02,  6.4595e-02,  2.1063e-02,  1.2096e-02, -5.1930e-02,
          5.4860e-02, -2.2176e-02,  4.1113e-02, -1.5210e-02,  5.1906e-03,
         -2.6861e-02, -2.8147e-02,  3.4399e-02, -4.9842e-02,  2.5388e-02,
          2.6219e-02,  4.1337e-02, -2.2133e-02,  2.3895e-02, -1.2442e-02,
          1.0604e-02,  1.8947e-02]).unsqueeze(0).to(device)
embed_to_score.load_state_dict(torch.load('best_model.pth'))
with torch.no_grad():
    embed_to_score.eval()
    output = embed_to_score(sample_tensor.to(device))
    print(output.item())

3.159768581390381


The embedding (hardcoded here) was generated by running this code in embed.ipynb

image = Image.open(r'C:\Users\rwwj8\OneDrive\Documents\colourful_man.jpg').convert("RGB")
text = "a man wearing green bandana, purple hoodie and long orange trousers"
image_embeds, text_embeds = run_embedding([image], [text])
combined_embeds = combine_embeddings(image_embeds, text_embeds)
print("Image Embedding Shape:", image_embeds.shape)
print(combined_embeds)