# Domain adaptation analysis

### Initial settings

#### Dependencies

In [None]:
%pip install --user -qr ./requirements.txt 

#### Import and Utilities

In [None]:
import pytorch_lightning as pl
# your favorite machine learning tracking tool
# from pytorch_lightning.loggers import WandbLogger

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader

from torchmetrics import Accuracy

from torchvision import transforms

import wget 
import numpy as np
# import wandb

### Import Dataset

In [None]:
# TODO: spostare questa implementazione in un file a parte
import os
import zipfile
from scripts.extract_patches import *
import shutil

class CulturalSiteDatasetsLoader():

    def __init__(self, download_path):    
        self._class_path_datasets = './CLASS-EGO-CH-OBJ-ADAPT/' # TODO: non è detto che si chiami così, configurarlo in modo che sia così
        if not self._classification_datasets_exists():
            print("Classification dataset not found. Checking main dataset...")
            self._main_path_datasets = './EGO-CH-OBJ-ADAPT/'
            if not self._main_datasets_exists():
                print("Main dataset not found. Checking zip file dataset...")
                self._get_main_datasets(download_path) 
            print("Start extracting patches...")
            self._extract_patches(self._main_path_datasets, self._class_path_datasets)  
            print("Patches extracted successfully!")  
            shutil.rmtree(self._main_path_datasets) # remove Main Dataset #TODO: add default value to do this

    def _get_main_datasets(self, download_path): # download and extract dataset
        save_path = "./"
        zip_file_path = './EGO-CH-OBJ-ADAPT.zip'
        if not self._main_datasets_zip_exists(zip_file_path):
            print("Zip file dataset not found. Pulling from resource(", download_path, ")...")
            wget.download(download_path, save_path)
        print("Zip file found, start unzipping...")
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(save_path)
        os.remove(zip_file_path)
        print("File unzipped successfully!")

    def _extract_patches(self, main_path, save_path):
        ExtractPatches(main_path, save_path)

    def _main_datasets_exists(self):
        return os.path.isdir(self._main_path_datasets)

    def _main_datasets_zip_exists(self, zip_file_path):
        return os.path.isfile(zip_file_path)

    def _classification_datasets_exists(self):
        return os.path.isdir(self._class_path_datasets)

In [None]:
from typing import Callable, Optional
import json
from PIL import Image
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
from torchvision.datasets.vision import VisionDataset

class CulturalSiteDataset(VisionDataset):

    TRAIN = 0
    VALIDATION = 1
    TEST = 2

    REAL = 'real'
    SYNTEHTIC = 'syntehtic'

    def __init__(self, dataset_type=TRAIN, real=False, transform: Optional[Callable] = None) -> None:
        data_domain = CulturalSiteDataset.REAL if real else CulturalSiteDataset.SYNTEHTIC
        if dataset_type == CulturalSiteDataset.TRAIN:
            dataset_folder = './CLASS-EGO-CH-OBJ-ADAPT/' + data_domain + '/training/data/'
        elif dataset_type == CulturalSiteDataset.VALIDATION:
            dataset_folder = './CLASS-EGO-CH-OBJ-ADAPT/' + data_domain + '/validation/data/'
        elif dataset_type == CulturalSiteDataset.TEST:
            dataset_folder = './CLASS-EGO-CH-OBJ-ADAPT/' + data_domain + '/test/data/'
        super().__init__(root = dataset_folder, transform = transform, target_transform=None)
        self.images_data = []
        self.image_classes = []
        self._load_images(dataset_folder)
        self._load_image_classes("./utils/image_classes.json") # todo: remove hardcode (add config file)
        self._load_class_ids("./utils/image_classes.json")     # todo: remove hardcode o salvare 
        # todo: spostare in gpu se disponile

    def _load_images(self, path):
        for filename in os.listdir(path):
            im = Image.open(os.path.join(path, filename))
            self.images_data.append(np.asarray(im))

    def _load_image_classes(self, path):
        file = open(path)
        content = json.load(file)
        for key,el in content["categories"]:
            self.image_classes.append(el)

    def _load_class_ids(self, path):  # TODO: ottimizzare se possibile
        file = open(path)
        content = json.load(file)
        for i in range(len(self.image_classes)):
            for el in content["categories"]:
                if el["name"] == self.image_classes[i]:
                    self.image_classes[i] = el["id"]

    def __getitem__(self, index: int):
        img, image_class = self.images_data[index], self.image_classes[index]
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform()
        return img, image_class


    def __len__(self) -> int:
        return len(self.images_data)

In [None]:
### TEST Dataset Cell

CulturalSiteDataset("./CLASS-EGO-CH-OBJ-ADAPT/syntehtic/test/data/")

In [None]:
import os.path
import pickle
from typing import Any, Callable, Optional, Tuple

import numpy as np
from PIL import Image

from torchvision.datasets.utils import check_integrity, download_and_extract_archive
from torchvision.datasets.vision import VisionDataset

class CIFAR10(VisionDataset):
    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

    base_folder = "cifar-10-batches-py"
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
    train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]

    test_list = [
        ["test_batch", "40351d587109b95175f43aff81a1287e"],
    ]
    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:

        super().__init__(root, transform=transform, target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data: Any = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, "rb") as f:
                entry = pickle.load(f, encoding="latin1")
                self.data.append(entry["data"])
                if "labels" in entry:
                    self.targets.extend(entry["labels"])
                else:
                    self.targets.extend(entry["fine_labels"])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

    def _load_meta(self) -> None:
        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
        if not check_integrity(path, self.meta["md5"]):
            raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            self.classes = data[self.meta["key"]]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def __len__(self) -> int:
        return len(self.data)

    def _check_integrity(self) -> bool:
        root = self.root
        for fentry in self.train_list + self.test_list:
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self) -> None:
        if self._check_integrity():
            print("Files already downloaded and verified")
            return
        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)

    def extra_repr(self) -> str:
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"

In [None]:
class CulturalSiteDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, real, num_classes):
        super().__init__()
        self.batch_size = batch_size
        self.real = real
        self.transform = transforms.Compose([
            transforms.ToTensor()
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.num_classes = num_classes
    
    def prepare_data(self):
        CulturalSiteDatasetsLoader("https://iplab.dmi.unict.it/EGO-CH-OBJ-ADAPT/EGO-CH-OBJ-ADAPT.zip") # todo: rimuovere hardcode
        CulturalSiteDataset() #dovremmo usare questa classe anziché il loader. O comunque l'importante qui è preparare i dati per averli pronti dopo

    
    def setup(self, stage=None):
        # TODO: qui istanziare CulturalSiteDataset, creare i dataset da passare ai dataloader sotto
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self.cultural_site_train = CulturalSiteDataset(dataset_type=CulturalSiteDataset.TRAIN, real=self.real, transform=self.transform)
            self.cultural_site_val = CulturalSiteDataset(dataset_type=CulturalSiteDataset.VALIDATION, real=self.real, transform=self.transform)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cultural_site_test = CulturalSiteDataset(dataset_type=CulturalSiteDataset.VALIDATION, real=self.real, transform=self.transform)
    
    def train_dataloader(self):
        return DataLoader(self.cultural_site_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cultural_site_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.cultural_site_test, batch_size=self.batch_size)

### Dataset Pre Analysis

*todo*

### Normalize input features

*todo*

## Domain adaptation study cases

### 1. Baseline approaches without adaption
Il primo caso di studio è quello di allenare il classificatore sul sintetico e poi testarlo nudo e crudo sul reale, valutare le performance

#### Init Logger

In [None]:
class ImagePredictionLogger(pl.callbacks.Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
    
    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # Get model prediction
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # Log the images as wandb Image
        ''' trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            }) '''
        

### Classification Module

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv3 = nn.Conv2d(32, 64, 3, 1)
        self.conv4 = nn.Conv2d(64, 64, 3, 1)

        self.pool1 = torch.nn.MaxPool2d(2)
        self.pool2 = torch.nn.MaxPool2d(2)
        
        n_sizes = self._get_conv_output(input_shape)

        self.fc1 = nn.Linear(n_sizes, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

        self.accuracy = Accuracy()

    # returns the size of the output tensor going into Linear layer from the conv block.
    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    # returns the feature tensor from the conv block
    def _forward_features(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool2(F.relu(self.conv4(x)))
        return x
    
    # will be used during inference
    def forward(self, x):
       x = self._forward_features(x)
       x = x.view(x.size(0), -1)
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = F.log_softmax(self.fc3(x), dim=1)
       
       return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


### Preparing DataModule

In [9]:
dm = CulturalSiteDataModule(batch_size=32, num_classes=16) # todo: istanziare CulturalSiteDataset, preparare i dati
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

TypeError: __init__() missing 1 required positional argument: 'dataset_folder'

### Model Training

In [None]:
model = LitModel((3, 32, 32), dm.num_classes)

# todo: inizializzare il logger, utilizzeremo tensorboard come anno scorso
# Initialize wandb logger
# wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')

# Initialize Callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()

# Initialize a trainer
# todo: vedere se possibile riconoscere automaticamente il device
trainer = pl.Trainer(max_epochs=10,
                     accelerator="mps", #'mps' to use apple silicon graphics unit, 'gpu' for nvidia or amd 
                     devices=1,
                     # logger=wandb_logger,
                     callbacks=[early_stop_callback,
                                ImagePredictionLogger(val_samples),
                                checkpoint_callback],
                     )

# Train the model ⚡🚅⚡
trainer.fit(model, dm)

# Evaluate the model on the held-out test set ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())

# Close wandb run
# wandb.finish()