In [16]:
import torch
import torch.nn as nn
import random
import pandas as pd
from tqdm import tqdm

import numpy as np
import datetime
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from pathlib import Path

In [17]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Device:', DEVICE)

data_path= Path('/kaggle/input/aml-competition/train/train/train.npz')
test_path= Path('/kaggle/input/aml-competition/test/test/test.clean.npz')

data = np.load(data_path)

caption_embeddings = data['captions/embeddings']
image_embeddings = data['images/embeddings']
caption_labels = data['captions/label']

data.close()

x_abs = torch.from_numpy(caption_embeddings) # captions space
y_abs = torch.from_numpy(image_embeddings[np.argmax(caption_labels, axis=1)]) # images space

print('Captions latent space:', x_abs.shape)
print('Images latent space:', y_abs.shape)

Device: cuda
Captions latent space: torch.Size([125000, 1024])
Images latent space: torch.Size([125000, 1536])


In [18]:
def center(data: torch.Tensor):
    return data - data.mean(dim=0, keepdim=True)


def normalize(data: torch.Tensor):
    return F.normalize(data, p=2, dim=1)


def extract_anchors(x_abs: torch.Tensor, y_abs: torch.Tensor, anchors_number: int):
    indices = torch.randperm(x_abs.size(0))[:anchors_number]
    
    return x_abs[indices], y_abs[indices]


def releative_representation(data: torch.Tensor, anchors: torch.Tensor):
    return F.normalize(data, p=2, dim=1) @ F.normalize(anchors.T, p=2, dim=1)

def procrustes_align(X, Y, scale=True):
    mu_X = X.mean(dim=0, keepdim=True)
    mu_Y = Y.mean(dim=0, keepdim=True)

    X_centered = X - mu_X
    Y_centered = Y - mu_Y

    C = X_centered.T @ Y_centered

    U, S, Vt = torch.linalg.svd(C, full_matrices=True)
    R = U @ Vt

    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = U @ Vt

    if scale:
        s = S.sum() / (X_centered ** 2).sum()
    else:
        s = 1.0

    t = mu_Y.squeeze() - s * (mu_X @ R)

    return R, s, t

def align_matrix(m: torch.Tensor, R, s, t):
    return s * (m @ R) + t

In [19]:
anchors_number = 1000

x = normalize(center(x_abs))
y = normalize(center(y_abs))

x_anchors, y_anchors = extract_anchors(x, y, anchors_number)

x_rel = x @ x_anchors.T
y_rel = y @ y_anchors.T

print('Rank X anchors', torch.linalg.matrix_rank(x_anchors).item())
print('Rank Y anchors', torch.linalg.matrix_rank(y_anchors).item())

R, s, t = procrustes_align(x_rel, y_rel)

x_rel_aligned = align_matrix(x_rel, R, s, t)

print("Mean squared distance after alignment:", ((x_rel_aligned - y_rel) ** 2).mean().item())

Rank X anchors 1000
Rank Y anchors 988
Mean squared distance after alignment: 0.009509388357400894


In [20]:
torch.save({
    'x_anchors': x_anchors,
    'y_anchors': y_anchors,
    'procrustes_R': R,
    'procrustes_s': s,
    'procrustes_t': t
}, 'anchors_and_procrustes.pth')

In [21]:
def generate_submission(test_path: Path, output_file_name="submission"):
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    output_file = f'{output_file_name}--{timestamp}.csv'
    
    test_data = np.load(test_path)
    
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    
    test_data.close()
    
    test_embds = torch.from_numpy(test_embds).float()
    test_embds = normalize(center(test_embds)) # like training
    
    pseudo_inverse = torch.linalg.pinv(y_anchors.T)

    x_rel_test = (test_embds @ x_anchors.T)
    x_rel_test = align_matrix(x_rel_test, R, s, t)
    
    pred_embds = x_rel_test @ pseudo_inverse

    print('Y abs reconstructed shape:', pred_embds.shape)

    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"✓ Saved submission to {output_file}")

    return pred_embds


pred_embds = generate_submission(test_path)

Y abs reconstructed shape: torch.Size([1500, 1536])
Generating submission file...
✓ Saved submission to submission--2025-10-30_14-23-27.csv
