Here, the encoder is pretrained on an external tuberculosis classification dataset, enabling it to
extract domain-specific features. You can access this dataset here: https://www.kaggle.com/datasets/mohamedhanyyy/chest-ctscan-images

# Import, load and preprocess external datset


*   First load and augment the dataset.
*   Then split it between training and validation datasets.
*   Use the following code to train the transformer based encoder on the augmented external dataset

In [None]:
!pip install -q graphviz torchsummary torchview

In [None]:
from torchsummary import summary
import torchvision
from torchview import draw_graph
import cv2
import os
from pathlib import Path
from PIL import Image
from termcolor import colored
import glob
from typing import List, Tuple
from sklearn.metrics import f1_score, accuracy_score
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset
import torch.nn.functional as F
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2
%matplotlib inline

sns.set()
sns.set_palette('bwr')
SNS_CMAP = 'bwr'
plt.style.use("dark_background")
plt.rcParams['grid.color'] = '#444444'
colors = sns.palettes.color_palette(SNS_CMAP)
pd.options.mode.chained_assignment = None

def clrd(text: str, color: str = None, con: bool = None, c1:str = 'ok', c2:str = 'error')->str:
  text = str(text)
    color_codes = {
        'ok': '\033[1;92m',
        'error': '\033[91m',
        'warning': '\033[93m',
        'success': '\033[92m',
        'status': '\033[95m',
        'special': '\033[94m',
        'log': '\033[96m',
        'reset': '\033[0m',
    }
    if con is not None:
        color = c1 if con else c2
    color_code = color_codes.get(color, color_codes['reset'])
    return f"{color_code}{text}{color_codes['reset']}"

In [None]:
DATA_DIR = r"/kaggle/input/chest-ctscan-images/Data"
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")
VAL_DIR = os.path.join(DATA_DIR, "valid")

In [None]:
LABEL_MAP = {
    "adenocarcinoma": 0,
    "adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib": 0,
    "large.cell.carcinoma": 1,
    "large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa": 1,
    "large.cell.carcinoma": 1,
    "normal": 2,
    "squamous.cell.carcinoma": 3,
    "squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa": 3
}

train_paths, train_labels = [], []
val_paths, val_labels = [], []

for split in ['train', 'test', 'valid']:
    split_path = os.path.join(DATA_DIR, split)
    for folder in os.listdir(split_path):
        full_path = os.path.join(split_path, folder)
        if not os.path.isdir(full_path):
            continue
        label = LABEL_MAP.get(folder)
        if label is None:
            print(f"Unknown folder label: {folder}")
            continue
        for fname in os.listdir(full_path):
            if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                img_path = os.path.join(full_path, fname)
                if split == 'valid':
                    val_paths.append(img_path)
                    val_labels.append(label)
                else:
                    train_paths.append(img_path)
                    train_labels.append(label)

df_train = pd.DataFrame({
    "img": train_paths,
    "label": train_labels
})
df_valid = pd.DataFrame({
    "img": val_paths,
    "label": val_labels
})
df_train.shape, df_valid.shape

In [None]:
class ImageDatasetCSV(Dataset):
    def __init__(self, df, transform=None):
        """
        Args:
            df (pd.DataFrame): `img` and `label` columns.
            transform (callable, optional): Optional transforms to apply to the images.
        """
        self.df = df.copy()
        df['label'] = df['label'].apply(lambda y: torch.tensor(y, dtype=torch.long))
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['img']).convert("RGB")
        label = row['label']

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
### augmentations on the dataset

train_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.RandomAffine(
        degrees=10,              # small rotation
        translate=(0.05, 0.05),  # simulate slight offset
        scale=(0.95, 1.05),      # simulate zoom in/out
        shear=5                  # optional shear
    ),
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.2, contrast=0.2)
    ], p=0.8),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
    transforms.Normalize(mean=[0.5], std=[0.5])
])


val_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [None]:
train_dataset = ImageDatasetCSV(df_train, train_transforms)
val_dataset = ImageDatasetCSV(df_valid, val_transforms)

# Encoder training on external dataset

In [None]:
import torch.nn.functional as F
from torchvision.models import vit_b_16

class TransformerEncoder(nn.Module):
    def __init__(self, latent_dim=512):
        super(TransformerEncoder, self).__init__()
        self.latent_dim = latent_dim
        self.vit = vit_b_16(weights=None)
        self.channel_adapter = nn.Conv2d(1, 3, kernel_size=1)
        self.vit.heads = nn.Identity()  # removing classification head
        self.fc = nn.Linear(768, latent_dim)

    def forward(self, x):
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        x = self.channel_adapter(x)
        z = self.vit(x)
        z = self.fc(z)
        return z

class TransformerClassifier(nn.Module):
    def __init__(self, encoder, latent_dim, num_classes):
        super(TransformerClassifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(latent_dim, num_classes)

    def forward(self, x):
        z = self.encoder(x)
        out = self.fc(z)
        return out

In [None]:
def get_accuracy_and_f1(y_true, y_pred):
    preds = torch.argmax(y_pred, dim=1)
    acc = accuracy_score(y_true.cpu(), preds.cpu())
    f1 = f1_score(y_true.cpu(), preds.cpu(), average='macro')
    return acc, f1

def debug_grad_norm(module):
    total_norm = 0
    for p in module.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm

In [None]:
NUM_CLASSES = 4
batch_size = 64
num_workers = 2

num_gpu = torch.cuda.device_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = TransformerEncoder(latent_dim=512)
model = TransformerClassifier(encoder=encoder, latent_dim=512, num_classes=NUM_CLASSES)
model = model.to(device)

if torch.cuda.device_count() > 1 and num_gpu > 1:
    model = nn.DataParallel(model)
    print(f"Using {torch.cuda.device_count()} GPU(s)!")

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm

history = {"train_acc": [], "train_loss": [], "val_acc": [], "val_loss": []}

epochs = 100
DEBUG = False

for epoch in range(epochs):
    train_loss, train_acc = 0.0, []
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
    for batch in pbar:
        optimizer.zero_grad()
        images, y = batch
        images, y = images.to(device), y.to(device)
        yhat = model(images)

        loss = criterion(yhat, y)
        loss.backward()
        optimizer.step()

        acc, f1 = get_accuracy_and_f1(y, yhat)
        train_loss += loss.item() / len(batch[0])
        train_acc.append(acc)

        pbar.set_postfix(loss=loss.item(), acc=acc)

    scheduler.step()

    history["train_loss"].append(train_loss)
    history["train_acc"].append(np.mean(train_acc))

    # Validation loop
    model.eval()
    val_loss, val_accs = 0.0, []
    with torch.no_grad():
        for batch in val_loader:
            images, y = batch
            images, y = images.to(device), y.to(device)
            yhat = model(images)

            loss = criterion(yhat, y)
            acc, f1 = get_accuracy_and_f1(y, yhat)

            val_loss += loss.item() / len(batch[0])
            val_accs.append(acc)

    history["val_loss"].append(val_loss)
    history["val_acc"].append(np.mean(val_accs))

    # gradient norm of encoder
    classifier_grad_norm = debug_grad_norm(model.module.encoder if isinstance(model, nn.DataParallel) else model.encoder)

    print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Train Acc: {np.mean(train_acc):.4f} | Val Acc: {np.mean(val_accs):.4f} | Grad Norm: {classifier_grad_norm:.4f}")

    if DEBUG:
        break

Save the encoder and use it in the following code.

# Training the autoencoder

### Load the dataset given in the task

Currently havent included the data here to maintain privacy

In [None]:
class MedicalImageDatasetBlackWhite(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)
        return image

image_size = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = MedicalImageDatasetBlackWhite(root_dir='insert dataset here',
                              transform=transforms.Compose([
                               transforms.Grayscale(),
                               transforms.Resize((image_size, image_size)),
                               transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5], std=[0.5]),
                           ]))

### Autoencoder structure

In [None]:
class ResidualDecoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.LeakyReLU(0.1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)

class TransformerDecoder(nn.Module):
    def __init__(self, latent_dim=512, out_channels=1):
        super().__init__()
        self.latent_dim = latent_dim
        self.fc = nn.Linear(latent_dim, 256 * 8 * 8)
        self.relu = nn.ReLU()

        self.decode = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # 16x16
            ResidualDecoderBlock(256, 128),
            ResidualDecoderBlock(128, 64),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # 32x32
            ResidualDecoderBlock(64, 32),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # 64x64
            ResidualDecoderBlock(16, 16),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),    # 128x128
            nn.ConvTranspose2d(8, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),  # 256x256
            nn.Tanh()  # Assuming input images are normalized between [-1, 1]
        )

    def forward(self, z):
        x = self.relu(self.fc(z))
        x = x.view(-1, 256, 8, 8)
        return self.decode(x)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super(AutoEncoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return recon

In [None]:
import torch.optim as optim

encoder = TransformerEncoder(latent_dim=512).to(device)
decoder = TransformerDecoder(latent_dim=512, out_channels=1).to(device)
autoencoder = AutoEncoder(encoder, decoder).to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    autoencoder.train()
    running_loss = 0.0
    for images, _ in train_loader:  # Assuming train_loader yields (image, label) pairs
        images = images.to(device)
        optimizer.zero_grad()
        outputs = autoencoder(images)
        loss = criterion(outputs, images)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader):.4f}")