In [None]:
import numpy as np

import torch
import torch.nn.functional as F

print('PyTorch Version: {}'.format(torch.__version__))
if torch.cuda.is_available():
    print('GPU available: {}'.format(torch.cuda.get_device_name(0)))
    device = torch.device('cuda:0')
else:
    print('GPU unavailable')
    device = torch.device('cpu')

# Model initialization

First load an existing dictionary matrix (for inference or fine-tune) or initialize a new one (for training from scratch)

In [None]:
n_atoms = 3800
feature_dim = 1600

# Load existing dictionary
path_to_dictionary = ""
D = torch.load(path_to_dictionary,map_location=device) # The shape should be (n_atoms, feature_dim)

# Or initialize a new dictionary
D = torch.randn(n_atoms, feature_dim)/np.sqrt(feature_dim)
D = D/torch.linalg.vector_norm(D,axis=1,keepdims=True)
torch.save(D,path_to_dictionary) # Save the dictionary if you need to use the trained model later
D = D.to(device)

In [None]:
# Initialize autoencoder components
from CAE_components import Encoder, Decoder
encoder = Encoder(out_dim=n_atoms).to(device)
decoder = Decoder(out_dim=feature_dim).to(device)

# Load model weights if you have them
# encoder.load_state_dict
# decoder.load_state_dict

# Training

Prepare your own training and testing datasets and dataloader. Note that image size should be 250.

In [None]:
# Optimizer
optimizer = torch.optim.Adam([{'params': encoder.parameters(),'lr': 1e-4},
                              {'params': decoder.parameters(),'lr': 1e-4}])
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

In [None]:
# Suppose that you already have train_dataloader and batch_size

# Set weights for loss function, they can be sensitive and vary for different datasets
lam1 = 1.0
lam2 = 0.0072 
lam3 = 0.00001

total_epoch = 50
encoder.train()
decoder.train()
for epoch in range(1,total_epoch+1):
    for i,sample in enumerate(train_dataloader):
        optimizer.zero_grad()
        img = sample['img'].to(device)
        prob_weights = encoder(img)
        img_feature = prob_weights@D
        img_reconstruct = decoder(img_feature)
    
        loss_reconstruction = F.mse_loss(img,img_reconstruct)
        loss_entropy = -torch.sum(prob_weights*torch.log(prob_weights+1e-10))/batch_size # Add a small term to avoid nan
        loss_dirichlet = -torch.sum(torch.log(torch.mean(prob_weights+1e-10,dim=0))) 
        loss = lam1*loss_reconstruction + lam2*loss_entropy + lam3*loss_dirichlet
        
        loss.backward()
        optimizer.step()
    scheduler.step()

# Testing (img reconstruction)

In [None]:
# Suppose you have test images tensor: test_img

encoder.eval()
decoder.eval()
test_img = test_img.to(device)
test_img_reconstruct = decoder(encoder(test_img)@D)

# You can check the MSE loss and plot the reconstructed images

For downstream tasks, discard the decoder and add a classifer head (MLP) to the end of the encoder.