In [None]:
from itertools import count

import h5py
import numpy as np
import torch
from torch import nn
from torch import optim
from torch import autograd
from torch.autograd import gradcheck
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

from qso import dataset
from qso import utils

In [None]:
class GradientReversalFunction(autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x
        
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

rev_grad = GradientReversalFunction.apply

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=5)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(32, 48, 5)
        self.pool2 = nn.MaxPool1d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        return x

class LabelPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(48 * 911, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

class DomainClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(48 * 911, 100)
        self.fc2 = nn.Linear(100, 1)

    def forward(self, x):
        x = rev_grad(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

class DANN(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.label_predictor = LabelPredictor()
        self.domain_classifier = DomainClassifier()
    
    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.label_predictor(x)
        return x

In [None]:
BATCH_SIZE = 64

In [None]:
lamost = h5py.File(dataset.LAMOST_DATASET, "r")
sdss = h5py.File(dataset.SDSS_DATASET, "r")

source_X = sdss["X_va"][...].reshape(-1, 1, utils.N_WAVELENGTHS)
source_y = sdss["y_va"][...].astype("f4")
source_tensors = list(map(torch.from_numpy, [source_X, source_y]))
source_ds = TensorDataset(*source_tensors)

target_X = lamost["X_va"][...].reshape(-1, 1, utils.N_WAVELENGTHS)
target_ds = TensorDataset(torch.from_numpy(target_X))

source_dl = DataLoader(source_ds, batch_size=BATCH_SIZE, shuffle=True)
target_dl = DataLoader(target_ds, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
def init_weights(m):
    if type(m) == nn.Conv1d:
        nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0)
    elif type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0)

In [None]:
GAMMA = 10
p = 1    # TODO compute
lam = (2 / (1 + np.exp(-GAMMA * p))) - 1

writer = SummaryWriter("runs/dann")
dev = torch.device("cuda")

dann = DANN().to(dev)
dann.feature_extractor.apply(init_weights)
dann.label_predictor.apply(init_weights)
dann.domain_classifier.apply(init_weights)

opt = optim.Adam(dann.parameters())

prediction_loss = F.binary_cross_entropy_with_logits
domain_loss = F.binary_cross_entropy_with_logits

iterator = count(1)

In [None]:
lam

In [None]:
domain_label = torch.cat([torch.zeros(BATCH_SIZE), torch.ones(BATCH_SIZE)]).unsqueeze(-1).to(dev)

for i, (source_data, source_label), (target_data,) in zip(iterator, source_dl, target_dl):
    source_data, source_label = source_data.to(dev), source_label.to(dev)
    target_data = target_data.to(dev)

    opt.zero_grad()

    # concat source and target data
    data = torch.cat([source_data, target_data])
    features = dann.feature_extractor(data)
    domain_pred = dann.domain_classifier(features)
    loss_dom = domain_loss(domain_pred, domain_label)
   
    pred = dann(source_data)
    loss_pred = prediction_loss(pred, source_label.unsqueeze(-1))

    loss = loss_pred + lam * loss_dom
    
    loss.backward()
    opt.step()
    
    writer.add_scalar("loss/domain", loss_dom, i)
    writer.add_scalar("loss/prediction", loss_pred, i)