In [39]:

# KL divergence per row between the three files (A,B,C)
import torch, numpy as np, pandas as pd

EPS = 1e-8

def load_2d_tensor(path: str) -> torch.Tensor:
    obj = torch.load(path, map_location='cpu')
    if torch.is_tensor(obj):
        t = obj
    elif isinstance(obj, (list, tuple)):
        t = None
        for item in obj:
            if torch.is_tensor(item) and item.ndim == 2:
                t = item
                break
        if t is None:
            raise ValueError(f"{path}: Could not find a 2D tensor inside the saved object.")
    else:
        raise TypeError(f"{path}: Expected a tensor or list/tuple, got {type(obj)}")
    if t.ndim != 2:
        raise ValueError(f"{path}: Expected a 2D tensor [N,K], got shape {tuple(t.shape)}")
    return t.float().cpu()

def to_probs(x: torch.Tensor, assume_logits: bool) -> torch.Tensor:
    if assume_logits:
        return torch.softmax(x, dim=1)
    s = x.sum(dim=1, keepdim=True).clamp_min(1e-12)
    return (x / s).clamp_min(0.0)

def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = EPS) -> float:
    p = p.astype(np.float64) + eps
    q = q.astype(np.float64) + eps
    p /= p.sum(); q /= q.sum()
    return np.sum(p * (np.log(p) - np.log(q)))

# File paths (A,B,C)
path_A = 'forget_lacuna5_allcnn_final_all_outputs_tensor.pt'
path_B = 'scrub_forget_lacuna5_allcnn_final_all_outputs_tensor.pt'
path_C = 'forget_set_all_probs_tensor.pt'

A = to_probs(load_2d_tensor(path_A), True).numpy()
B = to_probs(load_2d_tensor(path_B), True).numpy()
C = to_probs(load_2d_tensor(path_C), False).numpy()

n = min(len(A), len(B), len(C))
if len(A) != len(B) or len(A) != len(C):
    print(f'Warning: different lengths A={len(A)} B={len(B)} C={len(C)}; using n={n}')
A, B, C = A[:n], B[:n], C[:n]

kl_A_B = np.array([kl_divergence(A[i], B[i]) for i in range(n)])
kl_A_C = np.array([kl_divergence(A[i], C[i]) for i in range(n)])
kl_B_C = np.array([kl_divergence(B[i], C[i]) for i in range(n)])

import pandas as pd
rows = np.arange(n)
df = pd.DataFrame({'row': rows, 'KL(A||B)': kl_A_B, 'KL(A||C)': kl_A_C, 'KL(B||C)': kl_B_C})
try:
    display(df.head())
except Exception:
    print(df.head().to_string(index=False))

means = {
    'mean KL(A||B)': float(kl_A_B.mean()),
    'mean KL(A||C)': float(kl_A_C.mean()),
    'mean KL(B||C)': float(kl_B_C.mean()),
}
print('Means:', means)

df.to_csv('kl_rowwise.csv', index=False)
print('Saved per-row KL to kl_rowwise.csv')


Unnamed: 0,row,KL(A||B),KL(A||C),KL(B||C)
0,0,1.329944,2.145773,4.982303
1,1,0.263995,0.813363,0.086874
2,2,9.868894,6.576357,0.000495
3,3,1.601833,4.639836,0.218614
4,4,0.913462,2.841591,2.36361


Means: {'mean KL(A||B)': 1.9119363137255536, 'mean KL(A||C)': 3.6501410131554355, 'mean KL(B||C)': 1.3546289544872547}
Saved per-row KL to kl_rowwise.csv
