<a href="https://colab.research.google.com/github/shemysyed/SSRDC-ViT/blob/main/pretrainViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from typing import List
import time
import numpy as np
import torch
import torchvision
from torchvision import transforms

from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import os
import glob
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path


class PreTrain(nn.Module):
    def __init__(self, vit):
        super().__init__()

        decoder_dim = vit.hidden_dim
        self.mask_ratio = 0.75
        self.patch_size = 16
        self.sequence_length = vit.seq_length

        self.backbone = MaskedVisionTransformerTorchvision(vit=vit)

        # the decoder is a simple linear layer
        self.decoder = nn.Linear(decoder_dim, vit.patch_size ** 2 * 3)

    def forward_encoder(self, images, batch_size, idx_mask):
        # pass all the tokens to the encoder, both masked and non masked ones
        return self.backbone.encode(images=images, idx_mask=idx_mask)

    def forward_decoder(self, x_encoded):
        return self.decoder(x_encoded)

    def forward(self, images):
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )

        # Encoding...
        x_encoded = self.forward_encoder(images, batch_size, idx_mask)
        x_encoded_masked = utils.get_at_index(x_encoded, idx_mask)

        # Decoding...
        x_out = self.forward_decoder(x_encoded_masked)

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)

        # must adjust idx_mask for missing class token
        target = utils.get_at_index(patches, idx_mask - 1)

        return x_out, target


class XrayDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        image = Image.open(path).convert('RGB')

        if self.transform:
            image = self.transform(image)
        return image


def calculate_accuracy(predictions, targets):
    _, predicted_classes = predictions.max(1)
    correct = (predicted_classes == targets).sum().item()
    accuracy = correct / targets.size(0)
    return accuracy


if __name__ == '__main__':
    vit = torchvision.models.vit_b_32(weights=None)
    model = PreTrain(vit)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_save_path = r"/content/drive/MyDrive/models/500_model_NIHPADCHESTCOVIDx_pretrainedL2loss_model.pth"
    loss_value_save_path = r"/content/drive/MyDrive/modelforreconstruction/500_model_NIHPADCHESTCOVIDx_pretrainedL2loss_loss_values.pth"
    accuracy_value_save_path = r"/content/drive/MyDrive/modelforreconstruction/500_model_NIHPADCHESTCOVIDx_pretrainedL2loss_accuracy_values.pth"
    optimizer_save_path = r"/content/drive/MyDrive/modelforreconstruction/500_model_NIHPADCHESTCOVIDx_pretrainedL2loss_optimizer.pth"
    scheduler_save_path = r"/content/drive/MyDrive/modelforreconstruction/500_model_NIHPADCHESTCOVIDx_pretrainedL2loss_scheduler.pth"
    model.to(device)
    model.load_state_dict(torch.load(model_save_path, map_location=device))

    transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.5, 1.0), ratio=(3 / 4, 4 / 3)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    #train_data = r"/content/drive/MyDrive/NIH/Train_filtered"
    train_data = glob.glob(r'/content/drive/MyDrive/NIH/Train_filtered\*')
    train_path = train_data



    train_dataset = XrayDataset(paths=train_path, transform=transform)
    dataset = train_dataset
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=512,
        shuffle=True,
        drop_last=True,
        num_workers=8,
    )

    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
    optimizer.load_state_dict(torch.load(optimizer_save_path))  # Resume optimizer state
    scheduler.load_state_dict(torch.load(scheduler_save_path))  # Resume scheduler state
    print("Starting Training")
    loss_value: List[float] = []  # Initialize loss_value as an empty list of floats
    accuracy_value: List[float] = []
    num_epochs = 100  # Total epochs after fine-tuning

    total_start_time = time.time()  # Record the start time of the training loop

    for epoch in range(num_epochs):
        epoch_start_time = time.time()  # Record the start time of the epoch
        total_loss = 0
        total_accuracy = 0
        model.train()
        for batch in dataloader:
            images = batch.to(device)


            predictions, targets = model(images)
            # Flatten the predictions and targets to calculate the accuracy
            predictions = predictions.view(-1, 3, model.patch_size, model.patch_size)
            targets = targets.view(-1, 3, model.patch_size, model.patch_size)
            loss = criterion(predictions, targets)
            total_loss += loss.item()

            # Calculate element-wise accuracy for the batch
            accuracy = ((predictions - targets).abs() < 0.1).float().mean().item()
            total_accuracy += accuracy

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        avg_loss = total_loss / len(dataloader)
        avg_accuracy = total_accuracy / len(dataloader)
        loss_value.append(avg_loss)  # Append the average loss for the current epoch
        accuracy_value.append(avg_accuracy)
        scheduler.step()
        epoch_end_time = time.time()  # Record the end time of the epoch
        epoch_time = epoch_end_time - epoch_start_time  # Calculate the epoch duration
        print(f"epoch: {epoch + 1:>02}, loss: {avg_loss:.5f}, accuracy: {avg_accuracy:.5f}, time: {epoch_time:.2f}s")

    total_end_time = time.time()  # Record the end time of the training loop
    total_training_time = total_end_time - total_start_time  # Calculate the total training time
    print(f"Total training time for {num_epochs} epochs: {total_training_time:.2f}s")

    print("Training completed. Loss values and accuracy for each epoch:")
    print(loss_value)
    print(accuracy_value)
    # Plotting loss after the training loop
    plt.plot(loss_value, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # Save the model state dictionary and loss/accuracy values after all epochs
    save_path = r"C:\PROJECT FILES\Models_saved\fine tune models\finetune_reconstruction\nih_reconstruction_finetunedfrom pretrained" # Directory to save models
    os.makedirs(save_path, exist_ok=True)
    unique_identifier = "finetunalltrain_forreconstruction_frombasemodel_.weightdecay01"  # Change this for each model
    model_save_path = os.path.join(save_path, f"{unique_identifier}_model.pth")
    loss_value_save_path = os.path.join(save_path, f"{unique_identifier}_loss_values.pth")
    accuracy_value_save_path = os.path.join(save_path, f"{unique_identifier}_accuracy_values.pth")
    optimizer_save_path = os.path.join(save_path, f"{unique_identifier}_optimizer.pth")
    scheduler_save_path = os.path.join(save_path, f"{unique_identifier}_scheduler.pth")
    config_save_path = os.path.join(save_path, f"{unique_identifier}_config.pth")

    torch.save(model.state_dict(), model_save_path)
    torch.save(loss_value, loss_value_save_path)  # Save loss values
    torch.save(accuracy_value, accuracy_value_save_path)  # Save accuracy values
    torch.save(optimizer.state_dict(), optimizer_save_path)  # Save optimizer state
    torch.save(scheduler.state_dict(), scheduler_save_path)  # Save scheduler state

    # Save training configuration
    config = {
        'learning_rate': 1.5e-4,
        'weight_decay': 0.01,
        'batch_size': 512,  # Updated batch size
        'num_epochs': num_epochs,
        'mask_ratio': model.mask_ratio,
        'patch_size': model.patch_size,
        'sequence_length': model.sequence_length,
    }
    torch.save(config, config_save_path)

    print(f"Model saved to {model_save_path}")
    print(f"Loss values saved to {loss_value_save_path}")
    print(f"Accuracy values saved to {accuracy_value_save_path}")
    print(f"Optimizer state saved to {optimizer_save_path}")
    print(f"Scheduler state saved to {scheduler_save_path}")
    print(f"Training configuration saved to {config_save_path}")
