# Example of the `aitlas` toolbox in the context of multi class image classification

This notebook shows a sample implementation of a multi class image classification using the `aitlas` toolbox using the UC merced dataset.

In [None]:
from aitlas.datasets import UcMercedDataset
from aitlas.models import ResNet50
from aitlas.transforms import ResizeCenterCropFlipHVToTensor, ResizeCenterCropToTensor
from aitlas.utils import image_loader

## Load the dataset

In [None]:
dataset_config = {
    "data_dir": "G:/1.pond/3.pond_chengde/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset_nrg",
    "csv_file": "G:/1.pond/3.pond_chengde/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset_nrg/train.csv"
}
dataset = UcMercedDataset(dataset_config)

## Show images from the dataset

In [None]:
fig1 = dataset.show_image(1000)
fig2 = dataset.show_image(80)
fig3 = dataset.show_batch(15)

## Inspect the data

In [None]:
dataset.show_samples()

In [None]:
dataset.data_distribution_table()

In [None]:
fig = dataset.data_distribution_barchart()

## Load train and val 

In [None]:
train_dataset_config = {
    "batch_size": 16,
    "shuffle": True,
    "num_workers": 4,
    "data_dir": "G:/1.pond/3.pond_chengde/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset_nrg",
    "csv_file": "G:/1.pond/3.pond_chengde/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset_nrg/train.csv"
}

train_dataset = UcMercedDataset(train_dataset_config)
train_dataset.transform = ResizeCenterCropFlipHVToTensor() 

val_dataset_config = {
    "batch_size": 4,
    "shuffle": False,
    "num_workers": 4,
    "data_dir": "G:/1.pond/3.pond_chengde/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset_nrg",
    "csv_file": "G:/1.pond/3.pond_chengde/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset_nrg/val.csv",
    "transforms": ["aitlas.transforms.ResizeCenterCropToTensor"]
}

val_dataset = UcMercedDataset(val_dataset_config)
len(train_dataset), len(val_dataset)

## Setup and create the model for training

In [None]:
epochs = 100
model_directory = "G:/1.pond/3.pond_chengde/11.sc/aitlas-master/media/hdd/multi-class/experiment/nrg"
model_config = {
    "num_classes": 2, 
    "learning_rate": 0.0001,
    "pretrained": True,
    "metrics": ["accuracy", "precision", "recall", "f1_score"]
}
model = ResNet50(model_config)
model.prepare()

## Training and evaluation

In [None]:
model.train_and_evaluate_model(
    train_dataset=train_dataset,
    epochs=epochs,
    model_directory=model_directory,
    val_dataset=val_dataset,
    run_id='1',
)

## Predictions

In [None]:
model_path = "F:/2.tailing/11.sc/aitlas-master/media/hdd/multi-class/experiment/1/best_checkpoint_1700319155_13.pth.tar"
#labels = UcMercedDataset.labels
labels = ["background", "tailing",]
transform = ResizeCenterCropToTensor()
model.load_model(model_path)

image = image_loader('F:/2.tailing/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset/tailing/sentinel_cut_C011R010.tif')
fig, pred = model.predict_image(image, labels, transform)
print(pred)

image = image_loader('F:/2.tailing/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset/tailing/sentinel_cut_C008R019.tif')
fig, pred = model.predict_image(image, labels, transform)
print(pred)

image = image_loader('F:/2.tailing/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset/tailing/sentinel_cut_C012R020.tif')
fig, pred = model.predict_image(image, labels, transform)
print(pred)

image = image_loader('F:/2.tailing/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset/tailing/sentinel_cut_C020R015.tif')
fig, pred = model.predict_image(image, labels, transform)
print(pred)

In [None]:
import torch
import os
from tqdm import tqdm
model_path = "G:/1.pond/3.pond_chengde/11.sc/aitlas-master/media/hdd/multi-class/experiment/nrg/1/best_checkpoint_1700706198_15.pth.tar"
#labels = UcMercedDataset.labels
labels = ["background", "tailing",]
transform = ResizeCenterCropToTensor()
model.load_model(model_path)

data_folder = 'F:/20240128-EveryThing/3.MultiSourceTailing/tu12/1.sentinel2_patch/images'
output_background_file = 'F:/20240128-EveryThing/3.MultiSourceTailing/tu12/background-nrg.txt'
output_tailing_file = 'F:/20240128-EveryThing/3.MultiSourceTailing/tu12/tailing-nrg.txt'

tif_files = [f for f in os.listdir(data_folder) if f.endswith('.tif')]

with open(output_background_file, 'w') as background_file, open(output_tailing_file, 'w') as tailing_file:
    # Iterate over each TIF file
    for tif_file in tqdm(tif_files):
        # Load the image
        
        image_path = os.path.join(data_folder, tif_file)
        image = image_loader(image_path)
        pred = model.predict_image(image, labels, transform)
        if torch.equal(pred, torch.tensor([[0]], device='cuda:0')):
            background_file.write(tif_file + '\n')
        elif torch.equal(pred, torch.tensor([[1]], device='cuda:0')):
            tailing_file.write(tif_file + '\n')

evaluate

In [None]:
#测试集
test_dataset_config = {
    "batch_size": 4,
    "shuffle": False,
    "num_workers": 4,
    "data_dir": "F:/2.tailing/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset_nrg",
    "csv_file": "F:/2.tailing/11.sc/aitlas-master/media/hdd/multi-class/tailing_dataset_nrg/val.csv",
    "transforms": ["aitlas.transforms.ResizeCenterCropToTensor"]
}

test_dataset = UcMercedDataset(test_dataset_config)


In [None]:
model_path = "F:/2.tailing/11.sc/aitlas-master/media/hdd/multi-class/experiment/nrg/1/best_checkpoint_1700706198_15.pth.tar"
model.metrics = ["accuracy", "precision", "recall", "f1_score"]
model.running_metrics.reset()
model.evaluate(dataset=test_dataset, model_path=model_path)
print(model.running_metrics.get_scores(model.metrics))