In [2]:
import pathlib
import zipfile
from PIL import Image
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset    

In [28]:
class CustomDataset(Dataset):
    def __init__(self, path, label2id, id2label, transform=None) -> None:
        super().__init__()
        self.path = pathlib.Path(path)
        self.label2id = label2id
        self.id2label = id2label
        assert self.path.exists(), f"{path} does not exist"
        self.transform = transform
        self.files_path = self._get_file_list()
    
    def _get_file_list(self):
        files_path = []
            
        with zipfile.ZipFile(self.path, 'r') as zip_file:
            for name in zip_file.namelist():
                if '/' in name:
                    files_path.append(name)
        return files_path

    def __len__(self): 
        return len(self.file_list)
    
    def __getitem__(self, index):
        image, label = None, None
        with zipfile.ZipFile(self.path, 'r') as zip_file:
            with zip_file.open(self.files_path[index]) as file:
                image = Image.open(file).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            label = self.label2id[self.files_path[index].split('/')[1]]
        return image, label

In [29]:
def create_label_encs(dataset_path):
    label2id, id2label = {}, {}
    with zipfile.ZipFile(dataset_path, 'r') as zip_file:
        for name in zip_file.namelist():
            if '/' in name:
                label = name.split('/')[1]
                if label not in label2id:
                    label2id[label] = len(label2id)
                    id2label[len(id2label)] = label
    return label2id, id2label


label2id, id2label = create_label_encs(
    'datasets/geolocation-geoguessr-images-50k.zip')
print(label2id, id2label, sep='\n\n')

{'Aland': 0, 'Albania': 1, 'American Samoa': 2, 'Andorra': 3, 'Antarctica': 4, 'Argentina': 5, 'Armenia': 6, 'Australia': 7, 'Austria': 8, 'Bangladesh': 9, 'Belarus': 10, 'Belgium': 11, 'Bermuda': 12, 'Bhutan': 13, 'Bolivia': 14, 'Botswana': 15, 'Brazil': 16, 'Bulgaria': 17, 'Cambodia': 18, 'Canada': 19, 'Chile': 20, 'China': 21, 'Colombia': 22, 'Costa Rica': 23, 'Croatia': 24, 'Curacao': 25, 'Czechia': 26, 'Denmark': 27, 'Dominican Republic': 28, 'Ecuador': 29, 'Egypt': 30, 'Estonia': 31, 'Eswatini': 32, 'Faroe Islands': 33, 'Finland': 34, 'France': 35, 'Germany': 36, 'Ghana': 37, 'Gibraltar': 38, 'Greece': 39, 'Greenland': 40, 'Guam': 41, 'Guatemala': 42, 'Hong Kong': 43, 'Hungary': 44, 'Iceland': 45, 'India': 46, 'Indonesia': 47, 'Iraq': 48, 'Ireland': 49, 'Isle of Man': 50, 'Israel': 51, 'Italy': 52, 'Japan': 53, 'Jersey': 54, 'Jordan': 55, 'Kenya': 56, 'Kyrgyzstan': 57, 'Laos': 58, 'Latvia': 59, 'Lebanon': 60, 'Lesotho': 61, 'Lithuania': 62, 'Luxembourg': 63, 'Macao': 64, 'Madagas

In [31]:
dataset1 = CustomDataset('datasets/geolocation-geoguessr-images-50k.zip', label2id, id2label)
dataset2 = CustomDataset('datasets/geoguessr-55countries.zip', label2id, id2label)

['compressed_dataset', 'Aland', 'canvas_1629480180.jpg']
0


In [None]:
class CustomModel(nn.Module): #? Custom model made from scratch
    def __init__(self) -> None:
        super().__init__()
        

    def forward(self, x): pass

In [None]:
# aqui eh alopration

# class ConvBlock(nn.Module):
#     def __init__(self, **kwargs) -> None:
#         super().__init__()
#         self.convs = nn.Sequential()
#         for i in range(len(kwargs['filter_size'])):
#             # self.convs.add_module(f'conv_{i}', nn.Conv2d(kwargs['in_channels'], kwargs['out_channels'][i], kwargs['filter_size'][i], kwargs['stride'][i], kwargs['padding'][i]))
#             self.convs.add_module(f'activation_{i}', nn.ReLU())
#             if kwargs['batch_norm']:
#                 self.convs.add_module(f'bn_{i}', nn.BatchNorm2d(kwargs['out_channels'][i]))
#             self.convs.add_module(f'activation_{i}', nn.ReLU())
#         self.convs.add_module('pooling', nn.MaxPool2d(kwargs['pool_size'], kwargs['pool_stride']))


In [None]:
class Metrics: 
    def __init__(self) -> None:
        pass

    def __call__(self, *args, **kwds):
        pass

In [None]:
class Callback:
    def __init__(self) -> None:
        self.epoch_counter = 0
        self.batch_counter = 0
        self.train_loss = 0
        self.val_loss = 0
        self.train_acc = 0
        self.val_acc = 0
        pass

    def on_train_begin(self): pass
    def on_train_end(self): pass
    def on_epoch_begin(self): pass
    def on_epoch_end(self): pass
    def on_batch_begin(self): pass
    def on_batch_end(self): pass

class ModelCheckpoint(Callback):
    def __init__(self, path) -> None:
        super().__init__()
        self.path = pathlib.Path(path)
        self.best_loss = float('inf')

    def on_epoch_end(self, model, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            torch.save(model.state_dict(), self.path)

class EarlyStopping(Callback):
    def __init__(self, patience) -> None:
        super().__init__()
        self.patience = patience
        self.counter = 0
        self.best_loss = float('inf')

    def on_epoch_end(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print("Early stopping")
                return True
            
class LearningRateScheduler(Callback):
    def __init__(self, lr_scheduler) -> None:
        super().__init__()
        self.lr_scheduler = lr_scheduler

    def on_epoch_end(self):
        self.lr_scheduler()

In [None]:

# Training loop
def train(model, train_loader, val_loader, loss_func, optimizer, epochs, callbacks: None | list[Callback] = None):
    for epoch in range(epochs):
        if any(callback.on_epoch_begin() for callback in callbacks):
            break
        model.train()
        if any(callback.on_train_begin() for callback in callbacks):
            break
        for images, labels in train_loader:
            if any(callback.on_batch_begin() for callback in callbacks):
                break
            optimizer.zero_grad()
            output = model(images)
            loss = loss_func(output, labels)
            loss.backward()
            optimizer.step()
            if any(callback.on_batch_end() for callback in callbacks):
                break
        if any(callback.on_train_end() for callback in callbacks):
            break
        model.eval()
        with torch.no_grad():
            val_loss = 0
            val_samples = 0
            for images, labels in val_loader:
                output = model(images)
                loss = loss_func(output, labels)
                val_loss += loss.item()
                val_samples += 1
            avg_val_loss = val_loss / val_samples
            print(f"Epoch: {epoch}, Validation Loss: {avg_val_loss}")
        if any(callback.on_epoch_end() for callback in callbacks):
            break

In [None]:
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# loss_func = nn.CrossEntropyLoss()
# train(model, train_loader, val_loader, loss_func, optimizer, 10)