In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet18
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import os
from utils import *

# --- Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Load pretrained encoder and attach linear classifier ---
encoder = load_pretrained_encoder('./checkpoints_new_backbone/simsiam_encoder.pth', backbone='resnet18')
model = create_linear_classifier(encoder, num_classes=10, freeze_encoder=True).to(device)

# --- EuroSAT Dataset ---
# EuroSAT images are RGB with shape (64, 64), but we'll resize to 224x224
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

data_root = './data/EuroSAT/2750'

train_set = datasets.ImageFolder(root=data_root, transform=transform)
num_classes = len(train_set.classes)

# Optionally split dataset into train/test if not already split
from torch.utils.data import random_split
train_size = int(0.8 * len(train_set))
test_size = len(train_set) - train_size
train_set, test_set = random_split(train_set, [train_size, test_size])

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)

# --- Training setup ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model[-1].parameters(), lr=0.001)

writer = SummaryWriter(log_dir='../runs/linear_eval_eurosat')

# --- Training Loop ---
for epoch in range(100):
    model.train()
    correct, total, total_loss = 0, 0, 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}", leave=False)
    for x, y in pbar:
        x, y = x.to(device), y.to(device)

        logits = model(x)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

        pbar.set_postfix({
            'Loss': f"{loss.item():.4f}",
            'Acc': f"{100 * correct / total:.2f}%"
        })

    acc = 100 * correct / total
    writer.add_scalar('Train/Loss', total_loss / len(train_loader), epoch + 1)
    writer.add_scalar('Train/Accuracy', acc, epoch + 1)

    print(f"Epoch {epoch+1:3d} | Loss: {total_loss:.4f} | Train Acc: {acc:.2f}%")

    # --- Evaluate every 100 epochs ---
    if (epoch + 1) % 10 == 0:
        model.eval()
        test_correct, test_total = 0, 0
        with torch.no_grad():
            for x, y in tqdm(test_loader, desc=f"Eval @ Epoch {epoch+1}", leave=False):
                x, y = x.to(device), y.to(device)
                preds = model(x).argmax(dim=1)
                test_correct += (preds == y).sum().item()
                test_total += y.size(0)

        test_acc = 100 * test_correct / test_total
        writer.add_scalar('Test/Accuracy', test_acc, epoch + 1)
        print(f"Test Accuracy @ Epoch {epoch+1:3d}: {test_acc:.2f}%")

# --- Final Evaluation ---
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for x, y in tqdm(test_loader, desc="Final Evaluation", leave=False):
        x, y = x.to(device), y.to(device)
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

final_test_acc = 100 * correct / total
writer.add_scalar('Test/Accuracy', final_test_acc, 400)
print(f"\nFinal Linear Evaluation Accuracy on EuroSAT: {final_test_acc:.2f}%")

# --- Save Model ---
torch.save(model.state_dict(), 'linear_head_eurosat_final_weights.pth')
print("Final model weights saved as 'linear_head_eurosat_final_weights.pth'")


                                                                                     

Epoch   1 | Loss: 166.9954 | Train Acc: 70.62%


                                                                                     

Epoch   2 | Loss: 111.1044 | Train Acc: 78.12%


                                                                                     

Epoch   3 | Loss: 101.7263 | Train Acc: 79.65%


                                                                                     

Epoch   4 | Loss: 97.0383 | Train Acc: 80.24%


                                                                                     

Epoch   5 | Loss: 93.8080 | Train Acc: 80.75%


                                                                                     

Epoch   6 | Loss: 90.9266 | Train Acc: 81.29%


                                                                                     

Epoch   7 | Loss: 88.1845 | Train Acc: 81.64%


                                                                                     

Epoch   8 | Loss: 86.7217 | Train Acc: 82.15%


                                                                                     

Epoch   9 | Loss: 85.1550 | Train Acc: 82.33%


                                                                                     

Epoch  10 | Loss: 84.3194 | Train Acc: 82.55%


                                                                                     

Epoch  11 | Loss: 82.4495 | Train Acc: 82.88%


                                                                                     

Epoch  12 | Loss: 81.9416 | Train Acc: 82.85%


                                                                                     

Epoch  13 | Loss: 80.0282 | Train Acc: 83.38%


                                                                                     

Epoch  14 | Loss: 79.6530 | Train Acc: 83.09%


                                                                                     

Epoch  15 | Loss: 78.6143 | Train Acc: 83.59%


                                                                                     

Epoch  16 | Loss: 78.8534 | Train Acc: 83.38%


                                                                                     

Epoch  17 | Loss: 77.6386 | Train Acc: 83.90%


                                                                                     

Epoch  18 | Loss: 76.9537 | Train Acc: 84.04%


                                                                                     

Epoch  19 | Loss: 76.5158 | Train Acc: 84.10%


                                                                                     

Epoch  20 | Loss: 75.8300 | Train Acc: 84.38%


                                                                                     

Epoch  21 | Loss: 75.2347 | Train Acc: 84.25%


                                                                                     

Epoch  22 | Loss: 74.9111 | Train Acc: 84.47%


                                                                                     

Epoch  23 | Loss: 74.2798 | Train Acc: 84.38%


                                                                                     

Epoch  24 | Loss: 74.3959 | Train Acc: 84.53%


                                                                                     

Epoch  25 | Loss: 72.7068 | Train Acc: 84.78%


                                                                                     

Epoch  26 | Loss: 72.8056 | Train Acc: 84.75%


                                                                                     

Epoch  27 | Loss: 72.1901 | Train Acc: 85.00%


                                                                                     

Epoch  28 | Loss: 71.4157 | Train Acc: 85.13%


                                                                                     

Epoch  29 | Loss: 71.5258 | Train Acc: 84.99%


                                                                                     

Epoch  30 | Loss: 71.9124 | Train Acc: 85.12%


                                                                                     

Epoch  31 | Loss: 70.3786 | Train Acc: 85.34%


                                                                                     

Epoch  32 | Loss: 70.2142 | Train Acc: 85.48%


                                                                                     

Epoch  33 | Loss: 70.3116 | Train Acc: 85.36%


                                                                                     

Epoch  34 | Loss: 70.0658 | Train Acc: 85.32%


                                                                                     

Epoch  35 | Loss: 69.5912 | Train Acc: 85.47%


                                                                                     

Epoch  36 | Loss: 69.2227 | Train Acc: 85.53%


                                                                                     

Epoch  37 | Loss: 69.2061 | Train Acc: 85.56%


                                                                                     

Epoch  38 | Loss: 69.1600 | Train Acc: 85.54%


                                                                                     

Epoch  39 | Loss: 68.3559 | Train Acc: 85.69%


                                                                                     

Epoch  40 | Loss: 68.4857 | Train Acc: 85.75%


                                                                                     

Epoch  41 | Loss: 68.1869 | Train Acc: 85.72%


                                                                                     

Epoch  42 | Loss: 68.0445 | Train Acc: 85.76%


                                                                                     

Epoch  43 | Loss: 67.7991 | Train Acc: 85.85%


                                                                                     

Epoch  44 | Loss: 67.1280 | Train Acc: 86.15%


                                                                                     

Epoch  45 | Loss: 66.7258 | Train Acc: 86.08%


                                                                                     

Epoch  46 | Loss: 67.5363 | Train Acc: 85.85%


                                                                                     

Epoch  47 | Loss: 66.5911 | Train Acc: 86.16%


                                                                                     

Epoch  48 | Loss: 66.5493 | Train Acc: 86.12%


                                                                                     

Epoch  49 | Loss: 66.6312 | Train Acc: 86.08%


                                                                                     

Epoch  50 | Loss: 65.8446 | Train Acc: 86.56%


                                                                                     

Epoch  51 | Loss: 65.8294 | Train Acc: 86.16%


                                                                                     

Epoch  52 | Loss: 65.3110 | Train Acc: 86.52%


                                                                                     

Epoch  53 | Loss: 64.9421 | Train Acc: 86.56%


                                                                                     

Epoch  54 | Loss: 64.8996 | Train Acc: 86.71%


                                                                                     

Epoch  55 | Loss: 65.5203 | Train Acc: 86.44%


                                                                                     

Epoch  56 | Loss: 64.1778 | Train Acc: 86.76%


                                                                                     

Epoch  57 | Loss: 64.8208 | Train Acc: 86.55%


                                                                                     

Epoch  58 | Loss: 64.2777 | Train Acc: 86.67%


                                                                                     

Epoch  59 | Loss: 64.1837 | Train Acc: 86.54%


                                                                                     

Epoch  60 | Loss: 64.4200 | Train Acc: 86.59%


                                                                                     

Epoch  61 | Loss: 64.2048 | Train Acc: 86.79%


                                                                                     

Epoch  62 | Loss: 63.7937 | Train Acc: 86.75%


                                                                                     

Epoch  63 | Loss: 63.2263 | Train Acc: 86.67%


                                                                                     

Epoch  64 | Loss: 63.2417 | Train Acc: 86.89%


                                                                                     

Epoch  65 | Loss: 63.0776 | Train Acc: 86.83%


                                                                                     

Epoch  66 | Loss: 63.3988 | Train Acc: 86.75%


                                                                                     

Epoch  67 | Loss: 63.0548 | Train Acc: 86.89%


                                                                                     

Epoch  68 | Loss: 63.4032 | Train Acc: 86.77%


                                                                                     

Epoch  69 | Loss: 62.5496 | Train Acc: 86.94%


                                                                                     

Epoch  70 | Loss: 62.9782 | Train Acc: 86.95%


                                                                                     

Epoch  71 | Loss: 62.2625 | Train Acc: 87.06%


                                                                                     

Epoch  72 | Loss: 62.8343 | Train Acc: 86.84%


                                                                                     

Epoch  73 | Loss: 62.1964 | Train Acc: 87.03%


                                                                                     

Epoch  74 | Loss: 61.9639 | Train Acc: 87.20%


                                                                                     

Epoch  75 | Loss: 63.1336 | Train Acc: 86.90%


                                                                                     

Epoch  76 | Loss: 61.6711 | Train Acc: 87.05%


                                                                                     

Epoch  77 | Loss: 62.1049 | Train Acc: 87.06%


                                                                                     

Epoch  78 | Loss: 62.4849 | Train Acc: 87.07%


                                                                                     

Epoch  79 | Loss: 61.1076 | Train Acc: 87.30%


                                                                                     

Epoch  80 | Loss: 61.5288 | Train Acc: 87.46%


                                                                                     

Epoch  81 | Loss: 61.2726 | Train Acc: 87.20%


                                                                                     

Epoch  82 | Loss: 61.6224 | Train Acc: 87.22%


                                                                                     

Epoch  83 | Loss: 60.8649 | Train Acc: 87.31%


                                                                                     

Epoch  84 | Loss: 60.5361 | Train Acc: 87.47%


                                                                                     

Epoch  85 | Loss: 61.1851 | Train Acc: 87.09%


                                                                                     

Epoch  86 | Loss: 60.5520 | Train Acc: 87.31%


                                                                                     

Epoch  87 | Loss: 60.3419 | Train Acc: 87.29%


                                                                                     

Epoch  88 | Loss: 60.4889 | Train Acc: 87.46%


                                                                                     

Epoch  89 | Loss: 60.7061 | Train Acc: 87.42%


                                                                                     

Epoch  90 | Loss: 59.8389 | Train Acc: 87.54%


                                                                                     

Epoch  91 | Loss: 60.2314 | Train Acc: 87.53%


                                                                                     

Epoch  92 | Loss: 60.0969 | Train Acc: 87.46%


                                                                                     

Epoch  93 | Loss: 59.6624 | Train Acc: 87.56%


                                                                                     

Epoch  94 | Loss: 60.2592 | Train Acc: 87.29%


                                                                                     

Epoch  95 | Loss: 59.0386 | Train Acc: 87.81%


                                                                                     

Epoch  96 | Loss: 59.1517 | Train Acc: 87.89%


                                                                                     

Epoch  97 | Loss: 59.7031 | Train Acc: 87.61%


                                                                                     

Epoch  98 | Loss: 58.8972 | Train Acc: 87.78%


                                                                                     

Epoch  99 | Loss: 59.3804 | Train Acc: 87.71%


                                                                                     

Epoch 100 | Loss: 59.6722 | Train Acc: 87.53%


                                                                 

Test Accuracy @ Epoch 100: 86.65%


                                                                 


Final Linear Evaluation Accuracy on EuroSAT: 86.65%
Final model weights saved as 'linear_head_eurosat_final_weights.pth'
