In [None]:
import argparse
import torch
import os
from dataset import SKIN, HERDataset, SliceBatchSampler
from model import GR2ST
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import AvgMeter, get_lr
from utils import get_R
from scipy.stats import pearsonr
import numpy as np
import torch.nn.functional as F
import anndata
import pandas as pd
from types import SimpleNamespace
def generate_args():
    # 创建具有默认参数值的命名空间对象
    args = SimpleNamespace()
    args.batch_size = 2048
    args.max_epochs = 100
    args.temperature = 1.0
    args.fold = 0
    args.dim = 171
    args.image_embedding_dim = 1024
    args.projection_dim = 256
    args.heads_num = 8
    args.dropout = 0.0
    args.dataset = 'cscc'
    args.dynamic_topk = 20
    args.spatial_topk = 20
    args.fusion_type = 'sum'
    return args


def train(model, train_dataLoader, optimizer, epoch):
    loss_meter = AvgMeter()
    tqdm_train = tqdm(train_dataLoader, total=len(train_dataLoader))
    for batch in tqdm_train:
        batch = {k: v.cuda() for k, v in batch.items() if
                 k == "image_features" or k == "expression" or k == "position" or k == "cell_type"}
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        count = batch["image_features"].size(0)
        loss_meter.update(loss.item(), count)
        tqdm_train.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer), epoch=epoch)

def load_data(args):
    if args.dataset == 'cscc':
        print(f'load dataset: {args.dataset}')
        train_dataset = SKIN(train=True, fold=args.fold)
        batch_sampler = SliceBatchSampler(train_dataset, args.batch_size)
        train_dataLoader = DataLoader(
            train_dataset, 
            batch_sampler=batch_sampler,
            num_workers=0
        )
        test_dataset = SKIN(train=False, fold=args.fold)
        return train_dataLoader, test_dataset
    

def save_model(args, model, test_dataset=None):
    os.makedirs(f"./model_result/{args.dataset}/{test_dataset.id2name[0]}", exist_ok=True)
    torch.save(model.state_dict(),
                f"./model_result/{args.dataset}/{test_dataset.id2name[0]}/best_{args.fold}.pt")
    


def main():
    args = generate_args()
    for i in range(12): 
        args.fold = i
        print("当前fold:", args.fold)
        train_dataLoader, test_dataset = load_data(args)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model = GR2ST(
                temperature=1.0,
                image_dim=1024,       
                spot_dim=171,         
                projection_dim=256,  
                heads_num=8,          
                dropout=0.1,          
                dynamic_topk=20,       
                spatial_topk=20,       
                fusion_type='sum'  
            )
        model.to(device)
        
        for epoch in range(args.max_epochs):
            optimizer = torch.optim.Adam(
                        model.parameters(), lr=1e-4, weight_decay=1e-3)
            model.train()
            train(model, train_dataLoader, optimizer, epoch)
                
                        
        save_model(args, model, test_dataset=test_dataset)
                            
                            
            

if __name__ == '__main__':
    main()