In [None]:
from itertools import count

import h5py
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter

from qso import dataset, utils

In [None]:
BATCH_SIZE = 64

lamost = h5py.File(dataset.LAMOST_DATASET, "r")
sdss = h5py.File(dataset.SDSS_DATASET, "r")

source_X = sdss["X_va"][...].reshape(-1, 1, utils.N_WAVELENGTHS)
source_y = sdss["y_va"][...].astype("f4")
source_tensors = list(map(torch.from_numpy, [source_X, source_y]))
source_ds = TensorDataset(*source_tensors)

target_X = lamost["X_va"][...].reshape(-1, 1, utils.N_WAVELENGTHS)
target_ds = TensorDataset(torch.from_numpy(target_X))

source_dl = DataLoader(source_ds, batch_size=BATCH_SIZE, shuffle=True)
target_dl = DataLoader(target_ds, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 32, 5, padding=2)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(32, 48, 5, padding=2)
        self.pool2 = nn.MaxPool1d(2, 2)
        self.fc1 = nn.Linear(43872, 100)
        self.fc2 = nn.Linear(100, 100)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc3 = nn.Linear(100, 1)

    def forward(self, x):
        x = F.relu(x)
        return self.fc3(x)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(100, 100)
        self.fc2 = nn.Linear(100, 48 * 914)
        self.upsample1 = nn.Upsample(size=1829)
        self.conv1 = nn.Conv1d(48, 32, 5, padding=2)
        self.upsample2 = nn.Upsample(size=3659)
        self.conv2 = nn.Conv1d(32, 1, 5, padding=2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = x.view(-1, 48, 914)
        x = self.upsample1(x)
        x = F.relu(self.conv1(x))
        x = self.upsample2(x)
        return self.conv2(x)

class DRCN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.predictor = Predictor()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.predictor(x)
        return x

In [None]:
# typically, the optimal value was in the range [0.4, 0.7]
LAMBDA = 0.55

writer = SummaryWriter("runs/drcn")
dev = torch.device("cuda")

drcn = DRCN().to(dev)
drcn.encoder.apply(utils.init_weights)
drcn.predictor.apply(utils.init_weights)
drcn.decoder.apply(utils.init_weights)

opt_c = optim.Adam([
    {"params": drcn.encoder.parameters()},
    {"params": drcn.predictor.parameters()}
])
opt_r = optim.Adam([
    {"params": drcn.encoder.parameters()},
    {"params": drcn.decoder.parameters()}
])

classification_loss = nn.BCEWithLogitsLoss()
reconstruction_loss = nn.MSELoss()

iterator_c = count(1)
iterator_r = count(1)

In [None]:
for i, (source_data, source_label) in zip(iterator_c, source_dl):
    source_data, source_label = source_data.to(dev), source_label.to(dev)
    opt_c.zero_grad()
    pred = drcn.predictor(drcn.encoder(source_data))
    loss_c = LAMBDA * classification_loss(pred, source_label.unsqueeze(-1))
    loss_c.backward()
    opt_c.step()
    writer.add_scalar("loss/classification", loss_c, i)

In [None]:
for i, (target_data, ) in zip(iterator_r, target_dl):
    target_data = target_data.to(dev)
    opt_r.zero_grad()
    output = drcn.decoder(drcn.encoder(target_data))
    loss_r = (1 - LAMBDA) * reconstruction_loss(output, target_data)
    loss_r.backward()
    opt_r.step()
    writer.add_scalar("loss/reconstruction", loss_r, i)