In [None]:
from vacation.model import VCNN
from vacation.data import GalaxyDataset
import numpy as np
import torch

rng = np.random.default_rng(1337)

In [None]:
train_ds = GalaxyDataset(
    path="/scratch/tgross/vacation_data/reduced_size/Galaxy10_DECals_train.h5",
    device="cuda:1",
    max_cache_size="14G",
    cache_loaded=True,
    index_collection=rng.integers(0, 16813, 6000),
)

In [None]:
valid_ds = GalaxyDataset(
    path="/scratch/tgross/vacation_data/reduced_size/Galaxy10_DECals_valid.h5",
    device="cuda:1",
    max_cache_size="5G",
    cache_loaded=True,
    index_collection=rng.integers(0, 4204, 1000),
)

In [None]:
# train_ds.plot_distribution()
# valid_ds.plot_distribution()

In [None]:
def calculate_size(dim, kernel_size, padding, stride):
    return ((dim - kernel_size + 2 * padding) / stride) + 1


def calculate_network(
    input_dim,
    num_conv_blocks,
    conv_kernel_size,
    conv_padding,
    conv_stride,
    pool_kernel_size,
    pool_padding,
    pool_stride,
):
    sizes = []
    for i in range(0, num_conv_blocks):
        dim = input_dim if i == 0 else sizes[-1]
        if dim < conv_kernel_size:
            raise ValueError(
                f"The image size after layer {i} is smaller than the convolution kernel!"
            )

        sizes.append(
            calculate_size(
                dim=dim,
                kernel_size=conv_kernel_size,
                padding=conv_padding,
                stride=conv_stride,
            )
        )

        if sizes[-1] < pool_kernel_size:
            raise ValueError(
                f"The image size after the convolution of layer {i+1} is smaller than the pooling kernel!"
            )

        sizes.append(
            calculate_size(
                dim=sizes[-1],
                kernel_size=pool_kernel_size,
                padding=pool_padding,
                stride=pool_stride,
            )
        )

        print("-------------- Layer", (i + 1), "--------------")
        print(f"POST-CONV: {sizes[-2]} | POST-POOL: {sizes[-1]}")

        if not sizes[-1].is_integer() or not sizes[-2].is_integer():
            raise ValueError(
                f"An image size after layer {i+1} is not an integer value!"
            )

In [None]:
s1_1 = calculate_network(128, 5, 3, 0, 1, 2, 1, 2)

In [None]:
model = VCNN(
    train_batch_size=int(2**5),
    valid_batch_size=int(2**3),
    num_conv_blocks=6,
    num_dense_layers=1,
    out_channels=[1, 11, 12, 12, 12, 12],
    conv_dropout_rates=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    lin_out_features=[300],
    lin_dropout_rates=[0.0],
    optimizer=torch.optim.AdamW,
    activation_func=torch.nn.PReLU,
    learning_rate=0.001,
    weight_decay=0.01,
    loss_func=torch.nn.CrossEntropyLoss,
    pool_kernel_args={"kernel_size": 2, "padding": 1, "stride": 2},
    device="cuda:1",
)
model.init_data(train_dataset=train_ds, valid_dataset=valid_ds)

In [None]:
model.summarize(input_dims=(int(2**5), 3, 128, 128))

In [None]:
model.train_epochs(n_epochs=40)

In [None]:
model.save_state(".models/model2.pt", relative_to_package=True)

In [None]:
model1 = VCNN.load(
    "model2.pt",
    optimizer=torch.optim.AdamW,
    activation_func=torch.nn.PReLU,
    loss_func=torch.nn.CrossEntropyLoss,
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(model1._metrics["accuracy"].train_vals, label="Train")
plt.plot(model1._metrics["accuracy"].valid_vals, label="Valid")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()

In [None]:
plt.plot(model1._loss_metric.train_vals)
plt.plot(model1._loss_metric.valid_vals)