In [None]:
from itertools import count

import h5py
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

from qso import dataset, utils

In [None]:
def cov(data, dev):
    n = data.size(0)
    cov = torch.ones((1, n)).to(dev) @ data
    return (data.t() @ data - (cov.t() @ cov) / n) / (n - 1)

def coral(source_data, target_data, dev):
    d = source_data.size(1)
    source_cov = cov(source_data, dev)
    target_cov = cov(target_data, dev)    
    return (source_cov - target_cov).pow(2).sum() / (4 * d * d)

In [None]:
BATCH_SIZE = 64

lamost = h5py.File(dataset.LAMOST_DATASET, "r")
sdss = h5py.File(dataset.SDSS_DATASET, "r")

source_X = sdss["X_va"][...].reshape(-1, 1, utils.N_WAVELENGTHS)
source_y = sdss["y_va"][...].astype("f4")
source_tensors = list(map(torch.from_numpy, [source_X, source_y]))
source_ds = TensorDataset(*source_tensors)

target_X = lamost["X_va"][...].reshape(-1, 1, utils.N_WAVELENGTHS)
target_ds = TensorDataset(torch.from_numpy(target_X))

source_dl = DataLoader(source_ds, batch_size=BATCH_SIZE, shuffle=True)
target_dl = DataLoader(target_ds, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 32, 5, padding=2)
        self.pool1 = nn.MaxPool1d(2, 2)
        self.conv2 = nn.Conv1d(32, 48, 5, padding=2)
        self.pool2 = nn.MaxPool1d(2, 2)
        self.fc1 = nn.Linear(48 * 914, 100)
        self.fc2 = nn.Linear(100, 100)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc3 = nn.Linear(100, 1)
    
    def forward(self, x):
        x = F.relu(x)
        return self.fc3(x)

class DeepCORAL(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.predictor = Predictor()
        
    def forward(self, x):
        x = self.feature_extractor(x)
        return self.predictor(x)

In [None]:
LAMBDA = 1000   # TODO classification and CORAL loss should be similar
writer = SummaryWriter("runs/deep_coral")
dev = torch.device("cuda")

deep_coral = DeepCORAL().to(dev)
deep_coral.feature_extractor.apply(utils.init_weights)
deep_coral.predictor.apply(utils.init_weights)

opt = optim.Adam(deep_coral.parameters())
iterator = count(1)

In [None]:
for i, (source_data, source_label), (target_data,) in zip(iterator, source_dl, target_dl):
    source_data, source_label = source_data.to(dev), source_label.to(dev)
    target_data = target_data.to(dev)
    
    opt.zero_grad()
    
    source_features = deep_coral.feature_extractor(source_data)
    pred = deep_coral.predictor(source_features)
    target_features = deep_coral.feature_extractor(target_data)
    
    loss_class = F.binary_cross_entropy_with_logits(pred, source_label.unsqueeze(-1))
    loss_coral = coral(source_features, target_features, dev)
    loss = loss_class + LAMBDA * loss_coral

    loss.backward()
    opt.step()

    writer.add_scalar("loss/class", loss_class, i)
    writer.add_scalar("loss/coral", LAMBDA * loss_coral, i)