In [None]:
# ==========================================
# 0. Imports & Device
# ==========================================
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 재현성(선택)
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# ==========================================
# 1. Load Data (train + validation)
# ==========================================
train_seq = pd.read_csv('train_sequences.csv')
val_seq   = pd.read_csv('validation_sequences.csv')

# low_memory=False로 mixed dtype 경고 완화
train_labels = pd.read_csv('train_labels.csv', low_memory=False)
val_labels   = pd.read_csv('validation_labels.csv', low_memory=False)

print(f"Train seq shape: {train_seq.shape}")
print(f"Val   seq shape: {val_seq.shape}")
print(f"Train labels shape: {train_labels.shape}")
print(f"Val   labels shape: {val_labels.shape}")


In [None]:
# ==========================================
# 2. Tokenize sequences (A,C,G,U -> 1..4, PAD=0)
# ==========================================
mapping = {'A': 1, 'C': 2, 'G': 3, 'U': 4}

def tokenize(seq: str):
    return [mapping.get(b, 0) for b in str(seq)]

for df in (train_seq, val_seq):
    df['target_id'] = df['target_id'].astype(str).str.strip()
    df['tokenized'] = df['sequence'].apply(tokenize)

print("Tokenized train/val sequences.")


In [None]:
# ==========================================
# 3. Label preprocessing (train + validation)
#    - train_labels: (x_1,y_1,z_1) 단일 구조
#    - val_labels: (x_1..z_40) 다중 슬롯 + 결측이 -1e18 같은 센티넬로 들어있음
#      => isfinite만으로는 결측을 못 걸러서 abs<threshold 조건을 추가
# ==========================================
XYZ = ['x_1','y_1','z_1']
ABS_THRESH = 1e17   # -1e18 같은 센티넬 결측 제거용
MIN_VALID_POINTS = 30  # 너무 결측이 많은 target 제거(안정성)

def _clean_base(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df['target_id'] = df['ID'].astype(str).str.rsplit('_', n=1).str[0]
    df['resid'] = pd.to_numeric(df['resid'], errors='coerce')
    df = df.dropna(subset=['resid']).copy()
    df['resid'] = df['resid'].astype(int)
    df = df.sort_values(['target_id','resid'])
    return df

def build_coords_from_train_labels(train_labels: pd.DataFrame) -> pd.DataFrame:
    df = _clean_base(train_labels)

    for c in XYZ:
        df[c] = pd.to_numeric(df[c], errors='coerce')
    arr = df[XYZ].to_numpy(dtype=np.float64)

    ok = np.isfinite(arr).all(axis=1) & (np.abs(arr) < ABS_THRESH).all(axis=1)
    df['coord_ok'] = ok.astype(np.float32)

    # 결측은 0으로 채우고 mask로 제외
    df.loc[~ok, XYZ] = 0.0
    df[XYZ] = df[XYZ].astype(np.float32)

    # target 단위 유효 포인트 수 기준 필터
    valid_counts = df.groupby('target_id')['coord_ok'].sum()
    good_ids = valid_counts[valid_counts >= MIN_VALID_POINTS].index
    df = df[df['target_id'].isin(good_ids)].copy()

    coords_df = (df.groupby('target_id')[XYZ]
                 .apply(lambda x: x.to_numpy(np.float32).tolist())
                 .reset_index(name='coordinates'))
    mask_df = (df.groupby('target_id')['coord_ok']
               .apply(lambda x: x.to_numpy(np.float32).tolist())
               .reset_index(name='coord_mask'))
    return coords_df.merge(mask_df, on='target_id', how='inner')

def build_coords_from_val_labels(val_labels: pd.DataFrame, K: int = 40) -> pd.DataFrame:
    df = _clean_base(val_labels)

    # 각 row(residue)마다 k=1..K 중 '처음으로 유효한' 좌표를 선택해 (x_1,y_1,z_1)에 저장
    chosen = np.zeros((len(df), 3), dtype=np.float32)
    ok_mask = np.zeros((len(df),), dtype=np.float32)
    filled = np.zeros((len(df),), dtype=bool)

    for k in range(1, K+1):
        cols = [f'x_{k}', f'y_{k}', f'z_{k}']
        if not all(c in df.columns for c in cols):
            continue
        tmp = df[cols].apply(pd.to_numeric, errors='coerce')
        arr = tmp.to_numpy(dtype=np.float64)

        ok = np.isfinite(arr).all(axis=1) & (np.abs(arr) < ABS_THRESH).all(axis=1)
        take = ok & (~filled)
        if take.any():
            chosen[take] = arr[take].astype(np.float32)
            ok_mask[take] = 1.0
            filled[take] = True

    df['x_1'], df['y_1'], df['z_1'] = chosen[:,0], chosen[:,1], chosen[:,2]
    df['coord_ok'] = ok_mask

    # target 단위 유효 포인트 수 기준 필터
    valid_counts = df.groupby('target_id')['coord_ok'].sum()
    good_ids = valid_counts[valid_counts >= MIN_VALID_POINTS].index
    df = df[df['target_id'].isin(good_ids)].copy()

    coords_df = (df.groupby('target_id')[XYZ]
                 .apply(lambda x: x.to_numpy(np.float32).tolist())
                 .reset_index(name='coordinates'))
    mask_df = (df.groupby('target_id')['coord_ok']
               .apply(lambda x: x.to_numpy(np.float32).tolist())
               .reset_index(name='coord_mask'))
    return coords_df.merge(mask_df, on='target_id', how='inner')

train_coords = build_coords_from_train_labels(train_labels)
val_coords   = build_coords_from_val_labels(val_labels, K=40)

# key 정리
train_coords['target_id'] = train_coords['target_id'].astype(str).str.strip()
val_coords['target_id']   = val_coords['target_id'].astype(str).str.strip()

# seq/coords 합치기
all_seq = pd.concat([train_seq, val_seq], ignore_index=True)
all_coords = pd.concat([train_coords, val_coords], ignore_index=True)

all_df = all_seq.merge(all_coords, on='target_id', how='inner')
print("all_df shape:", all_df.shape)

# max_len 창에서 유효 포인트가 너무 적으면 제거 (Kabsch 안정화)
MAX_LEN = 200
all_df['valid_in_window'] = all_df['coord_mask'].apply(
    lambda m: float(np.sum(np.asarray(m, dtype=np.float32)[:MAX_LEN]))
)
before = len(all_df)
all_df = all_df[all_df['valid_in_window'] >= 30].copy()
print(f"Filtered all_df by valid_in_window>=30: {before} -> {len(all_df)}")

# sanity: NaN/Inf check
def has_nan_inf(coords):
    a = np.asarray(coords, dtype=np.float32)
    return (not np.isfinite(a).all())
print("NaN/Inf coords after cleaning:", all_df['coordinates'].apply(has_nan_inf).sum())


In [None]:
# ==========================================
# 4. Kabsch RMSD Loss (mask supported, and correct SVD handling)
# ==========================================
def kabsch_rotation(P, Q, mask=None):
    """
    P, Q: (B, N, 3)
    mask: (B, N)  1(valid), 0(invalid/pad)
    """
    if mask is None:
        mask = torch.ones(P.shape[0], P.shape[1], device=P.device, dtype=P.dtype)

    mask_exp = mask.unsqueeze(-1)  # (B,N,1)
    mask_sum = mask_exp.sum(dim=1, keepdim=True).clamp_min(1e-8)

    P_mean = (P * mask_exp).sum(dim=1, keepdim=True) / mask_sum
    Q_mean = (Q * mask_exp).sum(dim=1, keepdim=True) / mask_sum

    P_c = (P - P_mean) * mask_exp
    Q_c = (Q - Q_mean) * mask_exp

    H = torch.matmul(P_c.transpose(1, 2), Q_c)  # (B,3,3)

    # torch.linalg.svd returns U,S,Vh
    U, S, Vh = torch.linalg.svd(H, full_matrices=False)
    V = Vh.transpose(1, 2)

    det = torch.det(torch.matmul(V, U.transpose(1, 2)))
    sign = torch.where(det < 0, -torch.ones_like(det), torch.ones_like(det))

    E = torch.eye(3, device=P.device, dtype=P.dtype).unsqueeze(0).repeat(P.shape[0], 1, 1)
    E[:, 2, 2] = sign

    R = torch.matmul(torch.matmul(V, E), U.transpose(1, 2))  # (B,3,3)
    P_aligned = torch.matmul(P_c, R.transpose(1, 2)) + Q_mean

    return P_aligned * mask_exp


class KabschRMSDLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, preds, target, mask):
        """
        preds : (B, K, N, 3)
        target: (B, N, 3)
        mask  : (B, N) 1(valid), 0(invalid/pad)
        """
        B, K, N, _ = preds.shape
        losses = []
        for k in range(K):
            pred_k = preds[:, k, :, :]
            pred_aligned = kabsch_rotation(pred_k, target, mask)

            diff_sq = (pred_aligned - target) ** 2
            sum_sq = (diff_sq * mask.unsqueeze(-1)).sum(dim=(1, 2))  # (B,)
            n_valid = (mask.sum(dim=1) * 3).clamp_min(1.0)           # (B,)
            mse = sum_sq / n_valid
            rmsd = torch.sqrt(mse + 1e-8)
            losses.append(rmsd)

        losses = torch.stack(losses, dim=1)  # (B,K)
        min_loss, _ = torch.min(losses, dim=1)
        return torch.mean(min_loss)

In [None]:
# ==========================================
# 5. Dataset (uses coord_mask to ignore missing labels)
# ==========================================
class RNADataset(Dataset):
    def __init__(self, sequences, coordinates, coord_masks, max_len=200):
        self.sequences = sequences
        self.coordinates = coordinates
        self.coord_masks = coord_masks
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        coords = np.asarray(self.coordinates[idx], dtype=np.float32)
        c_mask = np.asarray(self.coord_masks[idx], dtype=np.float32)

        # 안전하게 길이 맞추기
        L = min(len(seq), coords.shape[0], c_mask.shape[0], self.max_len)

        # seq padding
        seq_padded = np.zeros(self.max_len, dtype=np.int64)
        seq_padded[:L] = np.asarray(seq[:L], dtype=np.int64)

        # coords padding
        coords_padded = np.zeros((self.max_len, 3), dtype=np.float32)
        coords_padded[:L] = coords[:L]

        # 최종 mask: 라벨 유효(coord_ok)만 1 (패딩/결측은 0)
        mask = np.zeros(self.max_len, dtype=np.float32)
        mask[:L] = c_mask[:L]

        return (
            torch.tensor(seq_padded, dtype=torch.long),
            torch.tensor(coords_padded, dtype=torch.float32),
            torch.tensor(mask, dtype=torch.float32),
        )

In [None]:
# ==========================================
# 6. Train/Holdout split & DataLoaders
#    - val을 학습 데이터로 포함했으므로, 별도 holdout을 만든다
# ==========================================
train_idx, val_idx = train_test_split(range(len(all_df)), test_size=0.1, random_state=42)

train_df = all_df.iloc[train_idx].reset_index(drop=True)
hold_df  = all_df.iloc[val_idx].reset_index(drop=True)

print("train_df:", train_df.shape, "hold_df:", hold_df.shape)

train_dataset = RNADataset(
    train_df['tokenized'].values,
    train_df['coordinates'].values,
    train_df['coord_mask'].values,
    max_len=200
)
val_dataset = RNADataset(
    hold_df['tokenized'].values,
    hold_df['coordinates'].values,
    hold_df['coord_mask'].values,
    max_len=200
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)


In [None]:
# ==========================================
# 7. Model (Transformer + padding mask for attention)
# ==========================================
class RNATransformer(nn.Module):
    def __init__(self, n_tokens=5, d_model=128, nhead=8, num_layers=4, dropout=0.1, num_preds=5):
        super().__init__()
        self.embedding = nn.Embedding(n_tokens, d_model, padding_idx=0)
        self.pos_encoder = nn.Parameter(torch.zeros(1, 1000, d_model))

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        self.num_preds = num_preds
        self.fc_coords = nn.Linear(d_model, 3 * num_preds)

    def forward(self, tokens):
        # tokens: (B,T)
        B, T = tokens.shape
        padding_mask = (tokens == 0)  # True at PAD positions

        x = self.embedding(tokens)
        x = x + self.pos_encoder[:, :T, :]
        x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)

        out = self.fc_coords(x)  # (B,T,3*K)
        out = out.view(B, T, self.num_preds, 3).permute(0, 2, 1, 3)  # (B,K,T,3)
        return out

model = RNATransformer(num_preds=5).to(device)
print("Model Initialized (Best-of-5 Output Strategy).")


In [None]:
# ==========================================
# 8. Train (robust batch filtering)
# ==========================================
criterion = KabschRMSDLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-2)  # 더 안정적으로
epochs = 5

print(f"Starting Training for {epochs} epochs...")

for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    n_steps = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for seq, target, mask in pbar:
        seq = seq.to(device)
        target = target.to(device)
        mask = mask.to(device)

        # ===== 핵심: 유효 포인트 부족 샘플 제거 =====
        # Kabsch는 최소 3점(=3 residues) 이상이 있어야 회전 정렬이 의미가 있고 수치적으로 안전함
        valid_counts = mask.sum(dim=1)  # (B,)
        keep = valid_counts >= 30       # 여기서 30은 안정 마진(3보다 크게 추천)

        if keep.sum() < 2:
            # 배치에 남는 샘플이 너무 적으면 그냥 스킵
            continue

        seq = seq[keep]
        target = target[keep]
        mask = mask[keep]

        optimizer.zero_grad(set_to_none=True)

        preds = model(seq)
        loss = criterion(preds, target, mask)

        if not torch.isfinite(loss):
            # 여기까지 왔는데도 NaN이면, 더 강한 방어 로깅
            print("Warning: Loss is NaN/Inf even after filtering. Skipping batch.")
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        train_loss += loss.item()
        n_steps += 1
        pbar.set_postfix({'loss': float(loss.item()), 'kept': int(keep.sum())})

    avg_loss = train_loss / max(1, n_steps)
    print(f"Epoch {epoch+1} Average Train Loss: {avg_loss:.6f} (steps={n_steps})")


In [None]:
# ==========================================
# 9. Validation
# ==========================================
model.eval()
val_loss = 0.0

with torch.no_grad():
    for seq, target, mask in val_loader:
        seq = seq.to(device)
        target = target.to(device)
        mask = mask.to(device)

        preds = model(seq)
        loss = criterion(preds, target, mask)

        if torch.isfinite(loss):
            val_loss += loss.item()

val_rmsd = val_loss / max(1, len(val_loader))
print(f"Validation RMSD (masked): {val_rmsd:.6f}")


In [None]:
# ==========================================
# 10. Visualization (uses mask correctly)
# ==========================================
seq_batch, target_batch, mask_batch = next(iter(val_loader))
seq_batch = seq_batch.to(device)
target_batch = target_batch.to(device)
mask_batch = mask_batch.to(device)

with torch.no_grad():
    pred_batch = model(seq_batch)  # (B,5,T,3)

mask0 = mask_batch[0].detach().cpu().numpy().astype(bool)
t0 = target_batch[0].detach().cpu().numpy()[mask0]     # (L_valid,3)
p0 = pred_batch[0, 0].detach().cpu().numpy()[mask0]    # pred k=0

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(t0[:, 0], label="Actual X", alpha=0.7)
plt.plot(p0[:, 0], label="Pred X (k=0)", linestyle="--", alpha=0.9)
plt.title(f"Actual vs Predicted X (valid points={mask0.sum()})")
plt.xlabel("Valid-point Index")
plt.ylabel("X-coordinate")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

plt.figure(figsize=(12, 6))
for i, name in enumerate(["X", "Y", "Z"]):
    plt.plot(t0[:, i], label=f"Actual {name}", alpha=0.5)
    plt.plot(p0[:, i], label=f"Pred {name}", linestyle="--", alpha=0.8)
plt.title("Actual vs Predicted XYZ (masked)")
plt.legend(ncol=3)
plt.grid(alpha=0.3)
plt.show()

from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection='3d')
ax.plot(t0[:, 0], t0[:, 1], t0[:, 2], label="Actual", alpha=0.7)
ax.plot(p0[:, 0], p0[:, 1], p0[:, 2], label="Pred (k=0)", alpha=0.7, linestyle="--")
ax.set_title("3D Trace (masked)")
ax.legend()
plt.show()
