In [10]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch

# Get current GPU information
current_gpu = torch.cuda.current_device()
print(f"Using GPU: {current_gpu}")
print(f"GPU name: {torch.cuda.get_device_name(current_gpu)}")

# Memory usage for current GPU
allocated = torch.cuda.memory_allocated(current_gpu) / 1024**3  # Convert to GB
reserved = torch.cuda.memory_reserved(current_gpu) / 1024**3    # Convert to GB
max_allocated = torch.cuda.max_memory_allocated(current_gpu) / 1024**3

print(f"Memory allocated: {allocated:.2f} GB")
print(f"Memory reserved: {reserved:.2f} GB") 
print(f"Max memory allocated: {max_allocated:.2f} GB")

Using GPU: 0
GPU name: NVIDIA RTX A5000
Memory allocated: 0.00 GB
Memory reserved: 0.00 GB
Max memory allocated: 0.00 GB


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image

try:
    from torchvision import transforms
except ImportError:
    # Fallback if torchvision not available
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torchvision"])
    from torchvision import transforms

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
import time
from datetime import datetime
from tqdm import tqdm


try:
    import timm
    from tqdm import tqdm
except ImportError:
    print("Installing required libraries...")
    import subprocess
    subprocess.check_call(["pip", "install", "timm", "tqdm"])
    import timm
    from tqdm import tqdm

AttributeError: partially initialized module 'torchvision' has no attribute 'extension' (most likely due to a circular import)

In [13]:
try:
    import timm
except ImportError:
    print("Installing timm library...")
    import subprocess
    subprocess.check_call(["pip", "install", "timm"])
    import timm

AttributeError: partially initialized module 'torchvision' has no attribute 'extension' (most likely due to a circular import)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image, ImageEnhance, ImageOps
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os
import time
import random
from datetime import datetime

# Install required packages
try:
    import timm
    from tqdm import tqdm
except ImportError:
    print("Installing required libraries...")
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "timm", "tqdm"])
    import timm
    from tqdm import tqdm

# Vision Transformer Configuration
class Config:
    data_dir = "/data1/home/prakrutp/medical_imaging/dataset"
    csv_path = "/data1/home/prakrutp/medical_imaging/dataset/Data_Entry_2017.csv"
    image_size = 224  # ViT standard input size
    batch_size = 16   # Reduced for ViT memory requirements
    num_epochs = 50
    learning_rate = 1e-4  # Lower LR for ViT
    num_classes = 14
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Vision Transformer specific
    model_name = "vit_base_patch16_224"
    use_pretrained = True
    
    # Class imbalance handling
    use_class_weights = True
    
    # Split files
    test_list_file = "test_list.txt"
    train_val_list_file = "train_val_list.txt"
    
    # Checkpoint settings
    checkpoint_dir = "checkpoints_vit"
    checkpoint_interval = 5
    resume_training = False

# Create checkpoint directory
os.makedirs(Config.checkpoint_dir, exist_ok=True)

# ==================== CUSTOM TRANSFORMS USING PILLOW ====================
class CustomTransform:
    """Custom transform class using only PIL to avoid torchvision issues"""
    
    @staticmethod
    def normalize(tensor):
        """Normalize tensor with ImageNet stats"""
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
        return (tensor - mean) / std
    
    @staticmethod
    def to_tensor(pil_image):
        """Convert PIL image to tensor"""
        return torch.from_numpy(np.array(pil_image)).float().permute(2, 0, 1) / 255.0
    
    @staticmethod
    def resize(pil_image, size):
        """Resize PIL image"""
        return pil_image.resize((size, size), Image.BILINEAR)
    
    @staticmethod
    def random_horizontal_flip(pil_image, p=0.5):
        """Random horizontal flip"""
        if random.random() < p:
            return pil_image.transpose(Image.FLIP_LEFT_RIGHT)
        return pil_image
    
    @staticmethod
    def random_rotation(pil_image, degrees=15):
        """Random rotation"""
        if random.random() < 0.5:
            angle = random.uniform(-degrees, degrees)
            return pil_image.rotate(angle, Image.BILINEAR)
        return pil_image
    
    @staticmethod
    def color_jitter(pil_image, brightness=0.2, contrast=0.2):
        """Color jitter using PIL"""
        if random.random() < 0.5:
            # Brightness
            factor = 1 + random.uniform(-brightness, brightness)
            pil_image = ImageEnhance.Brightness(pil_image).enhance(factor)
        
        if random.random() < 0.5:
            # Contrast
            factor = 1 + random.uniform(-contrast, contrast)
            pil_image = ImageEnhance.Contrast(pil_image).enhance(factor)
        
        return pil_image

class TrainTransform:
    def __init__(self, size=224):
        self.size = size
    
    def __call__(self, pil_image):
        # Apply augmentations
        pil_image = CustomTransform.resize(pil_image, self.size)
        pil_image = CustomTransform.random_horizontal_flip(pil_image, p=0.5)
        pil_image = CustomTransform.random_rotation(pil_image, degrees=15)
        pil_image = CustomTransform.color_jitter(pil_image, brightness=0.2, contrast=0.2)
        
        # Convert to tensor and normalize
        tensor = CustomTransform.to_tensor(pil_image)
        tensor = CustomTransform.normalize(tensor)
        return tensor

class ValTransform:
    def __init__(self, size=224):
        self.size = size
    
    def __call__(self, pil_image):
        # Only resize for validation
        pil_image = CustomTransform.resize(pil_image, self.size)
        
        # Convert to tensor and normalize
        tensor = CustomTransform.to_tensor(pil_image)
        tensor = CustomTransform.normalize(tensor)
        return tensor

def get_vit_transforms():
    """Get train and validation transforms"""
    return TrainTransform(size=Config.image_size), ValTransform(size=Config.image_size)

# ==================== DATASET CLASS ====================
class ChestXrayDataset(Dataset):
    def __init__(self, csv_file, base_dir, image_list=None, transform=None):
        self.df = pd.read_csv(csv_file)
        self.base_dir = base_dir
        self.transform = transform
        
        self.disease_classes = [
            'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 
            'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 
            'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 
            'Pleural_Thickening', 'Hernia'
        ]
        
        self._create_label_columns()
        
        if image_list is not None:
            self.data_frame = self.df[self.df['Image Index'].isin(image_list)].reset_index(drop=True)
        else:
            self.data_frame = self.df
        
        print(f"Dataset initialized with {len(self.data_frame)} images")
        
    def _create_label_columns(self):
        for disease in self.disease_classes:
            self.df.loc[:, disease] = self.df['Finding Labels'].apply(
                lambda x: 1 if disease in x else 0
            )
    
    def _find_image_path(self, img_name):
        for i in range(1, 13):
            folder_name = f"images_{i:03d}"
            possible_path = os.path.join(self.base_dir, folder_name, "images", img_name)
            if os.path.exists(possible_path):
                return possible_path
        return os.path.join(self.base_dir, "images_001", "images", img_name)
    
    def __len__(self):
        return len(self.data_frame)
    
    def __getitem__(self, idx):
        img_name = self.data_frame.iloc[idx]['Image Index']
        img_path = self._find_image_path(img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")
            image = Image.new('RGB', (Config.image_size, Config.image_size), color='black')
        
        labels = []
        for disease in self.disease_classes:
            labels.append(self.data_frame.iloc[idx][disease])
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.FloatTensor(labels)
    
    def get_class_distribution(self):
        distribution = {}
        for disease in self.disease_classes:
            count = self.data_frame[disease].sum()
            percentage = (count / len(self.data_frame)) * 100
            distribution[disease] = {'count': count, 'percentage': percentage}
        return distribution

# ==================== HELPER FUNCTIONS ====================
def load_split_file(file_path):
    with open(file_path, 'r') as f:
        image_names = [line.strip() for line in f.readlines()]
    return image_names

def create_datasets_from_splits(csv_file, base_dir, test_list_file, train_val_list_file, transform=None):
    test_images = load_split_file(test_list_file)
    train_val_images = load_split_file(train_val_list_file)
    
    print(f"Test images: {len(test_images)}")
    print(f"Train+Val images: {len(train_val_images)}")
    
    train_images, val_images = train_test_split(
        train_val_images, test_size=0.2, random_state=42
    )
    
    print(f"Train images: {len(train_images)}")
    print(f"Validation images: {len(val_images)}")
    
    train_dataset = ChestXrayDataset(
        csv_file=csv_file,
        base_dir=base_dir,
        image_list=train_images,
        transform=transform
    )
    
    val_dataset = ChestXrayDataset(
        csv_file=csv_file,
        base_dir=base_dir,
        image_list=val_images,
        transform=transform
    )
    
    test_dataset = ChestXrayDataset(
        csv_file=csv_file,
        base_dir=base_dir,
        image_list=test_images,
        transform=transform
    )
    
    return train_dataset, val_dataset, test_dataset

# ==================== VISION TRANSFORMER MODEL ====================
class VisionTransformerModel(nn.Module):
    def __init__(self, num_classes=14, model_name="vit_base_patch16_224"):
        super(VisionTransformerModel, self).__init__()
        
        # Load pre-trained Vision Transformer
        self.backbone = timm.create_model(
            model_name, 
            pretrained=Config.use_pretrained,
            num_classes=0  # Remove default classifier
        )
        
        # Get feature dimension
        if "base" in model_name:
            self.feature_dim = 768
        elif "large" in model_name:
            self.feature_dim = 1024
        else:
            self.feature_dim = self.backbone.num_features
        
        # Custom classifier for multi-label classification
        self.classifier = nn.Linear(self.feature_dim, num_classes)
        
        # Sigmoid for multi-label classification
        self.sigmoid = nn.Sigmoid()
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.1)
        
        print(f"Initialized {model_name} with {self.feature_dim} feature dimensions")
        
        # Count parameters
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
    
    def forward(self, x):
        # Extract features from Vision Transformer
        features = self.backbone(x)
        
        # Apply dropout
        features = self.dropout(features)
        
        # Classification
        output = self.classifier(features)
        output = self.sigmoid(output)
        
        return output

# ==================== LOSS FUNCTION ====================
class WeightedBCELoss(nn.Module):
    """Weighted BCE Loss for handling class imbalance"""
    def __init__(self, class_weights=None):
        super(WeightedBCELoss, self).__init__()
        self.class_weights = class_weights
    
    def forward(self, outputs, targets):
        batch_size, num_classes = targets.shape
        
        loss = 0
        for i in range(num_classes):
            class_output = outputs[:, i]
            class_target = targets[:, i]
            
            # Binary cross entropy for each class
            class_loss = nn.BCELoss()(class_output, class_target)
            
            # Apply class weights if provided
            if self.class_weights is not None:
                class_loss = class_loss * self.class_weights[i]
            
            loss += class_loss
        
        return loss / num_classes

# ==================== CHECKPOINT FUNCTIONS ====================
def save_checkpoint(epoch, model, optimizer, train_losses, val_losses, val_auc_scores, best_auc, is_best=False):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_auc_scores': val_auc_scores,
        'best_auc': best_auc,
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    filename = f"checkpoint_epoch_{epoch:03d}.pth"
    checkpoint_path = os.path.join(Config.checkpoint_dir, filename)
    torch.save(checkpoint, checkpoint_path)
    
    if is_best:
        best_path = os.path.join(Config.checkpoint_dir, "best_model.pth")
        torch.save(checkpoint, best_path)
    
    print(f"Checkpoint saved: {checkpoint_path}")

def find_latest_checkpoint():
    checkpoint_files = [f for f in os.listdir(Config.checkpoint_dir) if f.startswith('checkpoint_epoch_')]
    if not checkpoint_files:
        return None
    epochs = [int(f.split('_')[2].split('.')[0]) for f in checkpoint_files]
    latest_epoch = max(epochs)
    return os.path.join(Config.checkpoint_dir, f"checkpoint_epoch_{latest_epoch:03d}.pth")

# ==================== CLASS WEIGHTS CALCULATION ====================
def calculate_class_weights(dataset):
    """Calculate class weights for handling imbalance"""
    print("Calculating class weights...")
    
    # Use dataframe for fast calculation
    class_counts = np.zeros(Config.num_classes)
    
    for i, disease in enumerate(dataset.disease_classes):
        class_counts[i] = dataset.data_frame[disease].sum()
    
    total_samples = len(dataset)
    class_weights = total_samples / (Config.num_classes * np.maximum(class_counts, 1))
    class_weights = class_weights / np.sum(class_weights) * Config.num_classes
    
    print("Class weights calculated:")
    for i, disease in enumerate(dataset.disease_classes):
        print(f"  {disease}: {class_weights[i]:.2f} (count: {class_counts[i]})")
    
    return torch.FloatTensor(class_weights).to(Config.device)

# ==================== TRAINING FUNCTION ====================
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, start_epoch=0):
    train_losses = []
    val_losses = []
    val_auc_scores = []
    best_auc = 0.0
    patience = 15
    epochs_no_improve = 0
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    for epoch in range(start_epoch, num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 50)
        
        # Training phase
        model.train()
        running_loss = 0.0
        
        train_pbar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}')
        for batch_idx, (inputs, labels) in enumerate(train_pbar):
            inputs = inputs.to(Config.device)
            labels = labels.to(Config.device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping for ViT stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            train_pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
        
        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_preds = []
        
        val_pbar = tqdm(val_loader, desc=f'Validation Epoch {epoch+1}')
        with torch.no_grad():
            for inputs, labels in val_pbar:
                inputs = inputs.to(Config.device)
                labels = labels.to(Config.device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                
                all_labels.append(labels.cpu().numpy())
                all_preds.append(outputs.cpu().numpy())
        
        val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(val_loss)
        
        # Calculate AUC
        all_labels = np.concatenate(all_labels)
        all_preds = np.concatenate(all_preds)
        
        auc_scores = []
        for i in range(Config.num_classes):
            try:
                if np.sum(all_labels[:, i]) > 0:
                    auc = roc_auc_score(all_labels[:, i], all_preds[:, i])
                    auc_scores.append(auc)
                else:
                    auc_scores.append(0.0)
            except:
                auc_scores.append(0.0)
        
        mean_auc = np.mean(auc_scores)
        val_auc_scores.append(mean_auc)
        
        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f'Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}')
        print(f'Validation AUC: {mean_auc:.4f}, Learning Rate: {current_lr:.6f}')
        
        # Print top diseases by AUC
        sorted_auc = sorted(zip(train_loader.dataset.disease_classes, auc_scores), 
                           key=lambda x: x[1], reverse=True)
        print('Top 5 diseases by AUC:')
        for disease, auc in sorted_auc[:5]:
            print(f'  {disease}: {auc:.4f}')
        
        # Early stopping and checkpoint logic
        is_best = False
        if mean_auc > best_auc:
            best_auc = mean_auc
            epochs_no_improve = 0
            is_best = True
            print(f'New best AUC: {best_auc:.4f}')
        else:
            epochs_no_improve += 1
            print(f'No improvement for {epochs_no_improve} epochs')
        
        # Save checkpoint
        if (epoch + 1) % Config.checkpoint_interval == 0 or is_best:
            save_checkpoint(epoch+1, model, optimizer, train_losses, val_losses, val_auc_scores, best_auc, is_best)
        
        if epochs_no_improve >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break
    
    return model, train_losses, val_losses, val_auc_scores

# ==================== MAIN TRAINING FUNCTION ====================
def main_vit():
    print("Setting up VISION TRANSFORMER model...")
    print(f"Using device: {Config.device}")
    print(f"Model: {Config.model_name}")
    
    # Check GPU info
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    
    # Create transforms using our custom implementation
    train_transform, val_transform = get_vit_transforms()
    
    # Create datasets using existing functions
    test_list_path = os.path.join(Config.data_dir, Config.test_list_file)
    train_val_list_path = os.path.join(Config.data_dir, Config.train_val_list_file)
    csv_path = os.path.join(Config.data_dir, Config.csv_path)
    
    print("Loading datasets...")
    train_dataset, val_dataset, test_dataset = create_datasets_from_splits(
        csv_file=csv_path,
        base_dir=Config.data_dir,
        test_list_file=test_list_path,
        train_val_list_file=train_val_list_path,
        transform=train_transform
    )
    
    val_dataset.transform = val_transform
    test_dataset.transform = val_transform
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    # Calculate class weights
    class_weights = None
    if Config.use_class_weights:
        class_weights = calculate_class_weights(train_dataset)
    
    # Create data loaders
    print("Creating data loaders...")
    train_loader = DataLoader(
        train_dataset, 
        batch_size=Config.batch_size, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=Config.batch_size, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=Config.batch_size, 
        shuffle=False, 
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize Vision Transformer
    print(f"Initializing {Config.model_name}...")
    model = VisionTransformerModel(
        num_classes=Config.num_classes, 
        model_name=Config.model_name
    )
    model = model.to(Config.device)
    
    # Optimizer with weight decay (important for ViT)
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=Config.learning_rate,
        weight_decay=0.05,
        betas=(0.9, 0.999)
    )
    
    # Loss function
    criterion = WeightedBCELoss(class_weights=class_weights)
    
    # Start training
    print("\nStarting Vision Transformer training...")
    model, train_losses, val_losses, val_auc_scores = train_model(
        model, train_loader, val_loader, criterion, optimizer, 
        Config.num_epochs
    )
    
    # Final evaluation
    print("\nFinal evaluation on test set...")
    model.eval()
    test_preds = []
    test_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Testing'):
            inputs = inputs.to(Config.device)
            outputs = model(inputs)
            test_preds.append(outputs.cpu().numpy())
            test_labels.append(labels.numpy())
    
    test_preds = np.concatenate(test_preds)
    test_labels = np.concatenate(test_labels)
    
    # Calculate final metrics
    test_auc_scores = []
    for i in range(Config.num_classes):
        try:
            if np.sum(test_labels[:, i]) > 0:
                auc = roc_auc_score(test_labels[:, i], test_preds[:, i])
                test_auc_scores.append(auc)
            else:
                test_auc_scores.append(0.0)
        except:
            test_auc_scores.append(0.0)
    
    mean_auc = np.mean(test_auc_scores)
    
    # Print results
    print("\n" + "="*60)
    print("FINAL TEST RESULTS (VISION TRANSFORMER)")
    print("="*60)
    print(f"Test Average AUC: {mean_auc:.4f}")
    print("\nPer-class Test AUC:")
    for i, disease in enumerate(train_dataset.disease_classes):
        print(f"  {disease}: {test_auc_scores[i]:.4f}")
    
    # Save final model
    final_model_path = f'final_{Config.model_name}_model.pth'
    torch.save({
        'model_state_dict': model.state_dict(),
        'test_auc': mean_auc,
        'test_auc_scores': test_auc_scores,
        'training_history': {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_auc_scores': val_auc_scores
        }
    }, final_model_path)
    print(f"\nFinal model saved as '{final_model_path}'")
    
    # Plot results
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title(f'{Config.model_name} - Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(val_auc_scores)
    plt.title(f'{Config.model_name} - Validation AUC Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'{Config.model_name}_results.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("\n" + "="*60)
    print("VISION TRANSFORMER TRAINING COMPLETE")
    print("="*60)
    print(f"Final Test AUC: {mean_auc:.4f}")
    print("Expected improvement over ResNet-50: +4-6% AUC")

if __name__ == "__main__":
    main_vit()

Installing required libraries...
Collecting timm
  Using cached timm-1.0.22-py3-none-any.whl.metadata (63 kB)
Collecting safetensors (from timm)
  Using cached safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Using cached timm-1.0.22-py3-none-any.whl (2.5 MB)
Using cached safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (485 kB)
Installing collected packages: safetensors, timm
Successfully installed safetensors-0.6.2 timm-1.0.22


AttributeError: partially initialized module 'torchvision' has no attribute 'extension' (most likely due to a circular import)