In [None]:
import numpy as np
import torch 
import matplotlib.pyplot as plt
import sklearn
import torchvision
from torchvision import datasets
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
import pandas as pd
import random
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchsummary import summary
from tqdm import tqdm
import wandb
import scipy.ndimage

In [None]:
seq_len = 6

# Feature extraction from images

via encoder part of CNN --> extract features from spatially resolved image data via pre-implemented encoder part of convolutional autoencoder to get low-dimensional embedding

1) reload pre-trained encoder part
2) feed every sequence through encoder to get low-dimensional embedding of each sequence
3) store low-dimensional embedding of each sequence again in tensor dataset that can serve as input to LSTM
 

In [None]:
## Construction of encoder part
class Encoder_original(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            #nn.Dropout(p=0.2),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            #nn.Dropout(p=0.2),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True),
            #nn.Dropout(p=0.2)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            #nn.Dropout(p=0.5),
            nn.Linear(128, encoded_space_dim)
        )
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        return x

In [None]:
dataset_train = torch.load("train_seq_data.pt")
dataset_test = torch.load("test_seq_data.pt")

In [None]:
trainloader = DataLoader(dataset=dataset_train, batch_size=128, shuffle=False)
testloader = DataLoader(dataset=dataset_test, batch_size=1024, shuffle=False)


## Low-dimensional embedding with LSTM

In [None]:
import torch
import torch.nn as nn

class Simple_LSTM(nn.Module):
    def __init__(self, lin_in, lin_out, lstm_in, lstm_hidden_size, lstm_num_layers, dropout_prob):
        super().__init__()
        self.linear = nn.Linear(lin_in, lin_out)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout( p=dropout_prob)
        self.lstm = nn.LSTM(lstm_in, lstm_hidden_size, lstm_num_layers, batch_first=True, bidirectional=False)
        
    def forward(self, x):
        # x has shape [batch_size, sequence_length, input_size]
        # pass x through the linear layer without losing the sequence structure
        x = self.linear(x)
        x = self.activation(x)
        x = self.dropout(x)
        
        # pass the output of the linear layer through the LSTM layer
        # the batch_first=True argument means that the input has shape [batch_size, sequence_length, input_size]
        # the LSTM layer outputs a tuple (output, (h_n, c_n)), but we only need the output
        lstm_output, lstm_hidden = self.lstm(x)
        
        # return the output of the LSTM layer
        return lstm_output, lstm_hidden


## Prediction of next image in sequence with Decoder CNN

In [None]:
## Construction of decoder part
class Decoder_original(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            #nn.Dropout(p=0.5),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True),
            #nn.Dropout(p=0.5),
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [None]:
def plot_comparison(img_pred, img_orig):
    # Convert tensors to numpy arrays
    img_pred_np = img_pred.squeeze().detach().numpy()
    img_orig_np = img_orig.squeeze().detach().numpy()

    # Plot images side-by-side
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(img_orig_np, cmap='gray')
    ax[0].set_title('Original')
    ax[1].imshow(img_pred_np, cmap='gray')
    ax[1].set_title('Predicted')
    plt.show()


In [None]:
import random
import matplotlib.pyplot as plt

def plot_random_predictions(conc_label, conc_target, conc_out):
    # Define a list to store the randomly chosen label values
    label_values = []

    # Choose a random image for each of the 10 labels and store the label values
    for i in range(10):
        # Find all indices where the label tensor has the current label value
        indices = (conc_label == i).nonzero(as_tuple=True)[0]

        # Choose a random index from the indices and extract the corresponding label value
        index = indices[random.randint(0, len(indices) - 1)]
        label_values.append(index)

    # Define the figure and axes for the plot
    fig, axes = plt.subplots(2, 10, figsize=(20, 4))

    # Plot the original images in the top row
    for i, index in enumerate(label_values):
        # Find the first index where the label tensor has the current label value
        #index = (conc_label == label_value).nonzero(as_tuple=True)[0][0]

        # Extract the corresponding original image tensor
        img_orig = conc_target[index]

        # Convert the tensor to a numpy array and plot the image
        img_orig_np = img_orig.squeeze().detach().numpy()
        axes[0][i].imshow(img_orig_np, cmap='gray')
        axes[0][i].set_title(f'Label {i}')

    # Plot the predicted images in the bottom row
    for i, index in enumerate(label_values):
        # Find the first index where the label tensor has the current label value
        #index = (conc_label == label_value).nonzero(as_tuple=True)[0][0]

        # Extract the corresponding predicted image tensor
        img_pred = conc_out[index]

        # Convert the tensor to a numpy array and plot the image
        img_pred_np = img_pred.squeeze().detach().numpy()
        axes[1][i].imshow(img_pred_np, cmap='gray')

    # Set the title for the plot
    fig.suptitle('Randomly chosen original images and their corresponding predicted images')

    # Show the plot
    plt.show()
    return fig


In [None]:
## Implementation of training and evaluation function

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
def train_img_prediction(lstm, decoder, device, dataloader, loss_fn, optimizer):
    # Set train mode for both the encoder and the decoder
    lstm.train()
    decoder.train()
    #lstm.to(device)
    #decoder.to(device)
    train_loss = []
    #counter = 0
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, target_batch, label, _ ,_, _,_ in dataloader: 
        # Move tensor to the proper device
        #lstm.zero_grad()
        #decoder.zero_grad()
        image_batch = image_batch.to(device)
        target_batch = target_batch.to(device)
        # Encode data
        all_layers, last_layer = lstm(image_batch)
        h, c = last_layer
        #h_forward = h[::2, :,:]
        embedded_data = h[2]
        #embedded_data = h_forward[-1]
        
        # Decode data
        pred_img_batch = decoder(embedded_data)
        
        # Evaluate loss
        loss = loss_fn(pred_img_batch.float(), target_batch.float())
        #train_loss.append(loss.detach().cpu().numpy())
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #lstm.zero_grad()
        #decoder.zero_grad()
        
        # Print batch loss
        #print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())
        #counter += 1
        #if counter % 100 == 0:
          #print("pred_img: ",pred_img_batch[0])
          #print("last layer: ", embedded_data[0])
        
        

    return np.mean(train_loss)

In [None]:
## Evaluation of model on test dataset
def evaluate_img_prediction(lstm, decoder, device, dataloader, loss_fn):
    # Set evaluation mode for encoder and decoder
    lstm.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_target = []
        conc_label = []
        for image_batch, target_batch, label, _ ,_, _,_  in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            target_batch = target_batch.to(device)
            # Encode data
            all_layers, last_layer = lstm(image_batch)
            h, c = last_layer
            #h_forward = h[::2, :,:]
            embedded_data = h[2]
            #embedded_data = h_forward[-1]
            # Decode data
            pred_img_batch = decoder(embedded_data)
            # Append the network output and the original image to the lists
            conc_out.append(pred_img_batch.cpu())
            conc_target.append(target_batch.cpu())
            conc_label.append(label.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_target = torch.cat(conc_target) 
        conc_label = torch.cat(conc_label)
        # Evaluate global loss
    val_loss = loss_fn(conc_out.float(), conc_target.float())

    # Choose a random index
    index = random.randint(0, len(conc_out) - 1)

    # Extract corresponding img_pred and img_orig tensors
    img_pred = conc_out[index]
    img_orig = conc_target[index]
    #plot_comparison(img_pred, img_orig)
    rec_fig = plot_random_predictions(conc_out=conc_out, conc_target=conc_target, conc_label=conc_label)


    return val_loss, rec_fig


In [None]:
### Define the loss function
loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr= 0.001

### Set the random seed for reproducible results
torch.manual_seed(0)

### Initialize the two networks
d = 25
seq_len = 6

### Define weight decay
weight_decay = 0

### Define model parameters
lin_in = d
lin_out = 128
hidden_size = 25
stack = 3

#model = Autoencoder(encoded_space_dim=encoded_space_dim)
lstm_model = Simple_LSTM(lin_in=d, lin_out=lin_out, lstm_in=lin_out, lstm_hidden_size=hidden_size, lstm_num_layers=stack, dropout_prob=0)
decoder_prediction = Decoder_original(encoded_space_dim=hidden_size,fc2_input_dim=128)
params_to_optimize = [
    {'params': lstm_model.parameters()},
    {'params': decoder_prediction.parameters()}
]

optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=weight_decay)

# Check if the GPU is available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Selected device: {device}')
#device = "cpu"
# Move both the encoder and the decoder to the selected device
decoder_prediction.to(device)
lstm_model.to(device)


In [None]:
num_epochs = 12 #number of iterations
diz_loss = {'train_loss':[],'val_loss':[]} #store training and evaluation loss
for epoch in range(num_epochs):
   train_loss = train_img_prediction(lstm=lstm_model,decoder=decoder_prediction,device=device,dataloader=trainloader, loss_fn=loss_fn, optimizer=optim)
   val_loss, _ = evaluate_img_prediction(lstm=lstm_model,decoder=decoder_prediction,device=device,dataloader=testloader, loss_fn=loss_fn) #evaluate perfomance of autoencoder on test set
   print('\n EPOCH {}/{} \t train loss {} \t val loss {}'.format(epoch + 1, num_epochs,train_loss,val_loss))

   diz_loss['train_loss'].append(train_loss)
   diz_loss['val_loss'].append(val_loss)

In [None]:
def plot_loss(diz_loss):
    train_loss = diz_loss['train_loss']
    val_loss = diz_loss['val_loss']
    epochs = range(1, len(train_loss) + 1)

    fig = plt.figure()
    plt.plot(epochs, train_loss, 'g', label='Training loss')
    plt.plot(epochs, val_loss, 'b', label='Validation loss')
    plt.title('Training and Validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    return fig


plt_loss = plot_loss(diz_loss)

In [None]:
run.log({"Plot_loss": plt_loss})

In [None]:
val_loss, rec_fig = evaluate_img_prediction(lstm=lstm_model,decoder=decoder_prediction,device=device,dataloader=testloader, loss_fn=loss_fn)

In [None]:
run.log({"Pred_Img": rec_fig})

In [None]:
# Assume that your model is named "model" and your optimizer is named "optimizer"
lstm_state_dict = lstm_model.state_dict()
decoder_state_dict = decoder_prediction.state_dict()
optimizer_state_dict = optim.state_dict()

# Combine the model and optimizer state dictionaries into a single dictionary
state_dict = {'lstm': lstm_state_dict,'decoder': decoder_state_dict, 'optimizer': optimizer_state_dict}

# Specify the path where you want to save the state dictionary
PATH = "/content/drive/MyDrive/Healing_MNIST/CNN_LSTM_const_seq/model_const_rand_rot_rand_square_UNIDIRECTIONAL.pth"

# Save the combined state dictionary to the specified path
torch.save(state_dict, PATH)


In [None]:
def get_embeddings(data, device, lstm):
    encoded_samples = []
    for sample in tqdm(data):
        img = sample[0].to(device)
        label = sample[2]
        rot = sample[3]
        square = sample[4]
        square_count = sample[5]
        seq_len = sample[6]
        # Encode image
        lstm.eval()
        with torch.no_grad():
            all_layers, last_layer = lstm(img)
            h, c = last_layer
            #h_forward = h[::2, :,:]
            #embedded_img = h_forward[-1]
            embedded_img = h[-1]
        # Append to list
        encoded_img = embedded_img.cpu().numpy()
        encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
        encoded_sample['label'] = label.numpy()
        encoded_sample['square'] = square.numpy()
        encoded_sample['rotation'] = rot.numpy()
        encoded_sample['square_count'] = square_count.numpy()
        encoded_sample['seq_len'] = seq_len.numpy()
        encoded_samples.append(encoded_sample)
    encoded_samples = pd.DataFrame(encoded_samples)
    return encoded_samples

In [None]:
train_embeddings = get_embeddings(data=dataset_train, device=device, lstm=lstm_model)

In [None]:
test_embeddings = get_embeddings(data=dataset_test, device=device, lstm=lstm_model)

In [None]:
from sklearn.manifold import TSNE
import plotly.io as pio
import plotly.express as px

tsne = TSNE(n_components=2)
tsne_results = tsne.fit_transform(test_embeddings.drop(['label', 'square', 'rotation', 'seq_len', 'square_count'],axis=1))
fig_tsne_rot = px.scatter(tsne_results, x=0, y=1,
                 color=test_embeddings.rotation.astype(str),
                 symbol=test_embeddings.label.astype(str),
                 #symbol_sequence=['circle', 'cross'],
                 labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
fig_tsne_rot.show()

In [None]:
fig_tsne_label = px.scatter(tsne_results, x=0, y=1,
                 color=test_embeddings.label.astype(str),
                 symbol=test_embeddings.rotation.astype(str),
                 #symbol_sequence=['circle', 'cross'],
                 labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
fig_tsne_label.show()

In [None]:
fig_tsne_square = px.scatter(tsne_results, x=0, y=1,
                 color=test_embeddings.square.astype(str),
                 #symbol=test_embeddings.label.astype(str),
                 #symbol_sequence=['circle', 'cross'],
                 labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
fig_tsne_square.show()

In [None]:
fig_tsne_square_count = px.scatter(tsne_results, x=0, y=1,
                 color=test_embeddings.square_count.astype(str),
                 #symbol=test_embeddings.label.astype(str),
                 #symbol_sequence=['circle', 'cross'],
                 labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
fig_tsne_square_count.show()

In [None]:
fig_tsne_seq = px.scatter(tsne_results, x=0, y=1,
                 color=test_embeddings.seq_len.astype(str),
                 #symbol=test_embeddings.label.astype(str),
                 #symbol_sequence=['circle', 'cross'],
                 labels={'0': 'tsne-2d-one', '1': 'tsne-2d-two'})
fig_tsne_seq.show()