In [7]:
import torch
from easydict import EasyDict
from torchvision import transforms
from PIL import Image
import sys
import os

# Change to the correct directory
os.chdir('/home/yuhaowang/project/BRFound')

# Add current directory to Python path
sys.path.insert(0, '/home/yuhaowang/project/BRFound')

# Direct imports to avoid module issues
try:
    from src.vision_transformer import vit_base, DinoVisionTransformer
    from src.utils import load_pretrained_weights
    print("✓ Direct imports successful")
except ImportError as e:
    print(f"Import error: {e}")

def build_model_from_cfg(cfg, only_teacher=False):
    """Build model from configuration"""
    args = cfg.student
    img_size = cfg.crops.global_crops_size
    
    if "vit" in args.arch:
        vit_kwargs = dict(
            img_size=img_size,
            patch_size=args.patch_size,
            init_values=args.layerscale,
            ffn_layer=args.ffn_layer,
            block_chunks=args.block_chunks,
            qkv_bias=args.qkv_bias,
            proj_bias=args.proj_bias,
            ffn_bias=args.ffn_bias,
            num_register_tokens=args.num_register_tokens,
            interpolate_offset=args.interpolate_offset,
            interpolate_antialias=args.interpolate_antialias,
        )
        
        if args.arch == 'vit_base':
            teacher = vit_base(**vit_kwargs)
        else:
            # Fallback to DinoVisionTransformer constructor
            teacher = DinoVisionTransformer(**vit_kwargs)
            
        if only_teacher:
            return teacher, teacher.embed_dim
            
        # For student model, add dropout parameters
        if args.arch == 'vit_base':
            student = vit_base(
                **vit_kwargs,
                drop_path_rate=args.drop_path_rate,
                drop_path_uniform=args.drop_path_uniform,
            )
        else:
            student = DinoVisionTransformer(
                **vit_kwargs,
                drop_path_rate=args.drop_path_rate,
                drop_path_uniform=args.drop_path_uniform,
            )
        embed_dim = student.embed_dim
        
        return student, teacher, embed_dim

def load_pretrained_model(weights_path, device='cuda'):
    model = vit_base(patch_size=16)
    state_dict = torch.load(weights_path, map_location=device)

    # Remove potential prefixes in keys
    state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}

    model.load_state_dict(state_dict, strict=False)
    #check if load successfully
    if not all(k in model.state_dict() for k in state_dict.keys()):
        missing_keys = set(state_dict.keys()) - set(model.state_dict().keys())
        print(f"Missing keys in the model state dict: {missing_keys}")
    else:
        print("All keys loaded successfully.")
    model.to(device)
    model.eval()

    return model


def get_transform():
    return transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


def extract_features(image_path, model, device='cuda'):
    transform = get_transform()

    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        features = model(image_tensor)

    return features.cpu().numpy()

def build_model_for_eval(config, pretrained_weights):
    model, _ = build_model_from_cfg(config, only_teacher=True)
    load_pretrained_weights(model, pretrained_weights, "teacher")
    model.eval()
    model.cuda()
    return model

if __name__ == "__main__":
    
    config = EasyDict({
        'student': EasyDict({
            'arch': 'vit_base', 
            'patch_size': 16, 
            'drop_path_rate': 0.3,  
            'layerscale': 1.0e-05,  
            'drop_path_uniform': True,  
            'pretrained_weights': '',  
            'ffn_layer': 'mlp',  
            'block_chunks': 4,  
            'qkv_bias': True,  
            'proj_bias': True,  
            'ffn_bias': True,  
            'num_register_tokens': 0,  
            'interpolate_antialias': False,  
            'interpolate_offset': 0.1  
        }),
        'crops': EasyDict({
            'global_crops_size': 224,
        })    
    })
        
    weights_path = './weights/patch_encoder.pth'
    image_path = './images/patch_1.png'

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Check if files exist
    if not os.path.exists(weights_path):
        print(f"Warning: weights file not found at {weights_path}")
    if not os.path.exists(image_path):
        print(f"Warning: image file not found at {image_path}")
    
    # Only proceed if imports were successful and files exist
    try:
        if os.path.exists(weights_path) and os.path.exists(image_path):
            model = build_model_for_eval(config, weights_path)
            features = extract_features(image_path, model, device=device)
            print(f"Extracted features shape: {features.shape}")
        else:
            print("Skipping model evaluation due to missing files")
    except Exception as e:
        print(f"Error during model evaluation: {e}")
        import traceback
        traceback.print_exc()

✓ Direct imports successful
Using device: cuda
Extracted features shape: (1, 768)
Extracted features shape: (1, 768)
