## Using Trident with Your Own Foundation Model 

As more and more groups design their own foundation model, we want to offer easy tools for custom integration. This is the idea of the `CustomInferenceEncoder` from the `patch_encoder_models` module. 

In [None]:
from PIL import Image
import requests
import torch
import timm
import torchvision.transforms as transforms

from trident.patch_encoder_models import CustomInferenceEncoder

# Load your custom model (eg ViT pretrained on ImageNet)
model = timm.create_model('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', pretrained=True)
model = model.eval()
model.head = torch.nn.Identity()  

# Set precision
precision = torch.float16

# Set transforms
data_config = timm.data.resolve_model_data_config(model)
eval_transforms = timm.data.create_transform(**data_config, is_training=False)

# Create custom encoder
custom_patch_encoder = CustomInferenceEncoder(
    enc_name='my_custom_model',
    model=model,
    transforms=eval_transforms,
    precision=precision
)


In [None]:
# Integrate the above model into Trident "regular" pipeline, e.g., using the Processor
import os
import torch
from huggingface_hub import snapshot_download

from trident.Processor import Processor
from trident.segmentation_models import segmentation_model_factory

OUTPUT_DIR = "tutorial-2/"
DEVICE = f"cuda:0" if torch.cuda.is_available() else "cpu"
WSI_FNAME = '394140.svs'
os.makedirs(OUTPUT_DIR, exist_ok=True)
local_wsi_dir = snapshot_download(
    repo_id="MahmoodLab/unit-testing",
    repo_type='dataset',
    local_dir=os.path.join(OUTPUT_DIR, 'wsis'),
    allow_patterns=[WSI_FNAME]
)

# Create processor
processor = Processor(
    job_dir=OUTPUT_DIR,       # Directory to store outputs
    wsi_source=local_wsi_dir, # Directory containing WSI files
)

# Run tissue vs background segmentation
segmentation_model = segmentation_model_factory('hest')
processor.run_segmentation_job(
    segmentation_model,
    device=DEVICE
)

# Run tissue coordinate extraction (256x256 at 20x)
processor.run_patching_job(
    target_magnification=20,
    patch_size=256,
    overlap=0
)

# Run patch feature extraction using the custom encoder
processor.run_patch_feature_extraction_job(
    coords_dir=f'20x_256px_0px_overlap', # Make sure to change this if you changed the patching parameters
    patch_encoder=custom_patch_encoder,
    device=DEVICE,
    saveas='h5',
    batch_limit=32
)

