# Setup


In [None]:
# @title connect to drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# @title download dataset

import kagglehub

# Download latest version
path = kagglehub.dataset_download("orvile/gastric-cancer-histopathology-tissue-image-dataset")
print("Path to dataset files:", path)
DATA_DIR = "/kaggle/input/gastric-cancer-histopathology-tissue-image-dataset/HMU-GC-HE-30K/all_image"
PROJECT_DIR = "/content/drive/MyDrive/CS184A"

In [None]:
# @title install libraries + settings
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
RNG_SEED = 123

!pip install evaluate


# Training

In [None]:
# @title Find # images of each class
import os

for class_dir in os.listdir(DATA_DIR):
    class_dir_path = os.path.join(DATA_DIR, class_dir)
    if os.path.isdir(class_dir_path):
        num_files = len([f for f in os.listdir(class_dir_path)
                         if os.path.isfile(os.path.join(class_dir_path, f))])
        print(f"{class_dir} has {num_files}")


In [None]:
# @title Make train/test split
# This takes ≈10 min
import glob
import pandas as pd
from datasets import load_dataset, DatasetDict

# 80-10-10 train-validation-test
inital_dataset = load_dataset("imagefolder", data_dir=DATA_DIR)
tmp1_dataset = inital_dataset["train"].train_test_split(test_size=0.2, seed=RNG_SEED)
tmp2_dataset = tmp1_dataset["test"].train_test_split(test_size=0.5, seed=RNG_SEED)
dataset = DatasetDict({
    "train": tmp1_dataset["train"],
    "val": tmp2_dataset["train"],
    "test": tmp2_dataset["test"]
})


In [None]:
# @title Preprocess Data
from transformers import MobileViTImageProcessor, MobileViTForImageClassification
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop


model_name = "apple/mobilevit-x-small"
processor = MobileViTImageProcessor.from_pretrained(model_name)

_transform_pipeline = Compose([
    Resize(256),
    ToTensor(),
])

def transform_img(batch):
    batch["pixel_values"] = [ _transform_pipeline(img.convert("RGB")) for img in batch["image"]]
    return batch

dataset = dataset.map(
    transform_img,batched=True,batch_size=32,remove_columns=["image"]
)

In [None]:
# @title Load model

model = MobileViTForImageClassification.from_pretrained(model_name, num_labels=8, ignore_mismatched_sizes=True)



In [None]:
# @title Collate Function

from transformers import DefaultDataCollator

collator = DefaultDataCollator()

In [None]:
api_key = "00086a870ad384c29afdb7cd781adb1ffa612c06"
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
import torch
import evaluate
import numpy as np
import evaluate

training_args = TrainingArguments(
    output_dir=f"{PROJECT_DIR}/output",
    eval_strategy = "steps",
    save_strategy = "steps",
    learning_rate = 1e-4,
    weight_decay = 1e-4,
    num_train_epochs = 10,
    save_steps=500,
    eval_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    logging_steps = 500,
    report_to = "tensorboard",
    save_total_limit=3
)

metric = evaluate.load("accuracy", "roc_auc")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # convert the logits to their predicted class
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    data_collator = collator,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["val"],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )

trainer.train()

In [None]:
# @title Testing
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import numpy as np
import evaluate
import numpy as np
import torch

roc_auc_metric = evaluate.load("roc_auc", "multiclass")
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
f1_metric =evaluate.load("f1")
recall_metric =evaluate.load("recall")

def compute_metrics(eval_pred):
    logits, labels = eval_pred

    probs = torch.softmax(torch.tensor(logits, dtype=torch.float32), dim=1).numpy()
    labels = labels.astype(np.int32)
    predictions = np.argmax(logits, axis=-1)

    results = {}

    results["roc_auc"] = roc_auc_metric.compute(
        references=labels,
        prediction_scores=probs,
        multi_class="ovr",
    )["roc_auc"]

    results["accuracy"] = accuracy_metric.compute(
        predictions=predictions,
        references=labels,
    )["accuracy"]

    results["precision"] = precision_metric.compute(
        predictions=predictions,
        references=labels,
        average="macro",
    )["precision"]

    results["f1"] = f1_metric.compute(
        predictions=predictions,
        references=labels,
        average="macro",
    )["f1"]

    results["recall"] = recall_metric.compute(
        predictions=predictions,
        references=labels,
        average="macro",
    )["recall"]

    return results


checkpoint_path = f"{PROJECT_DIR}/checkpoint-13000"

model = AutoModelForImageClassification.from_pretrained(
    checkpoint_path,
    use_safetensors=True
)

training_args = TrainingArguments(
    output_dir=f"{PROJECT_DIR}/output",
    eval_strategy = "steps",
    save_strategy = "steps",
    learning_rate = 1e-4,
    weight_decay = 1e-4,
    num_train_epochs = 10,
    save_steps=500,
    eval_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    logging_steps = 500,
    report_to = "tensorboard",
    save_total_limit=3,
    remove_unused_columns=False
    )

trainer = Trainer(
    model=model,
    data_collator = collator,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["val"],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)


test_results = trainer.evaluate(dataset["test"])
print(test_results)



