# scikit-learn scratchpad

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report
import os
from PIL import Image

In [2]:
EXAMPLES = True

In [3]:
from dataclasses import dataclass

@dataclass
class Rect:
    """Specifies the region of the image we're interested in. 
    Resolution is a width:height of the image."""
    top: float
    left: float
    width: float
    height: float
    resolution: tuple = (2560, 1440)

    @property
    def right(self):
        return self.left + self.resolution[0] - self.width

    @property
    def bottom(self):
        return self.top + self.resolution[1] - self.height

if EXAMPLES:
    r = Rect(left=165, top=1192, width=2117, height=211)
    print(r.right, r.bottom)
    r = None

608 2421


In [4]:
class CustomDataset(Dataset):
    """Takes care of loading images from disk & transforming them.
    Makes use of predefined classes: 'other' and 'starmap'.
    """
    def __init__(self, root_dir, crop_rect: Rect, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['other', 'starmap']
        self.file_paths = []
        self.crop_rect = crop_rect
        
        for class_name in self.classes:
            class_path = os.path.join(root_dir, class_name)
            for file in os.listdir(class_path):
                self.file_paths.append(
                    (os.path.join(class_path, file), self.classes.index(class_name))
                )

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path, label = self.file_paths[idx]
        image = Image.open(img_path).convert('RGB')
        rect = self.crop_rect
        image = image.crop((rect.left, rect.top, rect.right, rect.bottom))
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [5]:
@dataclass
class Session:
    transform: torchvision.transforms.transforms.Compose
    crop_region: Rect
    optimizer: object
    criterion: object
    device: object

@dataclass
class Data:
    dataset: CustomDataset
    batch_size: int
    shuffle: bool

    @property
    def train_size(self):
        return int(0.8 * len(self.dataset))

    @property
    def val_size(self):
        return len(self.dataset) - self.train_size

    def dataset_split(self):
         return torch.utils.data.random_split(self.dataset, [self.train_size, self.val_size])

@dataclass
class TrainingResult:
    model: object
    loader_val: object
    

In [8]:
CROP_REGION = Rect(left=165, top=1192, width=2117, height=211)

def make_criterion(): return nn.CrossEntropyLoss()
def make_optimizer(model): return optim.Adam(model.parameters(), lr=0.001)

sess = Session(
    transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    crop_region=CROP_REGION,
    criterion=make_criterion,
    optimizer=make_optimizer,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

data = Data(
    dataset=CustomDataset('data/screenshots', transform=sess.transform, crop_rect=CROP_REGION),
    batch_size=32,
    shuffle=True,
)

def train(sess: Session, data: Data, model, num_epochs=5):
    dataset_train, dataset_val = data.dataset_split()
    loader_train = DataLoader(dataset_train, batch_size=data.batch_size, shuffle=data.shuffle)
    loader_val = DataLoader(dataset_val, batch_size=data.batch_size, shuffle=data.shuffle)

    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2) # 2 output classes
    device = sess.device
    device_model = model.to(device)

    optimizer = sess.optimizer(device_model)
    criterion = sess.criterion()
    
    for epoch in range(num_epochs):
        device_model.train()
        running_loss = 0.0

        for inputs, labels, in loader_train:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = device_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(dataset_train)
        print(f'Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}')

    return TrainingResult(model=device_model, loader_val=loader_val)

def evaluate(result: TrainingResult, device):
    model, val_loader = result.model, result.loader_val
    
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    print(classification_report(all_labels, all_preds, target_names=data.dataset.classes))
    

if EXAMPLES:
    model = models.resnet34()
    result = train(sess, data, model, num_epochs=10)
    evaluate(result, sess.device)

# Evaluation
# if False:
#     model.eval()
#     all_preds = []
#     all_labels = []
    
#     with torch.no_grad():
#         for inputs, labels in val_loader:
#             inputs = inputs.to(device)
#             outputs = model(inputs)
#             _, preds = torch.max(outputs, 1)
            
#             all_preds.extend(preds.cpu().numpy())
#             all_labels.extend(labels.numpy())

#     print(classification_report(all_labels, all_preds, target_names=data.dataset.classes))


Epoch 1/10 - Loss: 1.3512
Epoch 2/10 - Loss: 0.0515
Epoch 3/10 - Loss: 0.0263
Epoch 4/10 - Loss: 0.0018
Epoch 5/10 - Loss: 0.0013
Epoch 6/10 - Loss: 0.0024
Epoch 7/10 - Loss: 0.0005
Epoch 8/10 - Loss: 0.0014
Epoch 9/10 - Loss: 0.0006
Epoch 10/10 - Loss: 0.0013
              precision    recall  f1-score   support

       other       1.00      1.00      1.00        11
     starmap       1.00      1.00      1.00         4

    accuracy                           1.00        15
   macro avg       1.00      1.00      1.00        15
weighted avg       1.00      1.00      1.00        15



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

let's save the model for later

In [32]:
torch.save(model.state_dict(), 'resnet34_state_dict_v2.pth')

Let's try to use the model now.

In [49]:
import torch
from torchvision import transforms
from PIL import Image

def load_model(model_path, device='cpu'):
    # Initialize model
    model = models.resnet34()
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.load_state_dict(torch.load(model_path, map_location=device,weights_only=False))
    model.eval()
    return model.to(device)

# Preprocessing (must match training preprocessing)
def preprocess_image(image_path):    
    img = Image.open(image_path).convert('RGB')    
    r = CROP_REGION
    img = img.crop((r.left, r.top, r.right, r.bottom)) 
    return transform(img).unsqueeze(0)  # Add batch dimension

# Example usage
def example_usage(image_path):
    model = load_model('resnet34_state_dict_v2.pth')
    input_tensor = preprocess_image(image_path)
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        predicted_class = torch.argmax(probabilities).item()
        print(f'Predicted class: {["other", "starmap"][predicted_class]}')


In [50]:
example_usage('E:/bin/StarCitizen/LIVE/ScreenShots/ScreenShot-2024-05-11_21-58-40-1A3.jpg')

Predicted class: other


In [51]:
example_usage('E:/bin/StarCitizen/LIVE/ScreenShots/ScreenShot-2024-05-12_00-09-07-EF4.jpg')

Predicted class: other


In [56]:
example_usage(r"E:\bin\StarCitizen\LIVE\ScreenShots\ScreenShot-2024-07-31_01-13-41-B9A.jpg")

Predicted class: other


In [52]:
example_usage(r"E:\bin\StarCitizen\LIVE\ScreenShots\ScreenShot-2024-08-13_23-34-21-1CC.jpg") # starmap

Predicted class: starmap


In [54]:
example_usage(r"E:\bin\StarCitizen\LIVE\ScreenShots\ScreenShot-2025-03-06_22-49-40-3A1.jpg")

Predicted class: starmap


In [55]:
example_usage(r"E:\bin\StarCitizen\LIVE\ScreenShots\ScreenShot-2025-03-06_23-52-52-0D2.jpg") # starmap

Predicted class: starmap
