In [22]:
!git clone https://github.com/HUANGLIZI/LViT.git

Cloning into 'LViT'...
remote: Enumerating objects: 323, done.[K
remote: Counting objects: 100% (70/70), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 323 (delta 65), reused 61 (delta 61), pack-reused 253 (from 1)[K
Receiving objects: 100% (323/323), 92.10 MiB | 29.88 MiB/s, done.
Resolving deltas: 100% (123/123), done.
Updating files: 100% (121/121), done.


In [23]:
!pip install ml_collections




In [24]:
%cd LViT


/LViT


In [25]:
%cd datasets

/LViT/datasets


In [26]:
!pip install transformers




In [27]:
import os
import torch
import time
import ml_collections

# Set CUDA and random seed
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
use_cuda = torch.cuda.is_available()
seed = 666
os.environ['PYTHONHASHSEED'] = str(seed)


In [28]:
import os

class Config:
    task_name = 'MoNuSeg'
    model_name = 'LViT'  # or 'LViT_pretrain' if using a pretrained model
    img_size = 224
    batch_size = 4
    epochs = 100
    learning_rate = 1e-4
    seed = 42
    cosineLR = True
    early_stopping_patience = 10
    session_name = 'MoNuSeg_LViT_Session'
    save_path = './models/'
    tensorboard_folder = './tensorboard_logs/'
    logger_path = './training.log'
    train_dataset = './datasets/MoNuSeg/Train_Folder/'
    val_dataset = './datasets/MoNuSeg/Val_Folder/'
    test_dataset = './datasets/MoNuSeg/Test_Folder/'
    n_channels = 3
    n_labels = 1  # Adjust based on your dataset

config = Config()


In [29]:
import numpy as np
import torch
import random
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
from torchvision import transforms as T
from torchvision.transforms import functional as F
import os
import cv2
from scipy import ndimage
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

  from scipy.ndimage.interpolation import zoom


In [30]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModel.from_pretrained("bert-base-uncased")
bert_model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [31]:
def random_rot_flip(image, label):
    k = np.random.randint(0, 4)
    image = np.rot90(image, k)
    label = np.rot90(label, k)
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis=axis).copy()
    label = np.flip(label, axis=axis).copy()
    return image, label

def random_rotate(image, label):
    angle = np.random.randint(-20, 20)
    image = ndimage.rotate(image, angle, order=0, reshape=False)
    label = ndimage.rotate(label, angle, order=0, reshape=False)
    return image, label

def to_long_tensor(pic):
    img = torch.from_numpy(np.array(pic, np.uint8))
    return img.long()

def correct_dims(*images):
    corr_images = []
    for img in images:
        if len(img.shape) == 2:
            corr_images.append(np.expand_dims(img, axis=2))
        else:
            corr_images.append(img)
    return corr_images if len(corr_images) > 1 else corr_images[0]

In [32]:
class RandomGenerator(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label, text = sample['image'], sample['label'], sample['text']
        image, label = image.astype(np.uint8), label.astype(np.uint8)
        image, label = F.to_pil_image(image), F.to_pil_image(label)
        x, y = image.size
        if random.random() > 0.5:
            image, label = random_rot_flip(image, label)
        elif random.random() > 0.5:
            image, label = random_rotate(image, label)
        if x != self.output_size[0] or y != self.output_size[1]:
            image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)
            label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
        image = F.to_tensor(image)
        label = to_long_tensor(label)
        text = torch.Tensor(text)
        return {'image': image, 'label': label, 'text': text}


In [33]:
from typing import Callable


In [38]:
class ImageToImage2D(Dataset):

    def __init__(self, dataset_path: str, task_name: str, row_text: dict, joint_transform: Callable = None,
                 one_hot_mask: int = False, image_size: int = 224) -> None:
        self.dataset_path = dataset_path
        self.image_size = image_size
        self.input_path = os.path.join(dataset_path, 'img')
        self.output_path = os.path.join(dataset_path, 'labelcol')
        self.images_list = os.listdir(self.input_path)
        self.mask_list = os.listdir(self.output_path)
        self.one_hot_mask = one_hot_mask
        self.rowtext = row_text
        self.task_name = task_name
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.bert_model = AutoModel.from_pretrained("bert-base-uncased")
        self.bert_model.eval()

        if joint_transform:
            self.joint_transform = joint_transform
        else:
            to_tensor = T.ToTensor()
            self.joint_transform = lambda x, y: (to_tensor(x), to_tensor(y))

    def __len__(self):
        return len(os.listdir(self.input_path))

    def __getitem__(self, idx):
        image_filename = self.images_list[idx]
        mask_filename = image_filename[: -3] + "png"

        image = cv2.imread(os.path.join(self.input_path, image_filename))
        image = cv2.resize(image, (self.image_size, self.image_size))

        mask = cv2.imread(os.path.join(self.output_path, mask_filename), 0)
        mask = cv2.resize(mask, (self.image_size, self.image_size))
        mask[mask <= 0] = 0
        mask[mask > 0] = 1

        image, mask = correct_dims(image, mask)

        text = self.rowtext[mask_filename]
        text = text.split('\n')
        text = ' '.join(text)  # merge multi-line input if needed
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
           outputs = self.bert_model(**inputs)
        text_embedding = outputs.last_hidden_state.squeeze(0)  # shape: (seq_len, hidden_size)
        if text_embedding.shape[0] > 10:
             text_embedding = text_embedding[:10, :]
        text = text_embedding.cpu().numpy()


        if self.one_hot_mask:
            assert self.one_hot_mask > 0, 'one_hot_mask must be nonnegative'
            mask = torch.zeros((self.one_hot_mask, mask.shape[1], mask.shape[2])).scatter_(0, mask.long(), 1)

        sample = {'image': image, 'label': mask, 'text': text}
        if self.joint_transform:
            sample = self.joint_transform(sample)

        return sample, image_filename


In [49]:
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms as T
from PIL import Image

class CustomLV2D(Dataset):
    def __init__(self, image_dir, label_dir, excel_path, image_size=224, one_hot_mask=True):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.image_size = image_size
        self.one_hot_mask = one_hot_mask

        # Load Excel metadata
        df = pd.read_excel(excel_path)

        # Dynamically find columns (fallbacks if names vary)
        self.image_col = next((col for col in df.columns if 'image' in col.lower()), df.columns[0])
        self.text_col = next((col for col in df.columns if 'text' in col.lower()), df.columns[1])

        self.image_names = df[self.image_col].astype(str).tolist()
        self.texts = df[self.text_col].astype(str).tolist()

        # Image transform
        self.transform_img = T.Compose([
            T.Resize((self.image_size, self.image_size)),
            T.ToTensor()
        ])

        self.transform_mask = T.Compose([
            T.Resize((self.image_size, self.image_size), interpolation=T.InterpolationMode.NEAREST)
        ])

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]

        # Normalize image and label filenames
        img_path = os.path.join(self.image_dir, img_name)
        img_path = img_path.replace('.jpeg', '.jpg')  # Optional: normalize extensions

        label_name = img_name.rsplit('.', 1)[0] + '.png'  # Replace extension with .png
        label_path = os.path.join(self.label_dir, label_name)

        # Load image
        image = Image.open(img_path).convert("RGB")
        image = self.transform_img(image)

        # Load and binarize mask
        label = Image.open(label_path).convert("L")
        label = self.transform_mask(label)
        label = np.array(label)
        label = (label > 127).astype(np.uint8)  # binarize

        if self.one_hot_mask:
            one_hot = np.zeros((2, label.shape[0], label.shape[1]), dtype=np.float32)
            one_hot[0] = (label == 0)
            one_hot[1] = (label == 1)
            label = torch.from_numpy(one_hot)
        else:
            label = torch.tensor(label, dtype=torch.long).unsqueeze(0)  # Add channel dim

        text = self.texts[idx]

        return {
            'image': image,
            'label': label,
            'text': text
        }


In [50]:
train_dataset = CustomLV2D(
    image_dir='/content/LViT/datasets/MoNuSeg/Train_Folder/img',
    label_dir='/content/LViT/datasets/MoNuSeg/Train_Folder/labelcol',
    excel_path='/content/LViT/datasets/MoNuSeg/Train_Folder/Train_text.xlsx',
    image_size=224,
    one_hot_mask=True
)

# Check one sample
sample = train_dataset[0]
print(sample['image'].shape)  # [3, 224, 224]
print(sample['label'].shape)  # [2, 224, 224]
print(sample['text'])         # description string


torch.Size([3, 224, 224])
torch.Size([2, 224, 224])
The nuclei are evenly distributed.


In [48]:
import pandas as pd

df = pd.read_excel('/content/LViT/datasets/MoNuSeg/Train_Folder/Train_text.xlsx')
print(df.columns)

Index(['Image', 'Description'], dtype='object')
