Histolytics' modular model implementations allows you to select from a variety of pre-trained backbone encoders. For example, foundation models like [UNI](https://huggingface.co/MahmoodLab/UNI), [Virchow](https://huggingface.co/paige-ai/Virchow), or [Prov-GigaPath](https://huggingface.co/prov-gigapath/prov-gigapath) can be used, given that you have been granted the permissions to use these models. In general, any backbone from the [pytorch-image-models](https://github.com/huggingface/pytorch-image-models) (timm) library can be used.

In [1]:
from huggingface_hub import login

# login to huggingface to load the weights
# NOTE: You need to have granted permission for the weights before you can run this
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Initializing panoptic Cellpose with UNI backbone

**Note**: Models like UNI and Virchow require some extra key-word arguments to be passed during the timm initialization. These can be passed in the `model_kwargs` dictionary that takes in arbitrary arguments needed for initializing the models. The `encoder_kws` is the argument inside the `model_kwargs` that is used to pass the extra key-word arguments to the timm backbone encoder.

**Note**: The `enc_out_indices` argument in the `model_kwargs` is used to specify which layers of the backbone encoder are passed to the panoptic segmentation model as the U-net skip connections. The indices correspond to the layers in the backbone encoder. By default, 4 layers is passed to the pixel decoders so four indices are specified in the list.

In [3]:
from histolytics.models.cellpose_panoptic import CellposePanoptic

cpose_panoptic = CellposePanoptic(
    n_nuc_classes=6,
    n_tissue_classes=6,
    enc_name="hf_hub:MahmoodLab/uni",
    enc_pretrain=True,
    model_kwargs={
        "encoder_kws": {"init_values": 1e-5, "dynamic_img_size": True},
        "enc_out_indices": (2, 4, 6, 8),  # using layers 2, 4, 6, 8 from UNI
    },
)

# Print only a summary of the model instead of the full output
print(str(cpose_panoptic.model)[:500] + "\n...")

CellPoseUnet(
  (hf_hub:MahmoodLab/uni): Encoder(
    (encoder): TimmEncoder(
      (encoder): VisionTransformer(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
          (norm): Identity()
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (patch_drop): Identity()
        (norm_pre): Identity()
        (blocks): Sequential(
          (0): Block(
            (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=Tru
...


## Initializing panoptic HoverNet with Virchow backbone

In [4]:
import torch
from timm.layers import SwiGLUPacked
from histolytics.models.hovernet_panoptic import HoverNetPanoptic

hnet_panoptic = HoverNetPanoptic(
    n_nuc_classes=6,
    n_tissue_classes=6,
    enc_name="hf-hub:paige-ai/Virchow",
    enc_pretrain=True,
    model_kwargs={
        "encoder_kws": {"mlp_layer": SwiGLUPacked, "act_layer": torch.nn.SiLU},
        "enc_out_indices": (2, 4, 6, 8),  # using layers 2, 4, 6, 8 from Virchow
    },
)

# Print only a summary of the model instead of the full output
print(str(hnet_panoptic.model)[:500] + "\n...")

HoverNetUnet(
  (hf-hub:paige-ai/Virchow): Encoder(
    (encoder): TimmEncoder(
      (encoder): VisionTransformer(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14))
          (norm): Identity()
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (patch_drop): Identity()
        (norm_pre): Identity()
        (blocks): Sequential(
          (0): Block(
            (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=T
...


## Initializing panoptic Stardist with Prov-GigaPath backbone

In [None]:
from histolytics.models.stardist_panoptic import StarDistPanoptic

sdist_panoptic = StarDistPanoptic(
    n_nuc_classes=6,
    n_tissue_classes=6,
    enc_name="hf_hub:prov-gigapath/prov-gigapath",
    enc_pretrain=True,
    model_kwargs={
        "enc_out_indices": (2, 4, 6, 8),  # using layers 2, 4, 6, 8 from Prov-GigaPath
    },
)

# Print only a summary of the model instead of the full output
print(str(sdist_panoptic.model)[:500] + "\n...")

StarDistUnet(
  (hf_hub:prov-gigapath/prov-gigapath): Encoder(
    (encoder): TimmEncoder(
      (encoder): VisionTransformer(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 1536, kernel_size=(16, 16), stride=(16, 16))
          (norm): Identity()
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (patch_drop): Identity()
        (norm_pre): Identity()
        (blocks): Sequential(
          (0): Block(
            (norm1): LayerNorm((1536,), eps=1e-06, elementwi
...


## Initializing panoptic CPPNet with SAM image encoder backbone

In [6]:
from histolytics.models.cppnet_panoptic import CPPNetPanoptic

cpp_panoptic = CPPNetPanoptic(
    n_nuc_classes=6,
    n_tissue_classes=6,
    enc_name="samvit_base_patch16",
    enc_pretrain=True,
    model_kwargs={
        "enc_out_indices": (2, 4, 6, 8),  # using layers 2, 4, 6, 8 from SAM
    },
)

# Print only a summary of the model instead of the full output
print(str(cpp_panoptic.model)[:500] + "\n...")

CPPNetUnet(
  (samvit_base_patch16): Encoder(
    (encoder): TimmEncoder(
      (encoder): VisionTransformerSAM(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
          (norm): Identity()
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (patch_drop): Identity()
        (norm_pre): Identity()
        (blocks): Sequential(
          (0): Block(
            (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)

...


## Initializing panoptic CellVit with the SAM huge backbone

In [7]:
from histolytics.models.cellvit_panoptic import CellVitPanoptic

cvit_panoptic = CellVitPanoptic(
    n_nuc_classes=6,
    n_tissue_classes=6,
    enc_name="samvit_huge_patch16",
    enc_pretrain=True,
    model_kwargs={
        "enc_out_indices": (2, 4, 6, 8),  # using layers 2, 4, 6, 8 from SAM Huge
    },
)

# Print only a summary of the model instead of the full output
print(str(cvit_panoptic.model)[:500] + "\n...")

CellVitSamUnet(
  (samvit_huge_patch16): Encoder(
    (encoder): TimmEncoder(
      (encoder): VisionTransformerSAM(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
          (norm): Identity()
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (patch_drop): Identity()
        (norm_pre): Identity()
        (blocks): Sequential(
          (0): Block(
            (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=
...


To start finetuning your model with a pre-trained backbone, see the next chapter about finetuning

## Available timm encoders

In general, any pre-trained model from the timm (pytorch-image-models) library can be used as long as the model has the `forward_intermediates`-method implemented. You can list the available timm encoders followingly: 

In [8]:
import timm

timm.list_models()

['aimv2_1b_patch14_224',
 'aimv2_1b_patch14_336',
 'aimv2_1b_patch14_448',
 'aimv2_3b_patch14_224',
 'aimv2_3b_patch14_336',
 'aimv2_3b_patch14_448',
 'aimv2_huge_patch14_224',
 'aimv2_huge_patch14_336',
 'aimv2_huge_patch14_448',
 'aimv2_large_patch14_224',
 'aimv2_large_patch14_336',
 'aimv2_large_patch14_448',
 'bat_resnext26ts',
 'beit_base_patch16_224',
 'beit_base_patch16_384',
 'beit_large_patch16_224',
 'beit_large_patch16_384',
 'beit_large_patch16_512',
 'beitv2_base_patch16_224',
 'beitv2_large_patch16_224',
 'botnet26t_256',
 'botnet50ts_256',
 'caformer_b36',
 'caformer_m36',
 'caformer_s18',
 'caformer_s36',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_medium',
 'coat_lite_medium_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_small',
 'coat_tiny',
 'coatnet_0_224',
 'coatnet_0_rw_224',
 'coa