## Download Model

This is because I am training on Kaggle. Kaggle Internet access is disabled, so I have to manually upload the model weights after training.

In [None]:
from huggingface_hub import hf_hub_download
import shutil
import os

repo_id = "timm/convnext_base.clip_laion2b_augreg_ft_in12k_in1k"
filename = "model.safetensors"

print(f"Starting download for: {repo_id}")

# This downloads the file from Hugging Face to your cache
cached_path = hf_hub_download(repo_id=repo_id, filename=filename)

# Define where you want to save it locally
local_filename = "convnext_base_weights.safetensors"
destination = os.path.join(os.getcwd(), local_filename)

# Copy from cache to your current folder
shutil.copy(cached_path, destination)


In [None]:
import torch
import timm
from safetensors.torch import load_file

# 1. Setup paths and parameters
weights_path = 'convnext_base_weights.safetensors' 
model_name = 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k'

# 2. Create the model structure (Empty, Random Head)
# pretrained=False because we are loading manually
model = timm.create_model(model_name, pretrained=False, num_classes=2)


In [None]:
print(f"Model created. Expecting head size: {model.head.fc.weight.shape}")

# 3. Load the weights safely
try:
    if weights_path.endswith('.safetensors'):
        state_dict = load_file(weights_path)
    else:
        state_dict = torch.load(weights_path, map_location='cpu')
        
    # --- CRITICAL FIX: REMOVE THE HEAD ---
    # We remove the classification head weights from the loaded file
    # because they are size 1000, but our model is size 2.
    # We will let our model keep its random initialization for these 2 layers.
    keys_to_remove = [k for k in state_dict.keys() if 'head' in k or 'fc' in k]
    for k in keys_to_remove:
        # Only remove if the shape doesn't match (just in case)
        if k in model.state_dict() and state_dict[k].shape != model.state_dict()[k].shape:
            print(f"Removing mismatched key: {k} {state_dict[k].shape}")
            del state_dict[k]
    
    # 4. Load into model
    # strict=False allows loading even though we deleted the head keys
    model.load_state_dict(state_dict, strict=False)
    print("Success! Pretrained backbone loaded. Head is fresh (random) for 2 classes.")

except Exception as e:
    print(f"Error: {e}")

In [None]:
# 5. Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model.to(device)