# SASP score

In [None]:
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

from tgae import TGAE

import proteins_info
import fine_tune
import get_index

# Data Pre-Processing
This is a neccesary section.
If performing fine-tunning, run the fine-tunning section.
If not, directly jump to Encode SASP Score section.

In [None]:
df = pd.read_csv("your_data.csv")
protein_idx, age_idx = proteins_info.clean_data(df, tunning = True)

# This step is only neccessary for cohort apart from UKB
df.iloc[:,protein_idx] = proteins_info.match_ukb_dist(df.iloc[:,protein_idx])

# Fine-tunning
If performing fin-tunning, run this section. Otherwise skip this section in total

## Hyper-parameter

In [None]:
seed=42
random.seed(seed)                         # Python random
np.random.seed(seed)                      # NumPy
torch.manual_seed(seed)                   # PyTorch CPU
torch.cuda.manual_seed(seed)              # PyTorch GPU
torch.cuda.manual_seed_all(seed)          # if multi-GPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Prepare data for fine-tunning

In [None]:

train_idx, valid_idx = fine_tune.split_df(
    n_rows=len(df),
    train_size=0.6,
    valid_size=0.2,
)

train_loader, proteins_embedding = fine_tune.prepare_data(
    train_sample=df.iloc[train_idx, protein_idx],
    train_target=df.iloc[train_idx, age_idx],
    device=device,
)

valid_loader, proteins_embedding = fine_tune.prepare_data(
    train_sample=df.iloc[valid_idx, protein_idx],
    train_target=df.iloc[valid_idx, age_idx],
    device=device,
)

## Fine-tunning

In [None]:
tgae = TGAE(latent_dim=6).to(device)
tgae.load_state_dict(torch.load("tgae_pre_trained.pth", map_location=torch.device('cpu')))

regressor_params = list(tgae.regressor.parameters())
other_params = [p for n, p in tgae.named_parameters() if not n.startswith("regressor")]

optimizer = torch.optim.Adam([
    {'params': regressor_params, 'lr': 1e-3},
    {'params': other_params, 'lr': 1e-4}
])

tgae, train_losses, valid_losses = fine_tune.fine_tune(
    epochs=30,
    tgae=tgae,
    proteins_embedding=proteins_embedding,
    train_loader=train_loader,
    valid_loader=valid_loader,
    optimizer=optimizer,
)

# --- SAVE TRAINED MODEL and loss---
torch.save(tgae.state_dict(), "tgae_fine_tunned.pth")

## Plot plosses

In [None]:
epochs = range(len(train_losses)-1)

plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses[1:], marker='o', label='Training Error')
plt.plot(epochs, valid_losses[1:], marker='o', label='Testing Error')

plt.title('Training and Validation Losses')
plt.xlabel('Epoch')
plt.ylabel('Losses')
plt.legend()
plt.grid(True)
plt.show()

# Encode SASP score

In [None]:
df_raw = df.iloc[:,protein_idx]

## SASP Score from parent model

In [None]:
# Load TGAE model and data
tgae = TGAE(
    d_model=128,
    latent_dim = 6,
    nhead=8,
    num_layers=2,
).to(device)
tgae.load_state_dict(torch.load("tgae_pre_trained.pth", map_location=torch.device('cpu')))

sasp_raw = get_index.gen_index(
    raw_df=df_raw,
    tgae=tgae,
    device=device
)

sasp_raw = pd.DataFrame({'sasp_score_raw': sasp_raw})
df_combined = pd.concat([ID, Age, sasp_raw], axis=1)
df_combined.to_csv("sasp_score.csv", index=False)

## SASP score from fine-tunned model
Skip this section if not performing fine-tunning.

In [None]:
# Load TGAE model and data
tgae = TGAE(
    d_model=128,
    latent_dim = 6,
    nhead=8,
    num_layers=2,
).to(device)
tgae.load_state_dict(torch.load("tgae_fine_tunned.pth", map_location=torch.device('cpu')))

sasp_tuned = get_index.gen_index(
    raw_df=df_raw,
    tgae=tgae,
    device=device
)

sasp_tuned = pd.DataFrame({'sasp_score_tuned': sasp_tuned})
df_combined = pd.concat([ID, Age, sasp_tuned], axis=1)
df_combined.to_csv("sasp_score.csv", index=False)

In [None]:
# Check correlation
ID = df.iloc[:,0]
Age = df.iloc[:,1]
print(np.corrcoef(Age, sasp_raw))
print(np.corrcoef(Age, sasp_tuned))