In [None]:
from itertools import count

import h5py
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

from qso import dataset

In [None]:
def coral(src_data, trg_data, device):
    d = src_data.size(1)
    n_src, n_trg = src_data.size(0), trg_data.size(0)
    
    src_cov = torch.ones((1, n_src)).to(device) @ src_data
    src_cov = (
        src_data.t() @ src_data - (src_cov.t() @ src_cov) / n_src
    ) / (n_src - 1)
    
    trg_cov = torch.ones((1, n_trg)).to(device) @ trg_data
    trg_cov = (
        trg_data.t() @ trg_data - (trg_cov.t() @ trg_cov) / n_trg
    ) / (n_trg - 1)
    
    return (src_cov - trg_cov).pow(2).sum() / (4 * d * d)

In [None]:
lamost = h5py.File(dataset.LAMOST_DATASET, "r")
sdss = h5py.File(dataset.SDSS_DATASET, "r")

In [None]:
src_set = dataset.HDF5Dataset(sdss["X_va"], sdss["y_va"])

In [None]:
class TargetDataset(Dataset):
    def __init__(self, X):
        self.X = X
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx].reshape(1, -1)


trg_set = TargetDataset(lamost["X_va"][...])

In [None]:
src_loader = DataLoader(src_set, batch_size=64, shuffle=True)
trg_loader = DataLoader(trg_set, batch_size=64, shuffle=True)

In [None]:
class DeepCORAL(nn.Module):
    def __init__(self):
        super(DeepCORAL, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=5)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(32, 48, 5)
        self.pool2 = nn.MaxPool1d(2, 2)
        self.fc1 = nn.Linear(48 * 911, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 1)
    
    def features(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x
    
    def output(self, x):
        x = F.relu(x)
        x = self.fc3(x)
        return x
        
    def forward(self, x):
        x = self.fetures(x)
        x = self.output(x)
        return x


def init_weights(m):
    if type(m) == nn.Conv1d:
        nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0)
    elif type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0)

In [None]:
LAMBDA = 1000
writer = SummaryWriter("runs/deep_coral")
device = torch.device("cuda")
deep_coral = DeepCORAL().to(device)
deep_coral.apply(init_weights)
optimizer = optim.Adam(deep_coral.parameters())
iterator = count(1)

In [None]:
deep_coral.train()
for i, (src_data, src_label), trg_data in zip(iterator, src_loader, trg_loader):
    src_data, src_label = src_data.to(device), src_label.view(-1, 1).to(device)
    trg_data = trg_data.to(device)
    
    optimizer.zero_grad()
    
    src_features = deep_coral.features(src_data)
    output = deep_coral.output(src_features)
    trg_features = deep_coral.features(trg_data)
    
    loss_class = F.binary_cross_entropy_with_logits(output, src_label)
    loss_coral = coral(src_features, trg_features, device)
    loss = loss_class + LAMBDA * loss_coral
    loss.backward()
    
    optimizer.step()
    
    writer.add_scalar("loss/class", loss_class, i)
    writer.add_scalar("loss/coral", LAMBDA * loss_coral, i)