In [4]:
import os
import timm
import h5py
import anndata as ad
from torch.utils.data import Dataset, DataLoader
import sys 
import torch
from torch import einsum
import torch.nn.functional as F
from torch import nn
sys.path.append("/mnt/DATA-4/hx/Ruipath/MunchkinCat")
sys.path.append("/mnt/DATA-4/hx/Ruipath/scFoundation/model/")
# from model_finetune import *
from data_loader import *
from load import main_gene_selection

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
slide_ids, slide_patch_ids = Get_hest_meta()
#mydataset = HESTDataset(ids, root_dir="/mnt/DATA-4/hx/Ruipath/hest_data")

In [2]:
dataset = HESTDataset(
    data_root="/mnt/DATA-4/hx/Ruipath/hest_data"
)

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# # Provide a patch image from a WSI
# patch_image = Image.open("wsi_patch.png")
# image = transform(patch_image).unsqueeze(0)
# with torch.inference_mode():
#     patch_feature = model(image).cpu().numpy().astype(np.float32)
#     print(patch_feature)

In [None]:
class RuiPathViT(nn.Module):

    def __init__(self, ckpt_path):
        super().__init__()
        self.ckpt_path = ckpt_path
        self.head_lr = head_lr

    def build(self):
        self.model = timm.create_model(
            "vit_large_patch16_224", 
            img_size=224, 
            patch_size=16, 
            init_values=1e-5, 
            num_classes=0, 
            dynamic_img_size=True
        )
        embed_dim = self.model.num_features
        self.head = nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, embed_dim*2),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(embed_dim*2, embed_dim)
            )
        self.model.load_state_dict(torch.load(self.ckpt_path, map_location=device), strict=True)
        for param in self.model.parameters():
                param.requires_grad = False
        
    def forward(self, x, *args, **kwargs):
        features = self.model(x, *args, **kwargs)
        return self.head(features)
    
    def construct_optimizers(self, head_lr=1e-3):
        optimizer = torch.optim.AdamW(
            self.head.parameters(),
            lr=head_lr,
            weight_decay=1e-2,
        )
        return optimizer

In [22]:
local_dir = "/mnt/DATA-4/hx/Ruipath/RuiPathViT/Ruipath_visionfoundation_v1.bin"
ruipathvit = RuiPathViT(local_dir)
ruipathvit.build()

In [15]:
import sys 
sys.path.append("/mnt/DATA-4/hx/Ruipath/scFoundation/model/") # path to this folder
from load import *

In [16]:
class scFoundation(nn.Module):

    def __init__(self, ckpt_path, out_dim, frozenmore=True):
        super().__init__()
        self.ckpt_path = ckpt_path
        self.frozenmore = frozenmore
        self.out_dim = out_dim

    def build(self):
        model,model_config = load_model_frommmf(self.ckpt_path)
        self.token_emb = model.token_emb
        self.pos_emb = model.pos_emb
        self.encoder = model.encoder
        
        if self.frozenmore:
            for _,p in self.token_emb.named_parameters():
                p.requires_grad = False
            for _,p in self.pos_emb.named_parameters():
                p.requires_grad = False
            print('self.pos_emb and self.token_emb also frozen')
        
        for na, param in self.encoder.named_parameters():
            param.requires_grad = False
        for na, param in self.encoder.transformer_encoder[-2].named_parameters():
            print('self.encoder.transformer_encoder ',na,' have grad')
            param.requires_grad = True


        self.fc1 = nn.Sequential(
        nn.Linear(model_config['encoder']['hidden_dim'], self.out_dim*2),
        nn.ReLU(),
        nn.Linear(self.out_dim*2, self.out_dim)
        ) 
        self.norm = torch.nn.BatchNorm1d(model_config['encoder']['hidden_dim'], affine=False, eps=1e-6)
        self.model_config = model_config
        
    def forward(self, x, *args, **kwargs):
        
        value_labels = x > 0
        x, x_padding = gatherData(x, value_labels, self.model_config['pad_token_id'])
        data_gene_ids = torch.arange(19264, device=x.device).repeat(x.shape[0], 1)
        position_gene_ids, _ = gatherData(data_gene_ids, value_labels,
                                        self.model_config['pad_token_id'])
        
        x = self.token_emb(torch.unsqueeze(x, 2).float(), output_weight = 0)
        position_emb = self.pos_emb(position_gene_ids)
        x += position_emb

        geneemb = self.encoder(x,x_padding)
        geneembmerge, _ = torch.max(geneemb, dim=1)
        
        return self.fc1(geneembmerge)

In [52]:
class RuiPathST(nn.Module):
    def __init__(self, vision_model, sc_model):
        super().__init__()
        self.vision_model = vision_model
        self.sc_model = sc_model
        self.temperature = nn.Parameter(torch.Tensor([1.]))
    
    def embed_img(self, img):
        return self.vision_model(img)
    
    def embed_gene(self, gene_expr):
        return self.sc_model(gene_expr)

    def forward(self, img, gene_expr):
        batch_size = img.shape[0]
        img_features = self.embed_img(img)
        gene_features = self.embed_gene(gene_expr)
        ce = F.cross_entropy
        sim = einsum('i d, j d -> i j', img_features, gene_features)
        sim = sim * self.temperature.exp()
        contrastive_labels = torch.arange(batch_size, device=device)

        contrastive_loss = (ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels)) * 0.5
        return contrastive_loss

In [53]:
RuipathST = RuiPathST(ruipathvit, sc_model).cuda()
RuipathST(img.unsqueeze(0).cuda(), x.unsqueeze(0).cuda())

tensor(0., device='cuda:0', grad_fn=<MulBackward0>)

In [57]:
import torch.optim as optim

train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)   
model = RuiPathST(ruipathvit, sc_model).cuda()

# 使用自定义 criterion（这里不需要 nn.CrossEntropyLoss）
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(3):
    model.train()
    running_loss = 0.0
    for i, (images, genes) in enumerate(train_loader):  # 假设 dataloader 返回图像+基因
        images = images.cuda()
        genes = genes.cuda()

        optimizer.zero_grad()
        loss = model(images, genes)  # 假设模型这样设计
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        num_batches = i + 1
        avg_loss = running_loss / num_batches
        if i % 3 == 0:
            print(f'Epoch [{epoch+1}/{3}], Step [{i+1}/{len(train_loader)}], \
                  Total Loss: {running_loss:.4f}, Avg Loss: {avg_loss:.4f}')

    # 验证部分（可选，InfoNCE 通常不直接验证 accuracy）
    print(f"Loss on epoch {epoch+1}: {running_loss / len(train_loader):.4f}")

torch.save(model.state_dict(), "/mnt/DATA-4/hx/Ruipath/RuiPathST_ckp")
print(f"Model saved to /mnt/DATA-4/hx/Ruipath/RuiPathST_ckp")

Epoch [1/3], Step [1/3157],                   Total Loss: 0.0351, Avg Loss: 3.5122
Epoch [1/3], Step [4/3157],                   Total Loss: 0.4263, Avg Loss: 10.6580
Epoch [1/3], Step [7/3157],                   Total Loss: 0.7644, Avg Loss: 10.9200
Epoch [1/3], Step [10/3157],                   Total Loss: 1.1432, Avg Loss: 11.4322
Epoch [1/3], Step [13/3157],                   Total Loss: 1.2732, Avg Loss: 9.7938
Epoch [1/3], Step [16/3157],                   Total Loss: 1.3912, Avg Loss: 8.6950
Epoch [1/3], Step [19/3157],                   Total Loss: 1.5028, Avg Loss: 7.9094
Epoch [1/3], Step [22/3157],                   Total Loss: 1.6193, Avg Loss: 7.3606
Epoch [1/3], Step [25/3157],                   Total Loss: 1.7420, Avg Loss: 6.9681
Epoch [1/3], Step [28/3157],                   Total Loss: 1.8553, Avg Loss: 6.6261
Epoch [1/3], Step [31/3157],                   Total Loss: 1.9865, Avg Loss: 6.4079
Epoch [1/3], Step [34/3157],                   Total Loss: 2.0981, Avg Loss:

KeyboardInterrupt: 