In [1]:
from pathlib import Path

import h5py
import numpy as np
from matplotlib import rc
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from qso.utils import load_ds, predict, WAVES

In [2]:
rc("font", family="serif")
rc("text", usetex=True)
rc("font", size=16)

In [3]:
ds_file = Path("data") / "dataset.hdf5"
bs_eval = 2048

src_ds_va = load_ds(ds_file, "sdss", va=True)
src_dl_va = DataLoader(src_ds_va, batch_size=bs_eval)

trg_ds_va = load_ds(ds_file, "lamost", va=True)
trg_dl_va = DataLoader(trg_ds_va, batch_size=bs_eval)

In [4]:
with h5py.File(ds_file, "r") as f:
    trg_id_va = f["lamost/id_va"][:]
    src_id_va = f["sdss/id_va"][:]

In [5]:
class LeNet_5(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 32, 5, padding=2)
        self.pool1 = nn.MaxPool1d(16, 16)
        self.conv2 = nn.Conv1d(32, 48, 5, padding=2)
        self.pool2 = nn.MaxPool1d(16, 16)
        self.fc1 = nn.Linear(48 * 14, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 1)
    
    def forward_conv(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        return x

    def forward_fc1(self, x):
        x = self.forward_conv(x)
        x = F.relu(self.fc1(x))
        return x

    def forward_fc2(self, x):
        x = self.forward_fc1(x)
        x = F.relu(self.fc2(x))
        return x

    def forward(self, x):
        x = self.forward_fc2(x)
        x = self.fc3(x)
        return x

def mmd(src_f, trg_f):
    mmd = src_f.sum(axis=0) / src_f.size(0) - trg_f.sum(axis=0) / trg_f.size(0)
    return mmd @ mmd.T

def eval_layer(layer, src_dl, trg_dl, bottleneck, dev):
    assert src_dl.batch_size == trg_dl.batch_size
    bs = src_dl.batch_size
    src_fs = torch.zeros(src_dl.dataset.tensors[0].size(0), bottleneck)
    trg_fs = torch.zeros(trg_dl.dataset.tensors[0].size(0), bottleneck)
    for i, ((src_xb, _), (trg_xb, _)) in enumerate(zip(src_dl, trg_dl)):
        start = i * bs
        end = start + bs
        src_fs[start:end] = layer(src_xb.to(dev)).squeeze()
        trg_fs[start:end] = layer(trg_xb.to(dev)).squeeze()
    return torch.sqrt(mmd(src_fs, trg_fs)).item()

dev = torch.device("cuda")
lenet = LeNet_5().to(dev)
lenet_pt = Path("models") / "lenet-5.pt"
lenet.load_state_dict(torch.load(lenet_pt))

<All keys matched successfully>

In [6]:
with torch.no_grad():
    src_true_va, src_pred_va = predict(lenet, src_dl_va, dev)
    trg_true_va, trg_pred_va = predict(lenet, trg_dl_va, dev)

In [None]:
fig_dir = Path("figs")

# Source Errors

In [7]:
src_X = src_ds_va.tensors[0].squeeze()
src_y_true = src_true_va.bool()
src_y_pred = src_pred_va > 0

confusion_matrix(src_y_true, src_y_pred)

array([[42916,   532],
       [  274,  6278]])

In [8]:
src_fn_idx = (src_y_true == True) & (src_y_pred == False)
src_fn_X = src_X[src_fn_idx]
src_fn_id = src_id_va[src_fn_idx]

src_fp_idx = (src_y_true == False) & (src_y_pred == True)
src_fp_X = src_X[src_fp_idx]
src_fp_id = src_id_va[src_fp_idx]
src_fn_X.size(), src_fp_X.size()

(torch.Size([274, 3659]), torch.Size([532, 3659]))

In [12]:
import pandas as pd

In [None]:
plot_random(src_fn, 6)
plt.savefig(str(fig_dir / "source_fn.pdf"))

In [None]:
plot_random(src_fp, 6)
plt.savefig(str(fig_dir / "source_fp.pdf"))

# Target Errors 

In [21]:
trg_X = trg_ds_va.tensors[0].squeeze()
trg_y_true = trg_true_va.bool()
trg_y_pred = trg_pred_va > 0

confusion_matrix(trg_y_true, trg_y_pred)

array([[48873,   937],
       [   44,   146]])

In [22]:
trg_fn_idx = (trg_y_true == True) & (trg_y_pred == False)
trg_fn_X = trg_X[trg_fn_idx]
trg_fn_id = trg_id_va[trg_fn_idx]

trg_fp_idx = (trg_y_true == False) & (trg_y_pred == True)
trg_fp_X = trg_X[trg_fp_idx]
trg_fp_id = trg_id_va[trg_fp_idx]
trg_fn_X.size(), trg_fp_X.size()

(torch.Size([44, 3659]), torch.Size([937, 3659]))

In [None]:
str(fig_dir / trg_fn_id[idx].decode()[:-8]) + ".pdf"

In [None]:
SIZE = 15
idxs = np.random.choice(trg_fp_X.size(0), size=SIZE, replace=False)

for idx in idxs:
    plt.figure(figsize=[2 * 6.4, 4.8 * 0.6])
    plt.plot(WAVES, trg_fp_X[idx])
    plt.ylabel("Scaled flux")
    plt.xlabel("Vacuum wavelength (\AA{})")
    plt.tight_layout()
    plt.savefig(str(fig_dir / trg_fp_id[idx].decode()[:-8]) + ".pdf")

In [None]:
def plot_spec(specs, size):
    fig, axs = plt.subplots(nrows=size, sharex=True,
                            figsize=[6.4, (size / 3) * 4.8])
    spec_idxs = np.random.choice(specs.size(0), size=size,
                                 replace=False)
    for ax, spec in zip(axs.flatten(), specs[spec_idxs]):
        ax.plot(WAVES, spec)
        ax.set_ylabel("Scaled flux")
    ax.set_xlabel("Wavelength (\AA{})")