In [3]:
import torch
from torch.utils.data import Dataset
import sys
sys.path.append("../../../../../")

from torchvision import transforms
import graph
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import scipy
import skimage
# import custom functions
import sys
import utils
import os
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import re

import numpy as np
import pandas as pd
import scipy

In [4]:
class CNN_cell512_channel2(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 16, kernel_size = 2, stride = 2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size = 3, stride = 1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size = 4, stride = 2)
        self.conv4 = nn.Conv2d(64, 128, kernel_size = 2, stride = 1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size = 2, stride = 1)
        self.conv6 = nn.Conv2d(256, 512, kernel_size = 2, stride = 1)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 128)
        self.fc4 = nn.Linear(128, size)
    
    def cnn_encoder(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.pool(F.relu(self.conv5(x)))
        x = F.relu(self.conv6(x))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        

    def forward(self, x):
        x = self.cnn_encoder(x)
        x = self.fc4(F.relu(x))
        return x

In [20]:
class SingleCellImageDataset_Stream(Dataset):
    def __init__(self, img_path, train_mask, cell_nbhd, use_transform):
        """
        Form dataset of single cells
        Parameters
        ----------
        images: np.ndarray of shape (n_samples, C, H, W)
        cell_nbhd: np.ndarray of shape (n_samples, d)
        """
        super().__init__()
        self.transform = transforms.Compose([
            transforms.RandomRotation(degrees=180),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip()
        ])
        self.labels = cell_nbhd
        self.use_transform = use_transform
        self.img_path = img_path
        

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
#         if self.use_transform:
#             ind = self.train_list[index]
#         else:
#             ind = self.test_list[index]
        img = np.load(os.path.join(self.img_path, f"img_{index:05d}.npy"))
        if self.use_transform:
            img = self.transform(torch.Tensor(img))
        labels = self.labels[index]
        return img, labels

In [21]:
load_path = '/mnt/cloud1/sheng-projects/st_projects/spatial_clust/spatial-clust-scripts/ipynb/Bokai_reorg/benchmark/spleen/data/'
#df_clean = pd.read_csv(os.path.join(load_path, ""), index_col=0)
cell_nbhd = np.load(os.path.join(load_path, "cell_nbhd_res0.5_k20.npy")) # default
train_mask = np.load(os.path.join(load_path, "train_mask.npy"))

In [22]:
print([cell_nbhd.shape, train_mask.shape])

[(53500, 17), (53500,)]


In [44]:
load_path2 = '/mnt/cloud1/sheng-projects/st_projects/spatial_clust/data/codex_murine/'

#alpha = 0.6 
#alpha = 0.7 
#alpha = 0.8 
alpha = 0.9

size = 512 # default

train_nbhd = cell_nbhd[train_mask, :]
test_nbhd = cell_nbhd[~train_mask, :]

image_folder = os.path.join(load_path2, "processed_data", "single_cell_images", f"size{size}_qr{alpha}")
dataset = SingleCellImageDataset_Stream(os.path.join(image_folder,"images"),
                                        None, cell_nbhd, use_transform = False)


In [45]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
batch_size = 512
testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [46]:
l = 1
for a in [512]:
    
    model_save_path = os.path.join(load_path, "cnn", f"cnn_512_l{l}_layer6_testalpha:0.9_checkpoints", "epochs")
    size = cell_nbhd.shape[1]
    cnn_net = CNN_cell512_channel2(size)
        
    cnn_net = cnn_net.to(device)
    for epoch in tqdm([100, 200, 300, 400]):
        cnn_net.load_state_dict(torch.load(os.path.join(model_save_path, f"epoch{epoch}_model_weights.pth")))
        data_size = cell_nbhd.shape[0]
        cnn_embedding = np.zeros((data_size, 128))

        start_idx = 0
        with torch.no_grad():
            for data in testloader:
                inputs, labels = data
                inputs = inputs.to(device).to(torch.float32)
                outputs = cnn_net.cnn_encoder(inputs)
                cnn_embedding[start_idx: start_idx + inputs.shape[0]] = outputs.cpu()

                start_idx += inputs.shape[0]
        assert(start_idx == data_size)
        save_folder = os.path.join(model_save_path, 'embed', f"cnn_{a}_testalpha:0.9_l{l}_layer6_byepoch")
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        np.save(os.path.join(save_folder, f"cnn_embedding_{a}_full_l{l}_dim128_epoch{epoch}.npy"), cnn_embedding)

100%|█████████████████████████████████████████| 4/4 [01:24<00:00, 21.06s/it]
