In [1]:
import os

import torch
import torch.nn as nn
from torchvision.models import resnet18

from models import (
    NeuralImageCompressor,
    create_resnet_autoencoder,
    create_resnet_autoencoder_abs,
    SimpleResidualDecoder32x_ABS)


def save_separate_weights(autoencoder: NeuralImageCompressor, full_weights_path: str) -> None:
    w_dir, autoencoder_file_name = os.path.split(full_weights_path)
    if autoencoder_file_name[:6] == "full__":
        file_name_basis = autoencoder_file_name[6:]
    else:
        file_name_basis = autoencoder_file_name
    
    autoencoder.load_state_dict(
        torch.load(full_weights_path, map_location=torch.device('cpu')))

    model_and_model_type = [
        (autoencoder.encoder, "encoder"),
        (autoencoder.decoder, "decoder")
    ]

    for model, model_type in model_and_model_type:
        torch.save(
            model.state_dict(),
            os.path.join(w_dir, f"{model_type}__{file_name_basis}"))

In [2]:
decoder_in_channels = 512
up_func_name = "upsample"
last_decoder_activation = nn.ReLU()

decoder = SimpleResidualDecoder32x_ABS(
    decoder_in_channels,
    up_func_name = up_func_name,
    last_activation=last_decoder_activation)

In [17]:
autoencoder = create_resnet_autoencoder(resnet18(), enc_feat_extract = nn.Identity(), decoder = decoder)
weights = "./weights/full__resnet_autoencoder__512x16x16__upsample__B_6__66_epochs__2023-06-10T00_39.pt"

In [18]:
save_separate_weights(autoencoder, weights)

In [2]:
autoencoder = create_resnet_autoencoder_abs(
    resnet18(), enc_feat_extract = nn.Identity(), last_decoder_activation=nn.ReLU(inplace=True))
weights = "./weights/resnet_autoencoder_abs/full__resnet_autoencoder__512x16x16__upsample__B_6__6_epochs__last_relu__2023-06-18T17_42.pt"

In [3]:
save_separate_weights(autoencoder, weights)