In [44]:
import torch
import torch.nn as nn
import random
import pandas as pd
from tqdm import tqdm

import numpy as np
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from pathlib import Path

In [45]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Device:', DEVICE)

data_path= '/kaggle/input/aml-competition/train/train/train.npz'
test_path= '/kaggle/input/aml-competition/test/test/test.clean.npz'

data = np.load(Path('/kaggle/input/aml-competition/train/train/train.npz'))

caption_embeddings = data['captions/embeddings']
image_embeddings = data['images/embeddings']
caption_labels = data['captions/label']

data.close()

x_abs = torch.from_numpy(caption_embeddings) # captions space
y_abs = torch.from_numpy(image_embeddings[np.argmax(caption_labels, axis=1)]) # images space

print('Captions latent space:', x_abs.shape)
print('Images latent space:', y_abs.shape)

Device: cuda
Captions latent space: torch.Size([125000, 1024])
Images latent space: torch.Size([125000, 1536])


In [46]:
class LinearTranslator(nn.Module):
    def __init__(self, dim: int, use_bias: bool, xavier_init: bool):
        super().__init__()
        assert isinstance(dim, int) and dim > 0, "Expecting positive dimension"
        assert isinstance(use_bias, bool), 'Expecting boolean param for "use_bias"'
        assert isinstance(xavier_init, bool), 'Expecting boolean param for "xavier_init"'
                
        self.linear = nn.Linear(dim, dim, bias=use_bias)

        if xavier_init:
            self._xavier_init()
    
    def _xavier_init(self):
            nn.init.xavier_uniform_(self.linear.weight)
            
            if self.linear.bias is not None:
                nn.init.constant_(self.linear.bias, 0)
    
    def forward(self, x):
        return self.linear(x)


class MLPTranslator(nn.Module):
    def __init__(self, input_dim, hidden_dims=[1200, 600, 300], output_dim=None, dropout=0.15):
        super().__init__()
        if output_dim is None:
            output_dim = input_dim

        layers = []
        in_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.LayerNorm(h_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            in_dim = h_dim
        layers.append(nn.Linear(in_dim, output_dim))  # output layer

        self.net = nn.Sequential(*layers)

        # Xavier initialization
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)



def train_translator(model: LinearTranslator, train_loader: DataLoader, val_loader: DataLoader, epochs: int, lr: float):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()

        train_loss = 0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)

            optimizer.zero_grad()

            # outputs = model(X_batch)

            y_batch = F.normalize(y_batch, p=2, dim=1)
            outputs = F.normalize(model(X_batch), p=2, dim=1)
            
            loss = 1 - (outputs * y_batch).sum(dim=1).mean()

            loss.backward()

            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()

        val_loss = 0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
                
                y_batch = F.normalize(y_batch, p=2, dim=1)
                outputs = F.normalize(model(X_batch), p=2, dim=1)
                
                loss = 1 - (outputs * y_batch).sum(dim=1).mean()

                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")

    return model

In [47]:
def center(data: torch.Tensor):
    return data - data.mean(dim=0, keepdim=True)


def normalize(data: torch.Tensor):
    return F.normalize(data, p=2, dim=1)


def extract_anchors(x_abs: torch.Tensor, y_abs: torch.Tensor, anchors_number: int):
    indices = torch.randperm(x_abs.size(0))[:anchors_number]
    
    return x_abs[indices], y_abs[indices]


def releative_representation(data: torch.Tensor, anchors: torch.Tensor):
    return F.normalize(data, p=2, dim=1) @ F.normalize(anchors.T, p=2, dim=1)

In [48]:
# omega = 8
# delta = 0.65
anchors_number = 1000

x = normalize(center(x_abs))
y = normalize(center(y_abs))

x_anchors, y_anchors = extract_anchors(x, y, anchors_number)

x_rel = x @ x_anchors.T # releative_representation(x_centered, x_anchors)
y_rel = y @ y_anchors.T # releative_representation(y_centered, y_anchors)

In [49]:
n_train = int(0.9 * x_rel.shape[0])
batch_size = 256

indices = torch.randperm(x_rel.shape[0])
train_idx = indices[:n_train]
val_idx = indices[n_train:]

x_train, x_val = x_rel[train_idx], x_rel[val_idx]
y_train, y_val = y_rel[train_idx], y_rel[val_idx]

train_dataset = TensorDataset(x_train, y_train)
val_dataset = TensorDataset(x_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [50]:
epochs = 300
lr = 0.0003

#model = LinearTranslator(x_rel.shape[1], True, True).to(DEVICE)
model = MLPTranslator(x_rel.shape[1]).to(DEVICE)

train_translator(model, train_loader, val_loader, epochs, lr)

Epoch 1/300: 100%|██████████| 440/440 [00:02<00:00, 169.63it/s]


Epoch 1: Train Loss = 0.425923, Val Loss = 0.342918


Epoch 2/300: 100%|██████████| 440/440 [00:02<00:00, 176.59it/s]


Epoch 2: Train Loss = 0.343787, Val Loss = 0.324698


Epoch 3/300: 100%|██████████| 440/440 [00:06<00:00, 66.78it/s] 


Epoch 3: Train Loss = 0.328265, Val Loss = 0.316758


Epoch 4/300: 100%|██████████| 440/440 [00:02<00:00, 175.74it/s]


Epoch 4: Train Loss = 0.318443, Val Loss = 0.309302


Epoch 5/300: 100%|██████████| 440/440 [00:02<00:00, 173.14it/s]


Epoch 5: Train Loss = 0.312157, Val Loss = 0.305658


Epoch 6/300: 100%|██████████| 440/440 [00:02<00:00, 164.55it/s]


Epoch 6: Train Loss = 0.307446, Val Loss = 0.303373


Epoch 7/300: 100%|██████████| 440/440 [00:02<00:00, 172.38it/s]


Epoch 7: Train Loss = 0.303322, Val Loss = 0.301008


Epoch 8/300: 100%|██████████| 440/440 [00:02<00:00, 174.99it/s]


Epoch 8: Train Loss = 0.300123, Val Loss = 0.296705


Epoch 9/300: 100%|██████████| 440/440 [00:02<00:00, 164.07it/s]


Epoch 9: Train Loss = 0.296864, Val Loss = 0.294960


Epoch 10/300: 100%|██████████| 440/440 [00:02<00:00, 173.06it/s]


Epoch 10: Train Loss = 0.294532, Val Loss = 0.292540


Epoch 11/300: 100%|██████████| 440/440 [00:02<00:00, 174.08it/s]


Epoch 11: Train Loss = 0.291820, Val Loss = 0.292099


Epoch 12/300: 100%|██████████| 440/440 [00:02<00:00, 163.76it/s]


Epoch 12: Train Loss = 0.289793, Val Loss = 0.292132


Epoch 13/300: 100%|██████████| 440/440 [00:02<00:00, 171.25it/s]


Epoch 13: Train Loss = 0.287790, Val Loss = 0.290566


Epoch 14/300: 100%|██████████| 440/440 [00:02<00:00, 172.17it/s]


Epoch 14: Train Loss = 0.285912, Val Loss = 0.288298


Epoch 15/300: 100%|██████████| 440/440 [00:02<00:00, 176.56it/s]


Epoch 15: Train Loss = 0.284275, Val Loss = 0.289035


Epoch 16/300: 100%|██████████| 440/440 [00:02<00:00, 174.00it/s]


Epoch 16: Train Loss = 0.282306, Val Loss = 0.286782


Epoch 17/300: 100%|██████████| 440/440 [00:02<00:00, 174.18it/s]


Epoch 17: Train Loss = 0.280769, Val Loss = 0.285714


Epoch 18/300: 100%|██████████| 440/440 [00:02<00:00, 171.93it/s]


Epoch 18: Train Loss = 0.278950, Val Loss = 0.286088


Epoch 19/300: 100%|██████████| 440/440 [00:02<00:00, 164.60it/s]


Epoch 19: Train Loss = 0.277280, Val Loss = 0.284180


Epoch 20/300: 100%|██████████| 440/440 [00:02<00:00, 173.15it/s]


Epoch 20: Train Loss = 0.275816, Val Loss = 0.284824


Epoch 21/300: 100%|██████████| 440/440 [00:02<00:00, 169.47it/s]


Epoch 21: Train Loss = 0.274281, Val Loss = 0.282417


Epoch 22/300: 100%|██████████| 440/440 [00:02<00:00, 164.62it/s]


Epoch 22: Train Loss = 0.272513, Val Loss = 0.282570


Epoch 23/300: 100%|██████████| 440/440 [00:02<00:00, 173.48it/s]


Epoch 23: Train Loss = 0.271196, Val Loss = 0.281407


Epoch 24/300: 100%|██████████| 440/440 [00:02<00:00, 175.61it/s]


Epoch 24: Train Loss = 0.269624, Val Loss = 0.280726


Epoch 25/300: 100%|██████████| 440/440 [00:02<00:00, 160.96it/s]


Epoch 25: Train Loss = 0.268391, Val Loss = 0.280465


Epoch 26/300: 100%|██████████| 440/440 [00:02<00:00, 172.89it/s]


Epoch 26: Train Loss = 0.266924, Val Loss = 0.280341


Epoch 27/300: 100%|██████████| 440/440 [00:02<00:00, 174.41it/s]


Epoch 27: Train Loss = 0.265062, Val Loss = 0.280150


Epoch 28/300: 100%|██████████| 440/440 [00:02<00:00, 165.53it/s]


Epoch 28: Train Loss = 0.263748, Val Loss = 0.279801


Epoch 29/300: 100%|██████████| 440/440 [00:02<00:00, 173.66it/s]


Epoch 29: Train Loss = 0.262462, Val Loss = 0.278849


Epoch 30/300: 100%|██████████| 440/440 [00:02<00:00, 173.61it/s]


Epoch 30: Train Loss = 0.260829, Val Loss = 0.278726


Epoch 31/300: 100%|██████████| 440/440 [00:02<00:00, 164.17it/s]


Epoch 31: Train Loss = 0.259261, Val Loss = 0.278569


Epoch 32/300: 100%|██████████| 440/440 [00:02<00:00, 173.49it/s]


Epoch 32: Train Loss = 0.258053, Val Loss = 0.279012


Epoch 33/300: 100%|██████████| 440/440 [00:02<00:00, 174.81it/s]


Epoch 33: Train Loss = 0.256855, Val Loss = 0.278701


Epoch 34/300: 100%|██████████| 440/440 [00:02<00:00, 164.15it/s]


Epoch 34: Train Loss = 0.255148, Val Loss = 0.279116


Epoch 35/300: 100%|██████████| 440/440 [00:02<00:00, 176.04it/s]


Epoch 35: Train Loss = 0.253823, Val Loss = 0.277973


Epoch 36/300: 100%|██████████| 440/440 [00:02<00:00, 172.67it/s]


Epoch 36: Train Loss = 0.252507, Val Loss = 0.277957


Epoch 37/300: 100%|██████████| 440/440 [00:02<00:00, 165.84it/s]


Epoch 37: Train Loss = 0.251273, Val Loss = 0.277593


Epoch 38/300: 100%|██████████| 440/440 [00:02<00:00, 175.55it/s]


Epoch 38: Train Loss = 0.250006, Val Loss = 0.277258


Epoch 39/300: 100%|██████████| 440/440 [00:02<00:00, 174.78it/s]


Epoch 39: Train Loss = 0.248686, Val Loss = 0.276758


Epoch 40/300: 100%|██████████| 440/440 [00:02<00:00, 164.35it/s]


Epoch 40: Train Loss = 0.247354, Val Loss = 0.276780


Epoch 41/300: 100%|██████████| 440/440 [00:02<00:00, 175.54it/s]


Epoch 41: Train Loss = 0.246852, Val Loss = 0.276097


Epoch 42/300: 100%|██████████| 440/440 [00:02<00:00, 172.89it/s]


Epoch 42: Train Loss = 0.245312, Val Loss = 0.275712


Epoch 43/300: 100%|██████████| 440/440 [00:02<00:00, 162.53it/s]


Epoch 43: Train Loss = 0.243921, Val Loss = 0.276272


Epoch 44/300: 100%|██████████| 440/440 [00:02<00:00, 173.80it/s]


Epoch 44: Train Loss = 0.242812, Val Loss = 0.275905


Epoch 45/300: 100%|██████████| 440/440 [00:02<00:00, 173.70it/s]


Epoch 45: Train Loss = 0.241807, Val Loss = 0.275989


Epoch 46/300: 100%|██████████| 440/440 [00:02<00:00, 174.79it/s]


Epoch 46: Train Loss = 0.240868, Val Loss = 0.275364


Epoch 47/300: 100%|██████████| 440/440 [00:02<00:00, 171.78it/s]


Epoch 47: Train Loss = 0.239538, Val Loss = 0.276814


Epoch 48/300: 100%|██████████| 440/440 [00:02<00:00, 174.57it/s]


Epoch 48: Train Loss = 0.238885, Val Loss = 0.275314


Epoch 49/300: 100%|██████████| 440/440 [00:02<00:00, 169.59it/s]


Epoch 49: Train Loss = 0.237722, Val Loss = 0.275056


Epoch 50/300: 100%|██████████| 440/440 [00:02<00:00, 160.98it/s]


Epoch 50: Train Loss = 0.236780, Val Loss = 0.275688


Epoch 51/300: 100%|██████████| 440/440 [00:02<00:00, 173.71it/s]


Epoch 51: Train Loss = 0.235886, Val Loss = 0.274468


Epoch 52/300: 100%|██████████| 440/440 [00:02<00:00, 170.57it/s]


Epoch 52: Train Loss = 0.235096, Val Loss = 0.275312


Epoch 53/300: 100%|██████████| 440/440 [00:02<00:00, 163.78it/s]


Epoch 53: Train Loss = 0.233856, Val Loss = 0.274975


Epoch 54/300: 100%|██████████| 440/440 [00:02<00:00, 173.30it/s]


Epoch 54: Train Loss = 0.232824, Val Loss = 0.275787


Epoch 55/300: 100%|██████████| 440/440 [00:02<00:00, 172.66it/s]


Epoch 55: Train Loss = 0.232103, Val Loss = 0.274639


Epoch 56/300: 100%|██████████| 440/440 [00:02<00:00, 163.95it/s]


Epoch 56: Train Loss = 0.231376, Val Loss = 0.274260


Epoch 57/300: 100%|██████████| 440/440 [00:02<00:00, 175.30it/s]


Epoch 57: Train Loss = 0.230480, Val Loss = 0.273958


Epoch 58/300: 100%|██████████| 440/440 [00:02<00:00, 174.86it/s]


Epoch 58: Train Loss = 0.229601, Val Loss = 0.273741


Epoch 59/300: 100%|██████████| 440/440 [00:02<00:00, 164.80it/s]


Epoch 59: Train Loss = 0.228705, Val Loss = 0.274202


Epoch 60/300: 100%|██████████| 440/440 [00:02<00:00, 173.21it/s]


Epoch 60: Train Loss = 0.228355, Val Loss = 0.274362


Epoch 61/300: 100%|██████████| 440/440 [00:02<00:00, 172.21it/s]


Epoch 61: Train Loss = 0.227282, Val Loss = 0.274502


Epoch 62/300: 100%|██████████| 440/440 [00:02<00:00, 161.78it/s]


Epoch 62: Train Loss = 0.226619, Val Loss = 0.274122


Epoch 63/300: 100%|██████████| 440/440 [00:02<00:00, 174.67it/s]


Epoch 63: Train Loss = 0.225629, Val Loss = 0.275091


Epoch 64/300: 100%|██████████| 440/440 [00:02<00:00, 174.19it/s]


Epoch 64: Train Loss = 0.225605, Val Loss = 0.273665


Epoch 65/300: 100%|██████████| 440/440 [00:02<00:00, 161.05it/s]


Epoch 65: Train Loss = 0.224599, Val Loss = 0.274024


Epoch 66/300: 100%|██████████| 440/440 [00:02<00:00, 175.14it/s]


Epoch 66: Train Loss = 0.223600, Val Loss = 0.273467


Epoch 67/300: 100%|██████████| 440/440 [00:02<00:00, 173.70it/s]


Epoch 67: Train Loss = 0.223184, Val Loss = 0.273550


Epoch 68/300: 100%|██████████| 440/440 [00:02<00:00, 164.56it/s]


Epoch 68: Train Loss = 0.222494, Val Loss = 0.274710


Epoch 69/300: 100%|██████████| 440/440 [00:02<00:00, 173.24it/s]


Epoch 69: Train Loss = 0.221572, Val Loss = 0.273614


Epoch 70/300: 100%|██████████| 440/440 [00:02<00:00, 174.65it/s]


Epoch 70: Train Loss = 0.221145, Val Loss = 0.272811


Epoch 71/300: 100%|██████████| 440/440 [00:02<00:00, 164.50it/s]


Epoch 71: Train Loss = 0.220410, Val Loss = 0.273293


Epoch 72/300: 100%|██████████| 440/440 [00:02<00:00, 174.44it/s]


Epoch 72: Train Loss = 0.219636, Val Loss = 0.273381


Epoch 73/300: 100%|██████████| 440/440 [00:02<00:00, 176.03it/s]


Epoch 73: Train Loss = 0.219210, Val Loss = 0.273659


Epoch 74/300: 100%|██████████| 440/440 [00:02<00:00, 173.04it/s]


Epoch 74: Train Loss = 0.218720, Val Loss = 0.273624


Epoch 75/300: 100%|██████████| 440/440 [00:02<00:00, 171.45it/s]


Epoch 75: Train Loss = 0.217615, Val Loss = 0.274235


Epoch 76/300: 100%|██████████| 440/440 [00:02<00:00, 172.45it/s]


Epoch 76: Train Loss = 0.217111, Val Loss = 0.274398


Epoch 77/300: 100%|██████████| 440/440 [00:02<00:00, 174.49it/s]


Epoch 77: Train Loss = 0.216660, Val Loss = 0.274390


Epoch 78/300: 100%|██████████| 440/440 [00:02<00:00, 162.59it/s]


Epoch 78: Train Loss = 0.216210, Val Loss = 0.273708


Epoch 79/300: 100%|██████████| 440/440 [00:02<00:00, 174.68it/s]


Epoch 79: Train Loss = 0.215635, Val Loss = 0.273249


Epoch 80/300: 100%|██████████| 440/440 [00:02<00:00, 174.04it/s]


Epoch 80: Train Loss = 0.215244, Val Loss = 0.272890


Epoch 81/300: 100%|██████████| 440/440 [00:02<00:00, 163.16it/s]


Epoch 81: Train Loss = 0.214484, Val Loss = 0.273254


Epoch 82/300: 100%|██████████| 440/440 [00:02<00:00, 175.43it/s]


Epoch 82: Train Loss = 0.213944, Val Loss = 0.273556


Epoch 83/300: 100%|██████████| 440/440 [00:02<00:00, 174.49it/s]


Epoch 83: Train Loss = 0.213792, Val Loss = 0.273694


Epoch 84/300: 100%|██████████| 440/440 [00:02<00:00, 165.46it/s]


Epoch 84: Train Loss = 0.213210, Val Loss = 0.273732


Epoch 85/300: 100%|██████████| 440/440 [00:02<00:00, 171.59it/s]


Epoch 85: Train Loss = 0.212567, Val Loss = 0.274550


Epoch 86/300: 100%|██████████| 440/440 [00:02<00:00, 173.27it/s]


Epoch 86: Train Loss = 0.211928, Val Loss = 0.273686


Epoch 87/300: 100%|██████████| 440/440 [00:02<00:00, 161.99it/s]


Epoch 87: Train Loss = 0.211499, Val Loss = 0.273612


Epoch 88/300: 100%|██████████| 440/440 [00:02<00:00, 176.57it/s]


Epoch 88: Train Loss = 0.211048, Val Loss = 0.274488


Epoch 89/300: 100%|██████████| 440/440 [00:02<00:00, 174.13it/s]


Epoch 89: Train Loss = 0.210722, Val Loss = 0.273618


Epoch 90/300: 100%|██████████| 440/440 [00:02<00:00, 166.84it/s]


Epoch 90: Train Loss = 0.210365, Val Loss = 0.273409


Epoch 91/300: 100%|██████████| 440/440 [00:02<00:00, 172.41it/s]


Epoch 91: Train Loss = 0.209576, Val Loss = 0.272994


Epoch 92/300: 100%|██████████| 440/440 [00:02<00:00, 175.40it/s]


Epoch 92: Train Loss = 0.209546, Val Loss = 0.274562


Epoch 93/300: 100%|██████████| 440/440 [00:02<00:00, 164.99it/s]


Epoch 93: Train Loss = 0.208609, Val Loss = 0.273898


Epoch 94/300: 100%|██████████| 440/440 [00:02<00:00, 174.41it/s]


Epoch 94: Train Loss = 0.208392, Val Loss = 0.273387


Epoch 95/300: 100%|██████████| 440/440 [00:02<00:00, 175.53it/s]


Epoch 95: Train Loss = 0.207752, Val Loss = 0.273398


Epoch 96/300: 100%|██████████| 440/440 [00:02<00:00, 164.69it/s]


Epoch 96: Train Loss = 0.207864, Val Loss = 0.273716


Epoch 97/300: 100%|██████████| 440/440 [00:02<00:00, 177.93it/s]


Epoch 97: Train Loss = 0.207111, Val Loss = 0.273414


Epoch 98/300: 100%|██████████| 440/440 [00:02<00:00, 169.37it/s]


Epoch 98: Train Loss = 0.206867, Val Loss = 0.274340


Epoch 99/300: 100%|██████████| 440/440 [00:02<00:00, 161.29it/s]


Epoch 99: Train Loss = 0.206364, Val Loss = 0.273801


Epoch 100/300: 100%|██████████| 440/440 [00:02<00:00, 168.68it/s]


Epoch 100: Train Loss = 0.206059, Val Loss = 0.274596


Epoch 101/300: 100%|██████████| 440/440 [00:02<00:00, 174.43it/s]


Epoch 101: Train Loss = 0.205414, Val Loss = 0.273824


Epoch 102/300: 100%|██████████| 440/440 [00:02<00:00, 162.60it/s]


Epoch 102: Train Loss = 0.205221, Val Loss = 0.274295


Epoch 103/300: 100%|██████████| 440/440 [00:02<00:00, 174.00it/s]


Epoch 103: Train Loss = 0.205062, Val Loss = 0.273773


Epoch 104/300: 100%|██████████| 440/440 [00:02<00:00, 172.61it/s]


Epoch 104: Train Loss = 0.204352, Val Loss = 0.274006


Epoch 105/300: 100%|██████████| 440/440 [00:02<00:00, 170.59it/s]


Epoch 105: Train Loss = 0.204182, Val Loss = 0.273976


Epoch 106/300: 100%|██████████| 440/440 [00:02<00:00, 171.87it/s]


Epoch 106: Train Loss = 0.203480, Val Loss = 0.273464


Epoch 107/300: 100%|██████████| 440/440 [00:02<00:00, 171.93it/s]


Epoch 107: Train Loss = 0.203233, Val Loss = 0.274038


Epoch 108/300: 100%|██████████| 440/440 [00:02<00:00, 170.78it/s]


Epoch 108: Train Loss = 0.203092, Val Loss = 0.273616


Epoch 109/300: 100%|██████████| 440/440 [00:02<00:00, 160.44it/s]


Epoch 109: Train Loss = 0.202842, Val Loss = 0.273830


Epoch 110/300: 100%|██████████| 440/440 [00:02<00:00, 175.93it/s]


Epoch 110: Train Loss = 0.202636, Val Loss = 0.273817


Epoch 111/300: 100%|██████████| 440/440 [00:02<00:00, 171.51it/s]


Epoch 111: Train Loss = 0.201795, Val Loss = 0.273324


Epoch 112/300: 100%|██████████| 440/440 [00:02<00:00, 164.50it/s]


Epoch 112: Train Loss = 0.201413, Val Loss = 0.274732


Epoch 113/300: 100%|██████████| 440/440 [00:02<00:00, 172.94it/s]


Epoch 113: Train Loss = 0.201141, Val Loss = 0.273630


Epoch 114/300: 100%|██████████| 440/440 [00:02<00:00, 172.80it/s]


Epoch 114: Train Loss = 0.200787, Val Loss = 0.274131


Epoch 115/300: 100%|██████████| 440/440 [00:02<00:00, 164.43it/s]


Epoch 115: Train Loss = 0.200743, Val Loss = 0.274812


Epoch 116/300: 100%|██████████| 440/440 [00:02<00:00, 171.54it/s]


Epoch 116: Train Loss = 0.200037, Val Loss = 0.274115


Epoch 117/300: 100%|██████████| 440/440 [00:02<00:00, 176.70it/s]


Epoch 117: Train Loss = 0.199824, Val Loss = 0.273953


Epoch 118/300: 100%|██████████| 440/440 [00:02<00:00, 163.42it/s]


Epoch 118: Train Loss = 0.199718, Val Loss = 0.274355


Epoch 119/300: 100%|██████████| 440/440 [00:02<00:00, 175.75it/s]


Epoch 119: Train Loss = 0.199025, Val Loss = 0.274860


Epoch 120/300: 100%|██████████| 440/440 [00:02<00:00, 172.14it/s]


Epoch 120: Train Loss = 0.198844, Val Loss = 0.274276


Epoch 121/300: 100%|██████████| 440/440 [00:02<00:00, 164.69it/s]


Epoch 121: Train Loss = 0.198969, Val Loss = 0.274279


Epoch 122/300: 100%|██████████| 440/440 [00:02<00:00, 173.76it/s]


Epoch 122: Train Loss = 0.198438, Val Loss = 0.274782


Epoch 123/300: 100%|██████████| 440/440 [00:02<00:00, 173.90it/s]


Epoch 123: Train Loss = 0.197848, Val Loss = 0.273708


Epoch 124/300: 100%|██████████| 440/440 [00:02<00:00, 162.17it/s]


Epoch 124: Train Loss = 0.197882, Val Loss = 0.273900


Epoch 125/300: 100%|██████████| 440/440 [00:02<00:00, 173.82it/s]


Epoch 125: Train Loss = 0.197767, Val Loss = 0.273912


Epoch 126/300: 100%|██████████| 440/440 [00:02<00:00, 173.58it/s]


Epoch 126: Train Loss = 0.197138, Val Loss = 0.274632


Epoch 127/300: 100%|██████████| 440/440 [00:02<00:00, 163.18it/s]


Epoch 127: Train Loss = 0.196931, Val Loss = 0.273982


Epoch 128/300: 100%|██████████| 440/440 [00:02<00:00, 170.98it/s]


Epoch 128: Train Loss = 0.196639, Val Loss = 0.273486


Epoch 129/300: 100%|██████████| 440/440 [00:02<00:00, 172.39it/s]


Epoch 129: Train Loss = 0.196391, Val Loss = 0.274048


Epoch 130/300: 100%|██████████| 440/440 [00:02<00:00, 161.91it/s]


Epoch 130: Train Loss = 0.196199, Val Loss = 0.274273


Epoch 131/300: 100%|██████████| 440/440 [00:02<00:00, 170.14it/s]


Epoch 131: Train Loss = 0.195793, Val Loss = 0.274197


Epoch 132/300: 100%|██████████| 440/440 [00:02<00:00, 174.77it/s]


Epoch 132: Train Loss = 0.195642, Val Loss = 0.273294


Epoch 133/300: 100%|██████████| 440/440 [00:02<00:00, 163.56it/s]


Epoch 133: Train Loss = 0.195393, Val Loss = 0.274457


Epoch 134/300: 100%|██████████| 440/440 [00:02<00:00, 174.52it/s]


Epoch 134: Train Loss = 0.195431, Val Loss = 0.273439


Epoch 135/300: 100%|██████████| 440/440 [00:02<00:00, 171.63it/s]


Epoch 135: Train Loss = 0.194908, Val Loss = 0.273888


Epoch 136/300: 100%|██████████| 440/440 [00:02<00:00, 174.62it/s]


Epoch 136: Train Loss = 0.194288, Val Loss = 0.273410


Epoch 137/300: 100%|██████████| 440/440 [00:02<00:00, 174.71it/s]


Epoch 137: Train Loss = 0.194343, Val Loss = 0.274455


Epoch 138/300: 100%|██████████| 440/440 [00:02<00:00, 171.01it/s]


Epoch 138: Train Loss = 0.193889, Val Loss = 0.273948


Epoch 139/300: 100%|██████████| 440/440 [00:02<00:00, 173.59it/s]


Epoch 139: Train Loss = 0.193632, Val Loss = 0.273226


Epoch 140/300: 100%|██████████| 440/440 [00:02<00:00, 159.62it/s]


Epoch 140: Train Loss = 0.193621, Val Loss = 0.274579


Epoch 141/300: 100%|██████████| 440/440 [00:02<00:00, 174.63it/s]


Epoch 141: Train Loss = 0.193126, Val Loss = 0.273977


Epoch 142/300: 100%|██████████| 440/440 [00:02<00:00, 173.38it/s]


Epoch 142: Train Loss = 0.192972, Val Loss = 0.274398


Epoch 143/300: 100%|██████████| 440/440 [00:02<00:00, 164.97it/s]


Epoch 143: Train Loss = 0.192913, Val Loss = 0.274202


Epoch 144/300: 100%|██████████| 440/440 [00:02<00:00, 175.55it/s]


Epoch 144: Train Loss = 0.192855, Val Loss = 0.274236


Epoch 145/300: 100%|██████████| 440/440 [00:02<00:00, 174.28it/s]


Epoch 145: Train Loss = 0.192111, Val Loss = 0.273046


Epoch 146/300: 100%|██████████| 440/440 [00:02<00:00, 165.56it/s]


Epoch 146: Train Loss = 0.192068, Val Loss = 0.273871


Epoch 147/300: 100%|██████████| 440/440 [00:02<00:00, 176.15it/s]


Epoch 147: Train Loss = 0.192415, Val Loss = 0.274258


Epoch 148/300: 100%|██████████| 440/440 [00:02<00:00, 172.57it/s]


Epoch 148: Train Loss = 0.191754, Val Loss = 0.273447


Epoch 149/300: 100%|██████████| 440/440 [00:02<00:00, 162.34it/s]


Epoch 149: Train Loss = 0.191567, Val Loss = 0.274236


Epoch 150/300: 100%|██████████| 440/440 [00:02<00:00, 170.86it/s]


Epoch 150: Train Loss = 0.191019, Val Loss = 0.274318


Epoch 151/300: 100%|██████████| 440/440 [00:02<00:00, 171.06it/s]


Epoch 151: Train Loss = 0.191158, Val Loss = 0.274728


Epoch 152/300: 100%|██████████| 440/440 [00:02<00:00, 159.70it/s]


Epoch 152: Train Loss = 0.190894, Val Loss = 0.273749


Epoch 153/300: 100%|██████████| 440/440 [00:02<00:00, 170.26it/s]


Epoch 153: Train Loss = 0.190479, Val Loss = 0.274088


Epoch 154/300: 100%|██████████| 440/440 [00:02<00:00, 174.26it/s]


Epoch 154: Train Loss = 0.190621, Val Loss = 0.273725


Epoch 155/300: 100%|██████████| 440/440 [00:02<00:00, 163.19it/s]


Epoch 155: Train Loss = 0.190053, Val Loss = 0.274134


Epoch 156/300: 100%|██████████| 440/440 [00:02<00:00, 171.68it/s]


Epoch 156: Train Loss = 0.189961, Val Loss = 0.273660


Epoch 157/300: 100%|██████████| 440/440 [00:02<00:00, 171.17it/s]


Epoch 157: Train Loss = 0.189991, Val Loss = 0.273815


Epoch 158/300: 100%|██████████| 440/440 [00:02<00:00, 163.23it/s]


Epoch 158: Train Loss = 0.189623, Val Loss = 0.274565


Epoch 159/300: 100%|██████████| 440/440 [00:02<00:00, 171.32it/s]


Epoch 159: Train Loss = 0.189339, Val Loss = 0.274232


Epoch 160/300: 100%|██████████| 440/440 [00:02<00:00, 170.79it/s]


Epoch 160: Train Loss = 0.189118, Val Loss = 0.274362


Epoch 161/300: 100%|██████████| 440/440 [00:02<00:00, 163.72it/s]


Epoch 161: Train Loss = 0.189021, Val Loss = 0.274422


Epoch 162/300: 100%|██████████| 440/440 [00:02<00:00, 172.42it/s]


Epoch 162: Train Loss = 0.188920, Val Loss = 0.274415


Epoch 163/300: 100%|██████████| 440/440 [00:02<00:00, 174.14it/s]


Epoch 163: Train Loss = 0.188637, Val Loss = 0.274744


Epoch 164/300: 100%|██████████| 440/440 [00:02<00:00, 162.79it/s]


Epoch 164: Train Loss = 0.188518, Val Loss = 0.274941


Epoch 165/300: 100%|██████████| 440/440 [00:02<00:00, 172.43it/s]


Epoch 165: Train Loss = 0.188437, Val Loss = 0.274207


Epoch 166/300: 100%|██████████| 440/440 [00:02<00:00, 174.74it/s]


Epoch 166: Train Loss = 0.188217, Val Loss = 0.274618


Epoch 167/300: 100%|██████████| 440/440 [00:02<00:00, 171.36it/s]


Epoch 167: Train Loss = 0.187680, Val Loss = 0.274301


Epoch 168/300: 100%|██████████| 440/440 [00:02<00:00, 172.76it/s]


Epoch 168: Train Loss = 0.187123, Val Loss = 0.274511


Epoch 169/300: 100%|██████████| 440/440 [00:02<00:00, 173.88it/s]


Epoch 169: Train Loss = 0.187583, Val Loss = 0.275208


Epoch 170/300: 100%|██████████| 440/440 [00:02<00:00, 169.78it/s]


Epoch 170: Train Loss = 0.187648, Val Loss = 0.275118


Epoch 171/300: 100%|██████████| 440/440 [00:02<00:00, 163.94it/s]


Epoch 171: Train Loss = 0.187247, Val Loss = 0.275065


Epoch 172/300: 100%|██████████| 440/440 [00:02<00:00, 174.57it/s]


Epoch 172: Train Loss = 0.187397, Val Loss = 0.275022


Epoch 173/300: 100%|██████████| 440/440 [00:02<00:00, 172.42it/s]


Epoch 173: Train Loss = 0.186645, Val Loss = 0.274946


Epoch 174/300: 100%|██████████| 440/440 [00:02<00:00, 164.08it/s]


Epoch 174: Train Loss = 0.186688, Val Loss = 0.274709


Epoch 175/300: 100%|██████████| 440/440 [00:02<00:00, 173.46it/s]


Epoch 175: Train Loss = 0.186567, Val Loss = 0.274616


Epoch 176/300: 100%|██████████| 440/440 [00:02<00:00, 174.66it/s]


Epoch 176: Train Loss = 0.186140, Val Loss = 0.275283


Epoch 177/300: 100%|██████████| 440/440 [00:02<00:00, 165.40it/s]


Epoch 177: Train Loss = 0.186167, Val Loss = 0.274639


Epoch 178/300: 100%|██████████| 440/440 [00:02<00:00, 172.29it/s]


Epoch 178: Train Loss = 0.185763, Val Loss = 0.274978


Epoch 179/300: 100%|██████████| 440/440 [00:02<00:00, 174.96it/s]


Epoch 179: Train Loss = 0.185992, Val Loss = 0.274454


Epoch 180/300: 100%|██████████| 440/440 [00:02<00:00, 164.84it/s]


Epoch 180: Train Loss = 0.185264, Val Loss = 0.274981


Epoch 181/300: 100%|██████████| 440/440 [00:02<00:00, 174.36it/s]


Epoch 181: Train Loss = 0.185369, Val Loss = 0.275546


Epoch 182/300: 100%|██████████| 440/440 [00:02<00:00, 175.46it/s]


Epoch 182: Train Loss = 0.185313, Val Loss = 0.274635


Epoch 183/300: 100%|██████████| 440/440 [00:02<00:00, 167.01it/s]


Epoch 183: Train Loss = 0.185109, Val Loss = 0.275108


Epoch 184/300: 100%|██████████| 440/440 [00:02<00:00, 176.98it/s]


Epoch 184: Train Loss = 0.184627, Val Loss = 0.274706


Epoch 185/300: 100%|██████████| 440/440 [00:02<00:00, 175.52it/s]


Epoch 185: Train Loss = 0.184694, Val Loss = 0.274893


Epoch 186/300: 100%|██████████| 440/440 [00:02<00:00, 164.07it/s]


Epoch 186: Train Loss = 0.184308, Val Loss = 0.275030


Epoch 187/300: 100%|██████████| 440/440 [00:02<00:00, 173.18it/s]


Epoch 187: Train Loss = 0.184420, Val Loss = 0.275364


Epoch 188/300: 100%|██████████| 440/440 [00:02<00:00, 175.59it/s]


Epoch 188: Train Loss = 0.184199, Val Loss = 0.274567


Epoch 189/300: 100%|██████████| 440/440 [00:02<00:00, 162.03it/s]


Epoch 189: Train Loss = 0.184221, Val Loss = 0.274380


Epoch 190/300: 100%|██████████| 440/440 [00:02<00:00, 175.10it/s]


Epoch 190: Train Loss = 0.184021, Val Loss = 0.275217


Epoch 191/300: 100%|██████████| 440/440 [00:02<00:00, 175.52it/s]


Epoch 191: Train Loss = 0.183556, Val Loss = 0.275502


Epoch 192/300: 100%|██████████| 440/440 [00:02<00:00, 165.10it/s]


Epoch 192: Train Loss = 0.183332, Val Loss = 0.275333


Epoch 193/300: 100%|██████████| 440/440 [00:02<00:00, 174.23it/s]


Epoch 193: Train Loss = 0.183785, Val Loss = 0.274986


Epoch 194/300: 100%|██████████| 440/440 [00:02<00:00, 172.35it/s]


Epoch 194: Train Loss = 0.183538, Val Loss = 0.274862


Epoch 195/300: 100%|██████████| 440/440 [00:02<00:00, 171.22it/s]


Epoch 195: Train Loss = 0.183612, Val Loss = 0.274861


Epoch 196/300: 100%|██████████| 440/440 [00:02<00:00, 175.23it/s]


Epoch 196: Train Loss = 0.182672, Val Loss = 0.275478


Epoch 197/300: 100%|██████████| 440/440 [00:02<00:00, 171.30it/s]


Epoch 197: Train Loss = 0.183252, Val Loss = 0.274848


Epoch 198/300: 100%|██████████| 440/440 [00:02<00:00, 173.16it/s]


Epoch 198: Train Loss = 0.182740, Val Loss = 0.275405


Epoch 199/300: 100%|██████████| 440/440 [00:02<00:00, 162.77it/s]


Epoch 199: Train Loss = 0.182573, Val Loss = 0.275013


Epoch 200/300: 100%|██████████| 440/440 [00:02<00:00, 172.31it/s]


Epoch 200: Train Loss = 0.182514, Val Loss = 0.275686


Epoch 201/300: 100%|██████████| 440/440 [00:02<00:00, 176.66it/s]


Epoch 201: Train Loss = 0.182341, Val Loss = 0.275227


Epoch 202/300: 100%|██████████| 440/440 [00:02<00:00, 165.00it/s]


Epoch 202: Train Loss = 0.181789, Val Loss = 0.274410


Epoch 203/300: 100%|██████████| 440/440 [00:02<00:00, 172.86it/s]


Epoch 203: Train Loss = 0.182019, Val Loss = 0.275051


Epoch 204/300: 100%|██████████| 440/440 [00:02<00:00, 173.31it/s]


Epoch 204: Train Loss = 0.181802, Val Loss = 0.275524


Epoch 205/300: 100%|██████████| 440/440 [00:02<00:00, 164.62it/s]


Epoch 205: Train Loss = 0.181713, Val Loss = 0.275157


Epoch 206/300: 100%|██████████| 440/440 [00:02<00:00, 174.28it/s]


Epoch 206: Train Loss = 0.181685, Val Loss = 0.276141


Epoch 207/300: 100%|██████████| 440/440 [00:02<00:00, 172.99it/s]


Epoch 207: Train Loss = 0.181275, Val Loss = 0.275260


Epoch 208/300: 100%|██████████| 440/440 [00:02<00:00, 164.81it/s]


Epoch 208: Train Loss = 0.181523, Val Loss = 0.275803


Epoch 209/300: 100%|██████████| 440/440 [00:02<00:00, 178.00it/s]


Epoch 209: Train Loss = 0.180955, Val Loss = 0.275258


Epoch 210/300: 100%|██████████| 440/440 [00:02<00:00, 173.34it/s]


Epoch 210: Train Loss = 0.181270, Val Loss = 0.276131


Epoch 211/300: 100%|██████████| 440/440 [00:02<00:00, 162.37it/s]


Epoch 211: Train Loss = 0.181322, Val Loss = 0.275266


Epoch 212/300: 100%|██████████| 440/440 [00:02<00:00, 175.22it/s]


Epoch 212: Train Loss = 0.180778, Val Loss = 0.275628


Epoch 213/300: 100%|██████████| 440/440 [00:02<00:00, 173.86it/s]


Epoch 213: Train Loss = 0.180200, Val Loss = 0.275328


Epoch 214/300: 100%|██████████| 440/440 [00:02<00:00, 164.91it/s]


Epoch 214: Train Loss = 0.180355, Val Loss = 0.276232


Epoch 215/300: 100%|██████████| 440/440 [00:02<00:00, 171.71it/s]


Epoch 215: Train Loss = 0.180606, Val Loss = 0.274246


Epoch 216/300: 100%|██████████| 440/440 [00:02<00:00, 174.69it/s]


Epoch 216: Train Loss = 0.180440, Val Loss = 0.274669


Epoch 217/300: 100%|██████████| 440/440 [00:02<00:00, 163.66it/s]


Epoch 217: Train Loss = 0.180194, Val Loss = 0.274791


Epoch 218/300: 100%|██████████| 440/440 [00:02<00:00, 174.44it/s]


Epoch 218: Train Loss = 0.180117, Val Loss = 0.274770


Epoch 219/300: 100%|██████████| 440/440 [00:02<00:00, 172.42it/s]


Epoch 219: Train Loss = 0.180267, Val Loss = 0.275035


Epoch 220/300: 100%|██████████| 440/440 [00:02<00:00, 162.02it/s]


Epoch 220: Train Loss = 0.180095, Val Loss = 0.275153


Epoch 221/300: 100%|██████████| 440/440 [00:02<00:00, 175.81it/s]


Epoch 221: Train Loss = 0.179576, Val Loss = 0.276217


Epoch 222/300: 100%|██████████| 440/440 [00:02<00:00, 172.04it/s]


Epoch 222: Train Loss = 0.179811, Val Loss = 0.275517


Epoch 223/300: 100%|██████████| 440/440 [00:02<00:00, 162.33it/s]


Epoch 223: Train Loss = 0.179373, Val Loss = 0.275920


Epoch 224/300: 100%|██████████| 440/440 [00:02<00:00, 173.92it/s]


Epoch 224: Train Loss = 0.179593, Val Loss = 0.275551


Epoch 225/300: 100%|██████████| 440/440 [00:02<00:00, 173.66it/s]


Epoch 225: Train Loss = 0.179144, Val Loss = 0.274653


Epoch 226/300: 100%|██████████| 440/440 [00:02<00:00, 172.71it/s]


Epoch 226: Train Loss = 0.179341, Val Loss = 0.275682


Epoch 227/300: 100%|██████████| 440/440 [00:02<00:00, 174.00it/s]


Epoch 227: Train Loss = 0.179092, Val Loss = 0.275291


Epoch 228/300: 100%|██████████| 440/440 [00:02<00:00, 171.99it/s]


Epoch 228: Train Loss = 0.179159, Val Loss = 0.275362


Epoch 229/300: 100%|██████████| 440/440 [00:02<00:00, 173.27it/s]


Epoch 229: Train Loss = 0.178460, Val Loss = 0.276175


Epoch 230/300: 100%|██████████| 440/440 [00:02<00:00, 162.73it/s]


Epoch 230: Train Loss = 0.178476, Val Loss = 0.275313


Epoch 231/300: 100%|██████████| 440/440 [00:02<00:00, 172.87it/s]


Epoch 231: Train Loss = 0.178555, Val Loss = 0.275928


Epoch 232/300: 100%|██████████| 440/440 [00:02<00:00, 173.48it/s]


Epoch 232: Train Loss = 0.178958, Val Loss = 0.274887


Epoch 233/300: 100%|██████████| 440/440 [00:02<00:00, 162.87it/s]


Epoch 233: Train Loss = 0.178139, Val Loss = 0.274665


Epoch 234/300: 100%|██████████| 440/440 [00:02<00:00, 173.97it/s]


Epoch 234: Train Loss = 0.178129, Val Loss = 0.275777


Epoch 235/300: 100%|██████████| 440/440 [00:02<00:00, 176.40it/s]


Epoch 235: Train Loss = 0.177864, Val Loss = 0.275047


Epoch 236/300: 100%|██████████| 440/440 [00:02<00:00, 164.72it/s]


Epoch 236: Train Loss = 0.178110, Val Loss = 0.275171


Epoch 237/300: 100%|██████████| 440/440 [00:02<00:00, 173.92it/s]


Epoch 237: Train Loss = 0.177937, Val Loss = 0.274518


Epoch 238/300: 100%|██████████| 440/440 [00:02<00:00, 176.73it/s]


Epoch 238: Train Loss = 0.177841, Val Loss = 0.275217


Epoch 239/300: 100%|██████████| 440/440 [00:02<00:00, 165.26it/s]


Epoch 239: Train Loss = 0.177804, Val Loss = 0.275159


Epoch 240/300: 100%|██████████| 440/440 [00:02<00:00, 173.63it/s]


Epoch 240: Train Loss = 0.177209, Val Loss = 0.275048


Epoch 241/300: 100%|██████████| 440/440 [00:02<00:00, 176.26it/s]


Epoch 241: Train Loss = 0.177489, Val Loss = 0.276051


Epoch 242/300: 100%|██████████| 440/440 [00:02<00:00, 164.30it/s]


Epoch 242: Train Loss = 0.177368, Val Loss = 0.275700


Epoch 243/300: 100%|██████████| 440/440 [00:02<00:00, 175.68it/s]


Epoch 243: Train Loss = 0.177220, Val Loss = 0.275860


Epoch 244/300: 100%|██████████| 440/440 [00:02<00:00, 175.04it/s]


Epoch 244: Train Loss = 0.177204, Val Loss = 0.275676


Epoch 245/300: 100%|██████████| 440/440 [00:02<00:00, 167.66it/s]


Epoch 245: Train Loss = 0.176966, Val Loss = 0.275287


Epoch 246/300: 100%|██████████| 440/440 [00:02<00:00, 178.05it/s]


Epoch 246: Train Loss = 0.176658, Val Loss = 0.275163


Epoch 247/300: 100%|██████████| 440/440 [00:02<00:00, 176.16it/s]


Epoch 247: Train Loss = 0.176492, Val Loss = 0.274846


Epoch 248/300: 100%|██████████| 440/440 [00:02<00:00, 162.33it/s]


Epoch 248: Train Loss = 0.176621, Val Loss = 0.275397


Epoch 249/300: 100%|██████████| 440/440 [00:02<00:00, 177.09it/s]


Epoch 249: Train Loss = 0.176708, Val Loss = 0.275696


Epoch 250/300: 100%|██████████| 440/440 [00:02<00:00, 176.28it/s]


Epoch 250: Train Loss = 0.176915, Val Loss = 0.275290


Epoch 251/300: 100%|██████████| 440/440 [00:02<00:00, 165.89it/s]


Epoch 251: Train Loss = 0.176655, Val Loss = 0.275143


Epoch 252/300: 100%|██████████| 440/440 [00:02<00:00, 175.50it/s]


Epoch 252: Train Loss = 0.175922, Val Loss = 0.275432


Epoch 253/300: 100%|██████████| 440/440 [00:02<00:00, 176.47it/s]


Epoch 253: Train Loss = 0.176430, Val Loss = 0.275438


Epoch 254/300: 100%|██████████| 440/440 [00:02<00:00, 177.83it/s]


Epoch 254: Train Loss = 0.176222, Val Loss = 0.275315


Epoch 255/300: 100%|██████████| 440/440 [00:02<00:00, 175.97it/s]


Epoch 255: Train Loss = 0.176200, Val Loss = 0.275232


Epoch 256/300: 100%|██████████| 440/440 [00:02<00:00, 177.48it/s]


Epoch 256: Train Loss = 0.175610, Val Loss = 0.275252


Epoch 257/300: 100%|██████████| 440/440 [00:02<00:00, 176.81it/s]


Epoch 257: Train Loss = 0.175954, Val Loss = 0.275723


Epoch 258/300: 100%|██████████| 440/440 [00:02<00:00, 166.87it/s]


Epoch 258: Train Loss = 0.175523, Val Loss = 0.275004


Epoch 259/300: 100%|██████████| 440/440 [00:02<00:00, 174.22it/s]


Epoch 259: Train Loss = 0.175199, Val Loss = 0.276166


Epoch 260/300: 100%|██████████| 440/440 [00:02<00:00, 174.14it/s]


Epoch 260: Train Loss = 0.175189, Val Loss = 0.275445


Epoch 261/300: 100%|██████████| 440/440 [00:02<00:00, 163.20it/s]


Epoch 261: Train Loss = 0.175511, Val Loss = 0.276108


Epoch 262/300: 100%|██████████| 440/440 [00:02<00:00, 173.94it/s]


Epoch 262: Train Loss = 0.175515, Val Loss = 0.275243


Epoch 263/300: 100%|██████████| 440/440 [00:02<00:00, 174.19it/s]


Epoch 263: Train Loss = 0.174950, Val Loss = 0.274849


Epoch 264/300: 100%|██████████| 440/440 [00:02<00:00, 165.26it/s]


Epoch 264: Train Loss = 0.175191, Val Loss = 0.275302


Epoch 265/300: 100%|██████████| 440/440 [00:02<00:00, 174.43it/s]


Epoch 265: Train Loss = 0.174977, Val Loss = 0.275771


Epoch 266/300: 100%|██████████| 440/440 [00:02<00:00, 174.73it/s]


Epoch 266: Train Loss = 0.175355, Val Loss = 0.275488


Epoch 267/300: 100%|██████████| 440/440 [00:02<00:00, 165.09it/s]


Epoch 267: Train Loss = 0.174945, Val Loss = 0.276152


Epoch 268/300: 100%|██████████| 440/440 [00:02<00:00, 174.87it/s]


Epoch 268: Train Loss = 0.174900, Val Loss = 0.275727


Epoch 269/300: 100%|██████████| 440/440 [00:02<00:00, 173.77it/s]


Epoch 269: Train Loss = 0.174671, Val Loss = 0.275442


Epoch 270/300: 100%|██████████| 440/440 [00:02<00:00, 164.88it/s]


Epoch 270: Train Loss = 0.174537, Val Loss = 0.275245


Epoch 271/300: 100%|██████████| 440/440 [00:02<00:00, 177.58it/s]


Epoch 271: Train Loss = 0.174342, Val Loss = 0.275096


Epoch 272/300: 100%|██████████| 440/440 [00:02<00:00, 174.32it/s]


Epoch 272: Train Loss = 0.174235, Val Loss = 0.275693


Epoch 273/300: 100%|██████████| 440/440 [00:02<00:00, 165.04it/s]


Epoch 273: Train Loss = 0.174220, Val Loss = 0.275790


Epoch 274/300: 100%|██████████| 440/440 [00:02<00:00, 171.25it/s]


Epoch 274: Train Loss = 0.174552, Val Loss = 0.275480


Epoch 275/300: 100%|██████████| 440/440 [00:02<00:00, 173.66it/s]


Epoch 275: Train Loss = 0.173882, Val Loss = 0.275632


Epoch 276/300: 100%|██████████| 440/440 [00:02<00:00, 162.67it/s]


Epoch 276: Train Loss = 0.174547, Val Loss = 0.275214


Epoch 277/300: 100%|██████████| 440/440 [00:02<00:00, 173.53it/s]


Epoch 277: Train Loss = 0.173867, Val Loss = 0.275976


Epoch 278/300: 100%|██████████| 440/440 [00:02<00:00, 173.33it/s]


Epoch 278: Train Loss = 0.174072, Val Loss = 0.275477


Epoch 279/300: 100%|██████████| 440/440 [00:02<00:00, 165.40it/s]


Epoch 279: Train Loss = 0.173437, Val Loss = 0.275631


Epoch 280/300: 100%|██████████| 440/440 [00:02<00:00, 175.10it/s]


Epoch 280: Train Loss = 0.173686, Val Loss = 0.276042


Epoch 281/300: 100%|██████████| 440/440 [00:02<00:00, 172.76it/s]


Epoch 281: Train Loss = 0.173680, Val Loss = 0.275353


Epoch 282/300: 100%|██████████| 440/440 [00:02<00:00, 164.86it/s]


Epoch 282: Train Loss = 0.173453, Val Loss = 0.275050


Epoch 283/300: 100%|██████████| 440/440 [00:02<00:00, 175.30it/s]


Epoch 283: Train Loss = 0.173497, Val Loss = 0.275135


Epoch 284/300: 100%|██████████| 440/440 [00:02<00:00, 174.45it/s]


Epoch 284: Train Loss = 0.173562, Val Loss = 0.275571


Epoch 285/300: 100%|██████████| 440/440 [00:02<00:00, 171.71it/s]


Epoch 285: Train Loss = 0.173292, Val Loss = 0.275394


Epoch 286/300: 100%|██████████| 440/440 [00:02<00:00, 174.13it/s]


Epoch 286: Train Loss = 0.173177, Val Loss = 0.275209


Epoch 287/300: 100%|██████████| 440/440 [00:02<00:00, 175.42it/s]


Epoch 287: Train Loss = 0.173307, Val Loss = 0.275409


Epoch 288/300: 100%|██████████| 440/440 [00:02<00:00, 171.44it/s]


Epoch 288: Train Loss = 0.173332, Val Loss = 0.275992


Epoch 289/300: 100%|██████████| 440/440 [00:02<00:00, 166.05it/s]


Epoch 289: Train Loss = 0.172828, Val Loss = 0.275305


Epoch 290/300: 100%|██████████| 440/440 [00:02<00:00, 175.02it/s]


Epoch 290: Train Loss = 0.172747, Val Loss = 0.276085


Epoch 291/300: 100%|██████████| 440/440 [00:02<00:00, 173.13it/s]


Epoch 291: Train Loss = 0.173067, Val Loss = 0.275860


Epoch 292/300: 100%|██████████| 440/440 [00:02<00:00, 162.38it/s]


Epoch 292: Train Loss = 0.172558, Val Loss = 0.275644


Epoch 293/300: 100%|██████████| 440/440 [00:02<00:00, 174.96it/s]


Epoch 293: Train Loss = 0.172647, Val Loss = 0.275778


Epoch 294/300: 100%|██████████| 440/440 [00:02<00:00, 175.62it/s]


Epoch 294: Train Loss = 0.172608, Val Loss = 0.275187


Epoch 295/300: 100%|██████████| 440/440 [00:02<00:00, 164.84it/s]


Epoch 295: Train Loss = 0.172259, Val Loss = 0.275610


Epoch 296/300: 100%|██████████| 440/440 [00:02<00:00, 172.91it/s]


Epoch 296: Train Loss = 0.172228, Val Loss = 0.275931


Epoch 297/300: 100%|██████████| 440/440 [00:02<00:00, 173.74it/s]


Epoch 297: Train Loss = 0.172591, Val Loss = 0.275632


Epoch 298/300: 100%|██████████| 440/440 [00:02<00:00, 163.60it/s]


Epoch 298: Train Loss = 0.172375, Val Loss = 0.275699


Epoch 299/300: 100%|██████████| 440/440 [00:02<00:00, 172.75it/s]


Epoch 299: Train Loss = 0.172019, Val Loss = 0.275631


Epoch 300/300: 100%|██████████| 440/440 [00:02<00:00, 176.17it/s]


Epoch 300: Train Loss = 0.171668, Val Loss = 0.276042


MLPTranslator(
  (net): Sequential(
    (0): Linear(in_features=1000, out_features=1200, bias=True)
    (1): LayerNorm((1200,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Dropout(p=0.15, inplace=False)
    (4): Linear(in_features=1200, out_features=600, bias=True)
    (5): LayerNorm((600,), eps=1e-05, elementwise_affine=True)
    (6): ReLU()
    (7): Dropout(p=0.15, inplace=False)
    (8): Linear(in_features=600, out_features=300, bias=True)
    (9): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
    (10): ReLU()
    (11): Dropout(p=0.15, inplace=False)
    (12): Linear(in_features=300, out_features=1000, bias=True)
  )
)

In [51]:
torch.save({
    'model_state_dict': model.state_dict(),
    'x_anchors': x_anchors,
    'y_anchors': y_anchors
}, 'translator_with_anchors.pth')

In [52]:
with torch.no_grad():
    y_rel_pred = model(x_rel.to(DEVICE)).cpu()  # output del modello nello spazio relativo immagini

# Cosine similarity tra y_rel e y_rel_pred
cos_sim = F.cosine_similarity(y_rel_pred, y_rel, dim=1)
print("Media cosine similarity nello spazio relativo:", cos_sim.mean().item())
print("Cosine similarity min/max:", cos_sim.min().item(), cos_sim.max().item())

Media cosine similarity nello spazio relativo: 0.8538557291030884
Cosine similarity min/max: -0.709624707698822 0.9906015992164612


In [53]:
pseudo_inverse = torch.linalg.pinv(y_anchors.T).to(DEVICE)


def generate_submission(model, test_path: Path, output_file="submission.csv", device=None):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_data.close()
    
    test_embds = torch.from_numpy(test_embds).float()
    test_embds = normalize(center(test_embds)) # like training

    with torch.no_grad():
        x_rel_test = (test_embds @ x_anchors.T).to(DEVICE)
        pred_embds = model(x_rel_test) @ pseudo_inverse

    print('Y abs reconstructed shape:', pred_embds.shape)

    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"✓ Saved submission to {output_file}")

    return df_submission


generate_submission(model, Path(test_path), device=DEVICE)

Y abs reconstructed shape: torch.Size([1500, 1536])
Generating submission file...
✓ Saved submission to submission.csv


Unnamed: 0,id,embedding
0,1,"[0.04971304163336754, 0.003106743097305298, -0..."
1,2,"[-0.022800683975219727, -0.03812825679779053, ..."
2,3,"[-0.007564276456832886, -0.053823813796043396,..."
3,4,"[-0.0014978647232055664, 0.04147624969482422, ..."
4,5,"[0.04639643058180809, -0.0068305134773254395, ..."
...,...,...
1495,1496,"[0.006414793431758881, -0.023131832480430603, ..."
1496,1497,"[0.0017552375793457031, 0.07606494426727295, -..."
1497,1498,"[0.012645125389099121, 0.0034373849630355835, ..."
1498,1499,"[0.02855837345123291, 0.0019985437393188477, -..."
