In [16]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split


# Update transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])


# Load dataset
dataset = datasets.ImageFolder(root='/home/vamsi/cv/project/neww/Image-Super-Resolution-via-Iterative-Refinement/classs/share_it100', transform=transform)

# Define split sizes
train_size = int(0.9 * len(dataset))  # 90% training
val_size = len(dataset) - train_size  # 20% validation

# Split dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# Check class labels
print("Class to index mapping:", dataset.class_to_idx)


Class to index mapping: {'0n': 0, '1n': 1, '2n': 2, '3n': 3, '4n': 4, '5n': 5, '6n': 6, '7n': 7, '8n': 8, '9n': 9}


In [17]:
import timm
import torch.nn as nn
import torch

# Load a pre-trained ViT model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Modify the classifier head for the number of classes
num_classes = len(dataset.classes)
model.head = nn.Linear(model.head.in_features, num_classes)

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [18]:
from torch.optim import Adam
import torch.nn as nn

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 50
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    print(f"Epoch {epoch + 1}/{epochs}")

    # Training phase
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    print(f"Training Loss: {train_loss:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = 100 * correct / total

    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%\n")


Epoch 1/50
Training Loss: 2.8197
Validation Loss: 2.2305, Validation Accuracy: 22.32%

Epoch 2/50
Training Loss: 2.2654
Validation Loss: 2.1049, Validation Accuracy: 26.79%

Epoch 3/50
Training Loss: 2.1565
Validation Loss: 2.1510, Validation Accuracy: 16.07%

Epoch 4/50
Training Loss: 2.1771
Validation Loss: 2.1603, Validation Accuracy: 16.96%

Epoch 5/50
Training Loss: 2.1192
Validation Loss: 2.2214, Validation Accuracy: 23.21%

Epoch 6/50
Training Loss: 2.1036
Validation Loss: 1.9056, Validation Accuracy: 25.00%

Epoch 7/50
Training Loss: 2.0486
Validation Loss: 1.9812, Validation Accuracy: 24.11%

Epoch 8/50
Training Loss: 2.0082
Validation Loss: 1.9557, Validation Accuracy: 31.25%

Epoch 9/50
Training Loss: 1.9993
Validation Loss: 1.8982, Validation Accuracy: 26.79%

Epoch 10/50
Training Loss: 1.9410
Validation Loss: 2.0771, Validation Accuracy: 22.32%

Epoch 11/50
Training Loss: 1.9605
Validation Loss: 1.9822, Validation Accuracy: 24.11%

Epoch 12/50
Training Loss: 1.9739
Validat